augment of seq2seq with webank datas
This commit is contained in:
parent
b53939d895
commit
849ae9fa08
5
AugmentText/augment_seq2seq/__init__.py
Normal file
5
AugmentText/augment_seq2seq/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# !/usr/bin/python
|
||||
# @time :2019/4/15 10:17
|
||||
# @author :Mo
|
||||
# @function :
|
@ -0,0 +1,5 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# !/usr/bin/python
|
||||
# @time :2019/4/15 10:50
|
||||
# @author :Mo
|
||||
# @function :
|
@ -0,0 +1,91 @@
|
||||
"""
|
||||
把文件格式转换为可训练格式
|
||||
Code from: QHDuan(2018-02-05) url: https://github.com/qhduan/just_another_seq2seq
|
||||
"""
|
||||
from conf.path_config import train_data_web_ws_anti
|
||||
from conf.path_config import train_data_web_xy_anti
|
||||
from conf.path_config import model_ckpt_web_anti
|
||||
from conf.path_config import path_webank_sim
|
||||
|
||||
from utils.mode_util.seq2seq.word_sequence import WordSequence
|
||||
from utils.text_tools import txtRead
|
||||
from tqdm import tqdm
|
||||
import pickle
|
||||
import sys
|
||||
import re
|
||||
|
||||
sys.path.append('..')
|
||||
|
||||
|
||||
def make_split(line):
|
||||
"""构造合并两个句子之间的符号
|
||||
"""
|
||||
if re.match(r'.*([,。…?!~\.,!?])$', ''.join(line)):
|
||||
return []
|
||||
return [',']
|
||||
|
||||
|
||||
def good_line(line):
|
||||
if len(re.findall(r'[a-zA-Z0-9]', ''.join(line))) > 2:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def regular(sen, limit=50):
|
||||
sen = re.sub(r'\.{3,100}', '…', sen)
|
||||
sen = re.sub(r'…{2,100}', '…', sen)
|
||||
sen = re.sub(r'[,]{1,100}', ',', sen)
|
||||
sen = re.sub(r'[\.]{1,100}', '。', sen)
|
||||
sen = re.sub(r'[\?]{1,100}', '?', sen)
|
||||
sen = re.sub(r'[!]{1,100}', '!', sen)
|
||||
if len(sen) > limit:
|
||||
sen = sen[0:limit]
|
||||
return sen
|
||||
|
||||
|
||||
def creat_train_data_of_sim_corpus(limit=50, x_limit=2, y_limit=2):
|
||||
x_datas = []
|
||||
y_datas = []
|
||||
max_len = 0
|
||||
sim_ali_web_gov_dli_datas = txtRead(path_webank_sim, encodeType="gbk")
|
||||
for sim_ali_web_gov_dli_datas_one in sim_ali_web_gov_dli_datas[1:]:
|
||||
sim_ali_web_gov_dli_datas_one_split = sim_ali_web_gov_dli_datas_one.strip().split(",")
|
||||
if sim_ali_web_gov_dli_datas_one_split[2]=="1":
|
||||
len_x1 = len(sim_ali_web_gov_dli_datas_one_split[0])
|
||||
len_x2 = len(sim_ali_web_gov_dli_datas_one_split[1])
|
||||
# if max_len < len_x1 or max_len < len_x2:
|
||||
max_len = max(len_x1, len_x2, max_len)
|
||||
|
||||
sentence_org = regular(sim_ali_web_gov_dli_datas_one_split[0], limit=limit)
|
||||
sentence_sim = regular(sim_ali_web_gov_dli_datas_one_split[1], limit=limit)
|
||||
x_datas.append([sen for sen in sentence_org])
|
||||
y_datas.append([sen for sen in sentence_sim])
|
||||
x_datas.append([sen for sen in sentence_sim])
|
||||
y_datas.append([sen for sen in sentence_org])
|
||||
|
||||
datas = list(zip(x_datas, y_datas))
|
||||
datas = [
|
||||
(x, y)
|
||||
for x, y in datas
|
||||
if len(x) < limit and len(y) < limit and len(y) >= y_limit and len(x) >= x_limit
|
||||
]
|
||||
x_datas, y_datas = zip(*datas)
|
||||
|
||||
print('fit word_sequence')
|
||||
|
||||
ws_input = WordSequence()
|
||||
ws_input.fit(x_datas + y_datas)
|
||||
|
||||
print('dump')
|
||||
|
||||
pickle.dump((x_datas, y_datas),
|
||||
open(train_data_web_xy_anti, 'wb')
|
||||
)
|
||||
pickle.dump(ws_input, open(train_data_web_ws_anti, 'wb'))
|
||||
|
||||
print('done')
|
||||
print(max_len)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
creat_train_data_of_sim_corpus()
|
@ -0,0 +1,95 @@
|
||||
"""
|
||||
对SequenceToSequence模型进行基本的参数组合测试
|
||||
Code from: QHDuan(2018-02-05) url: https://github.com/qhduan/just_another_seq2seq
|
||||
|
||||
"""
|
||||
|
||||
from utils.mode_util.seq2seq.data_utils import batch_flow_bucket as batch_flow
|
||||
from utils.mode_util.seq2seq.thread_generator import ThreadedGenerator
|
||||
from utils.mode_util.seq2seq.model_seq2seq import SequenceToSequence
|
||||
from utils.mode_util.seq2seq.word_sequence import WordSequence
|
||||
|
||||
from conf.path_config import train_data_web_ws_anti
|
||||
from conf.path_config import train_data_web_xy_anti
|
||||
from conf.path_config import model_ckpt_web_anti
|
||||
from conf.path_config import path_params
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import pickle
|
||||
import json
|
||||
import sys
|
||||
|
||||
sys.path.append('..')
|
||||
|
||||
|
||||
def predict_anti(params):
|
||||
"""测试不同参数在生成的假数据上的运行结果"""
|
||||
|
||||
x_data, _ = pickle.load(open(train_data_web_xy_anti, 'rb'))
|
||||
ws = pickle.load(open(train_data_web_ws_anti, 'rb'))
|
||||
|
||||
for x in x_data[:5]:
|
||||
print(' '.join(x))
|
||||
|
||||
config = tf.ConfigProto(
|
||||
# device_count={'CPU': 1, 'GPU': 0},
|
||||
allow_soft_placement=True,
|
||||
log_device_placement=False
|
||||
)
|
||||
|
||||
save_path = model_ckpt_web_anti
|
||||
|
||||
# 测试部分
|
||||
tf.reset_default_graph()
|
||||
model_pred = SequenceToSequence(
|
||||
input_vocab_size=len(ws),
|
||||
target_vocab_size=len(ws),
|
||||
batch_size=1,
|
||||
mode='decode',
|
||||
beam_width=0,
|
||||
**params
|
||||
)
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(init)
|
||||
model_pred.load(sess, save_path)
|
||||
|
||||
while True:
|
||||
user_text = input('Input Chat Sentence:')
|
||||
if user_text in ('exit', 'quit'):
|
||||
exit(0)
|
||||
x_test = [list(user_text.lower())]
|
||||
# x_test = [word_tokenize(user_text)]
|
||||
bar = batch_flow([x_test], ws, 1)
|
||||
x, xl = next(bar)
|
||||
x = np.flip(x, axis=1)
|
||||
# x = np.array([
|
||||
# list(reversed(xx))
|
||||
# for xx in x
|
||||
# ])
|
||||
print(x, xl)
|
||||
pred = model_pred.predict(
|
||||
sess,
|
||||
np.array(x),
|
||||
np.array(xl)
|
||||
)
|
||||
print(pred)
|
||||
# prob = np.exp(prob.transpose())
|
||||
print(ws.inverse_transform(x[0]))
|
||||
# print(ws.inverse_transform(pred[0]))
|
||||
# print(pred.shape, prob.shape)
|
||||
for p in pred:
|
||||
ans = ws.inverse_transform(p)
|
||||
print(ans)
|
||||
|
||||
|
||||
def main():
|
||||
"""入口程序"""
|
||||
import json
|
||||
predict_anti(json.load(open(path_params)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
174
AugmentText/augment_seq2seq/code_seq2seq_char/train_char_anti.py
Normal file
174
AugmentText/augment_seq2seq/code_seq2seq_char/train_char_anti.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
对SequenceToSequence模型进行基本的参数组合测试
|
||||
Code from: QHDuan(2018-02-05) url: https://github.com/qhduan/just_another_seq2seq
|
||||
|
||||
"""
|
||||
|
||||
from utils.mode_util.seq2seq.data_utils import batch_flow_bucket as batch_flow
|
||||
from utils.mode_util.seq2seq.thread_generator import ThreadedGenerator
|
||||
from utils.mode_util.seq2seq.model_seq2seq import SequenceToSequence
|
||||
from utils.mode_util.seq2seq.word_sequence import WordSequence
|
||||
|
||||
from conf.path_config import train_data_web_ws_anti
|
||||
from conf.path_config import train_data_web_xy_anti
|
||||
from conf.path_config import model_ckpt_web_anti
|
||||
from conf.path_config import path_params
|
||||
|
||||
from sklearn.utils import shuffle
|
||||
import tensorflow as tf
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pickle
|
||||
import random
|
||||
import json
|
||||
import sys
|
||||
|
||||
sys.path.append('..')
|
||||
|
||||
|
||||
def train_and_dev(params):
|
||||
"""测试不同参数在生成的假数据上的运行结果"""
|
||||
|
||||
x_data, y_data = pickle.load(open(train_data_web_xy_anti, 'rb'))
|
||||
ws = pickle.load(open(train_data_web_ws_anti, 'rb'))
|
||||
|
||||
# 训练部分
|
||||
n_epoch = 2
|
||||
batch_size = 128
|
||||
x_data, y_data = shuffle(x_data, y_data, random_state=20190412)
|
||||
|
||||
steps = int(len(x_data) / batch_size) + 1
|
||||
|
||||
config = tf.ConfigProto(
|
||||
# device_count={'CPU': 1, 'GPU': 0},
|
||||
allow_soft_placement=True,
|
||||
log_device_placement=False
|
||||
)
|
||||
|
||||
save_path = model_ckpt_web_anti
|
||||
|
||||
tf.reset_default_graph()
|
||||
with tf.Graph().as_default():
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
tf.set_random_seed(0)
|
||||
|
||||
with tf.Session(config=config) as sess:
|
||||
|
||||
model = SequenceToSequence(
|
||||
input_vocab_size=len(ws),
|
||||
target_vocab_size=len(ws),
|
||||
batch_size=batch_size,
|
||||
**params
|
||||
)
|
||||
init = tf.global_variables_initializer()
|
||||
sess.run(init)
|
||||
|
||||
|
||||
flow = ThreadedGenerator(
|
||||
batch_flow([x_data, y_data], ws, batch_size,
|
||||
add_end=[False, True]),
|
||||
queue_maxsize=30)
|
||||
|
||||
dummy_encoder_inputs = np.array([
|
||||
np.array([WordSequence.PAD]) for _ in range(batch_size)])
|
||||
dummy_encoder_inputs_lengths = np.array([1] * batch_size)
|
||||
|
||||
for epoch in range(1, n_epoch + 1):
|
||||
costs = []
|
||||
bar = tqdm(range(steps), total=steps,
|
||||
desc='epoch {}, loss=0.000000'.format(epoch))
|
||||
for _ in bar:
|
||||
x, xl, y, yl = next(flow)
|
||||
x = np.flip(x, axis=1)
|
||||
|
||||
add_loss = model.train(sess,
|
||||
dummy_encoder_inputs,
|
||||
dummy_encoder_inputs_lengths,
|
||||
y, yl, loss_only=True)
|
||||
|
||||
add_loss *= -0.5
|
||||
|
||||
cost, lr = model.train(sess, x, xl, y, yl,
|
||||
return_lr=True, add_loss=add_loss)
|
||||
costs.append(cost)
|
||||
bar.set_description('epoch {} loss={:.6f} lr={:.6f}'.format(
|
||||
epoch,
|
||||
np.mean(costs),
|
||||
lr
|
||||
))
|
||||
|
||||
model.save(sess, save_path)
|
||||
|
||||
flow.close()
|
||||
|
||||
# 测试部分
|
||||
tf.reset_default_graph()
|
||||
model_pred = SequenceToSequence(
|
||||
input_vocab_size=len(ws),
|
||||
target_vocab_size=len(ws),
|
||||
batch_size=1,
|
||||
mode='decode',
|
||||
beam_width=12,
|
||||
**params
|
||||
)
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(init)
|
||||
model_pred.load(sess, save_path)
|
||||
|
||||
bar = batch_flow([x_data, y_data], ws, 1, add_end=False)
|
||||
t = 0
|
||||
for x, xl, y, yl in bar:
|
||||
x = np.flip(x, axis=1)
|
||||
pred = model_pred.predict(
|
||||
sess,
|
||||
np.array(x),
|
||||
np.array(xl)
|
||||
)
|
||||
print(ws.inverse_transform(x[0]))
|
||||
print(ws.inverse_transform(y[0]))
|
||||
print(ws.inverse_transform(pred[0]))
|
||||
t += 1
|
||||
if t >= 3:
|
||||
break
|
||||
|
||||
tf.reset_default_graph()
|
||||
model_pred = SequenceToSequence(
|
||||
input_vocab_size=len(ws),
|
||||
target_vocab_size=len(ws),
|
||||
batch_size=1,
|
||||
mode='decode',
|
||||
beam_width=1,
|
||||
**params
|
||||
)
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(init)
|
||||
model_pred.load(sess, save_path)
|
||||
|
||||
bar = batch_flow([x_data, y_data], ws, 1, add_end=False)
|
||||
t = 0
|
||||
for x, xl, y, yl in bar:
|
||||
pred = model_pred.predict(
|
||||
sess,
|
||||
np.array(x),
|
||||
np.array(xl)
|
||||
)
|
||||
print(ws.inverse_transform(x[0]))
|
||||
print(ws.inverse_transform(y[0]))
|
||||
print(ws.inverse_transform(pred[0]))
|
||||
t += 1
|
||||
if t >= 3:
|
||||
break
|
||||
|
||||
|
||||
def main():
|
||||
"""入口程序"""
|
||||
train_and_dev(json.load(open(path_params)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,5 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# !/usr/bin/python
|
||||
# @time :2019/4/15 10:52
|
||||
# @author :Mo
|
||||
# @function :
|
142
AugmentText/augment_seq2seq/code_seq2seq_word/extract_webank.py
Normal file
142
AugmentText/augment_seq2seq/code_seq2seq_word/extract_webank.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
把文件格式转换为可训练格式
|
||||
Code from: QHDuan(2018-02-05) url: https://github.com/qhduan/just_another_seq2seq
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
import pickle
|
||||
import jieba
|
||||
import gensim
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from conf.path_config import projectdir
|
||||
from conf.path_config import w2v_model_merge_short_path
|
||||
from utils.mode_util.seq2seq.word_sequence import WordSequence
|
||||
|
||||
from conf.path_config import model_ckpt_web_anti_word
|
||||
from conf.path_config import train_data_web_xyw_anti
|
||||
from conf.path_config import train_data_web_emb_anti
|
||||
from conf.path_config import path_webank_sim
|
||||
|
||||
sys.path.append('..')
|
||||
|
||||
|
||||
def make_split(line):
|
||||
"""构造合并两个句子之间的符号
|
||||
"""
|
||||
if re.match(r'.*([,。…?!~\.,!?])$', ''.join(line)):
|
||||
return []
|
||||
return [',']
|
||||
|
||||
|
||||
def good_line(line):
|
||||
"""判断一个句子是否好"""
|
||||
if len(re.findall(r'[a-zA-Z0-9]', ''.join(line))) > 2:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def regular(sen, limit=50):
|
||||
sen = re.sub(r'\.{3,100}', '…', sen)
|
||||
sen = re.sub(r'…{2,100}', '…', sen)
|
||||
sen = re.sub(r'[,]{1,100}', ',', sen)
|
||||
sen = re.sub(r'[\.]{1,100}', '。', sen)
|
||||
sen = re.sub(r'[\?]{1,100}', '?', sen)
|
||||
sen = re.sub(r'[!]{1,100}', '!', sen)
|
||||
if len(sen) > limit:
|
||||
sen = sen[0:limit]
|
||||
return sen
|
||||
|
||||
|
||||
def creat_train_data_of_bank_corpus(limit=50, x_limit=3, y_limit=3):
|
||||
"""执行程序
|
||||
Args:
|
||||
limit: 只输出句子长度小于limit的句子
|
||||
"""
|
||||
|
||||
print('load word2vec start!')
|
||||
word_vec = gensim.models.KeyedVectors.load_word2vec_format(w2v_model_merge_short_path, encoding='gbk', binary=False, limit=None)
|
||||
print('load word2vec end!')
|
||||
fp = open(path_webank_sim, 'r', encoding='gbk', errors='ignore')
|
||||
|
||||
x_datas = []
|
||||
y_datas = []
|
||||
max_len = 0
|
||||
count_fp = 0
|
||||
for line in tqdm(fp):
|
||||
count_fp += 1
|
||||
if count_fp == 1:
|
||||
continue
|
||||
sim_bank_datas_one_split = line.strip().split(",")
|
||||
len_x1 = len(sim_bank_datas_one_split[0])
|
||||
len_x2 = len(sim_bank_datas_one_split[1])
|
||||
# if max_len < len_x1 or max_len < len_x2:
|
||||
max_len = max(len_x1, len_x2, max_len)
|
||||
|
||||
sentence_org = regular(sim_bank_datas_one_split[0], limit=limit)
|
||||
sentence_sim = regular(sim_bank_datas_one_split[1], limit=limit)
|
||||
org_cut = jieba._lcut(sentence_org)
|
||||
sen_cut = jieba._lcut(sentence_sim)
|
||||
|
||||
x_datas.append(org_cut)
|
||||
y_datas.append(sen_cut)
|
||||
x_datas.append(sen_cut)
|
||||
y_datas.append(org_cut)
|
||||
|
||||
print(len(x_datas), len(y_datas))
|
||||
for ask, answer in zip(x_datas[:50], y_datas[:50]):
|
||||
print(''.join(ask))
|
||||
print(''.join(answer))
|
||||
print('-' * 50)
|
||||
|
||||
data = list(zip(x_datas, y_datas))
|
||||
data = [
|
||||
(x, y)
|
||||
for x, y in data
|
||||
if len(x) < limit \
|
||||
and len(y) < limit \
|
||||
and len(y) >= y_limit \
|
||||
and len(x) >= x_limit
|
||||
]
|
||||
x_data, y_data = zip(*data)
|
||||
|
||||
print('refine train data')
|
||||
|
||||
train_data = x_data + y_data
|
||||
|
||||
print('fit word_sequence')
|
||||
|
||||
ws_input = WordSequence()
|
||||
|
||||
ws_input.fit(train_data, max_features=100000)
|
||||
|
||||
print('dump word_sequence')
|
||||
|
||||
pickle.dump((x_data, y_data, ws_input),
|
||||
open(train_data_web_xyw_anti, 'wb')
|
||||
)
|
||||
|
||||
print('make embedding vecs')
|
||||
|
||||
emb = np.zeros((len(ws_input), len(word_vec['</s>'])))
|
||||
|
||||
np.random.seed(1)
|
||||
for word, ind in ws_input.dict.items():
|
||||
if word in word_vec:
|
||||
emb[ind] = word_vec[word]
|
||||
else:
|
||||
emb[ind] = np.random.random(size=(300,)) - 0.5
|
||||
|
||||
print('dump emb')
|
||||
|
||||
pickle.dump(
|
||||
emb,
|
||||
open(train_data_web_emb_anti, 'wb')
|
||||
)
|
||||
|
||||
print('done')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
creat_train_data_of_bank_corpus()
|
@ -0,0 +1,117 @@
|
||||
"""
|
||||
对SequenceToSequence模型进行基本的参数组合测试
|
||||
"""
|
||||
|
||||
from utils.mode_util.seq2seq.thread_generator import ThreadedGenerator
|
||||
from utils.mode_util.seq2seq.model_seq2seq import SequenceToSequence
|
||||
from utils.mode_util.seq2seq.word_sequence import WordSequence
|
||||
from utils.mode_util.seq2seq.data_utils import batch_flow
|
||||
|
||||
from conf.path_config import model_ckpt_web_anti_word
|
||||
from conf.path_config import train_data_web_xyw_anti
|
||||
from conf.path_config import train_data_web_emb_anti
|
||||
from conf.path_config import path_webank_sim
|
||||
from conf.path_config import path_params
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import random
|
||||
import pickle
|
||||
import jieba
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append('..')
|
||||
|
||||
|
||||
def pred_word_anti(bidirectional, cell_type, depth,
|
||||
attention_type, use_residual, use_dropout, time_major, hidden_units):
|
||||
"""测试不同参数在生成的假数据上的运行结果"""
|
||||
|
||||
x_data, _, ws = pickle.load(open(train_data_web_xyw_anti, 'rb'))
|
||||
|
||||
for x in x_data[:5]:
|
||||
print(' '.join(x))
|
||||
|
||||
config = tf.ConfigProto(
|
||||
device_count={'CPU': 1, 'GPU': 0},
|
||||
allow_soft_placement=True,
|
||||
log_device_placement=False
|
||||
)
|
||||
|
||||
save_path = model_ckpt_web_anti_word
|
||||
|
||||
# 测试部分
|
||||
tf.reset_default_graph()
|
||||
model_pred = SequenceToSequence(
|
||||
input_vocab_size=len(ws),
|
||||
target_vocab_size=len(ws),
|
||||
batch_size=1,
|
||||
mode='decode',
|
||||
beam_width=1,
|
||||
bidirectional=bidirectional,
|
||||
cell_type=cell_type,
|
||||
depth=depth,
|
||||
attention_type=attention_type,
|
||||
use_residual=use_residual,
|
||||
use_dropout=use_dropout,
|
||||
parallel_iterations=1,
|
||||
time_major=time_major,
|
||||
hidden_units=hidden_units,
|
||||
share_embedding=True,
|
||||
pretrained_embedding=True
|
||||
)
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(init)
|
||||
model_pred.load(sess, save_path)
|
||||
|
||||
while True:
|
||||
user_text = input('Input Chat Sentence:')
|
||||
if user_text in ('exit', 'quit'):
|
||||
exit(0)
|
||||
x_test = [jieba.lcut(user_text.lower())]
|
||||
# x_test = [word_tokenize(user_text)]
|
||||
bar = batch_flow([x_test], ws, 1)
|
||||
x, xl = next(bar)
|
||||
x = np.flip(x, axis=1)
|
||||
# x = np.array([
|
||||
# list(reversed(xx))
|
||||
# for xx in x
|
||||
# ])
|
||||
print(x, xl)
|
||||
pred = model_pred.predict(
|
||||
sess,
|
||||
np.array(x),
|
||||
np.array(xl)
|
||||
)
|
||||
print(pred)
|
||||
# prob = np.exp(prob.transpose())
|
||||
print(ws.inverse_transform(x[0]))
|
||||
# print(ws.inverse_transform(pred[0]))
|
||||
# print(pred.shape, prob.shape)
|
||||
for p in pred:
|
||||
ans = ws.inverse_transform(p)
|
||||
print(ans)
|
||||
|
||||
|
||||
def main():
|
||||
"""入口程序,开始测试不同参数组合"""
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
tf.set_random_seed(0)
|
||||
pred_word_anti(
|
||||
bidirectional=True,
|
||||
cell_type='lstm',
|
||||
depth=2,
|
||||
attention_type='Bahdanau',
|
||||
use_residual=False,
|
||||
use_dropout=False,
|
||||
time_major=False,
|
||||
hidden_units=512
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
227
AugmentText/augment_seq2seq/code_seq2seq_word/train_word_anti.py
Normal file
227
AugmentText/augment_seq2seq/code_seq2seq_word/train_word_anti.py
Normal file
@ -0,0 +1,227 @@
|
||||
"""
|
||||
对SequenceToSequence模型进行基本的参数组合测试
|
||||
"""
|
||||
|
||||
from utils.mode_util.seq2seq.data_utils import batch_flow
|
||||
from utils.mode_util.seq2seq.thread_generator import ThreadedGenerator
|
||||
from utils.mode_util.seq2seq.model_seq2seq import SequenceToSequence
|
||||
from utils.mode_util.seq2seq.word_sequence import WordSequence
|
||||
|
||||
from conf.path_config import model_ckpt_web_anti_word
|
||||
from conf.path_config import train_data_web_xyw_anti
|
||||
from conf.path_config import train_data_web_emb_anti
|
||||
from conf.path_config import path_webank_sim
|
||||
from conf.path_config import path_params
|
||||
|
||||
import tensorflow as tf
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import random
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
sys.path.append('..')
|
||||
|
||||
|
||||
def train_word_anti(bidirectional, cell_type, depth,
|
||||
attention_type, use_residual, use_dropout, time_major, hidden_units):
|
||||
"""测试不同参数在生成的假数据上的运行结果"""
|
||||
|
||||
emb = pickle.load(open(train_data_web_emb_anti, 'rb'))
|
||||
|
||||
x_data, y_data, ws = pickle.load(
|
||||
open(train_data_web_xyw_anti, 'rb'))
|
||||
|
||||
# 训练部分
|
||||
n_epoch = 10
|
||||
batch_size = 128
|
||||
# x_data, y_data = shuffle(x_data, y_data, random_state=0)
|
||||
# x_data = x_data[:100000]
|
||||
# y_data = y_data[:100000]
|
||||
steps = int(len(x_data) / batch_size) + 1
|
||||
|
||||
config = tf.ConfigProto(
|
||||
# device_count={'CPU': 1, 'GPU': 0},
|
||||
allow_soft_placement=True,
|
||||
log_device_placement=False
|
||||
)
|
||||
|
||||
save_path = model_ckpt_web_anti_word
|
||||
|
||||
tf.reset_default_graph()
|
||||
with tf.Graph().as_default():
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
tf.set_random_seed(0)
|
||||
|
||||
with tf.Session(config=config) as sess:
|
||||
|
||||
model = SequenceToSequence(
|
||||
input_vocab_size=len(ws),
|
||||
target_vocab_size=len(ws),
|
||||
batch_size=batch_size,
|
||||
bidirectional=bidirectional,
|
||||
cell_type=cell_type,
|
||||
depth=depth,
|
||||
attention_type=attention_type,
|
||||
use_residual=use_residual,
|
||||
use_dropout=use_dropout,
|
||||
hidden_units=hidden_units,
|
||||
time_major=time_major,
|
||||
learning_rate=0.001,
|
||||
optimizer='adam',
|
||||
share_embedding=True,
|
||||
dropout=0.2,
|
||||
pretrained_embedding=True
|
||||
)
|
||||
init = tf.global_variables_initializer()
|
||||
sess.run(init)
|
||||
|
||||
# 加载训练好的embedding
|
||||
model.feed_embedding(sess, encoder=emb)
|
||||
|
||||
# print(sess.run(model.input_layer.kernel))
|
||||
# exit(1)
|
||||
|
||||
flow = ThreadedGenerator(
|
||||
batch_flow([x_data, y_data], ws, batch_size),
|
||||
queue_maxsize=30)
|
||||
|
||||
dummy_encoder_inputs = np.array([
|
||||
np.array([WordSequence.PAD]) for _ in range(batch_size)])
|
||||
dummy_encoder_inputs_lengths = np.array([1] * batch_size)
|
||||
|
||||
for epoch in range(1, n_epoch + 1):
|
||||
costs = []
|
||||
bar = tqdm(range(steps), total=steps,
|
||||
desc='epoch {}, loss=0.000000'.format(epoch))
|
||||
for _ in bar:
|
||||
x, xl, y, yl = next(flow)
|
||||
x = np.flip(x, axis=1)
|
||||
|
||||
add_loss = model.train(sess,
|
||||
dummy_encoder_inputs,
|
||||
dummy_encoder_inputs_lengths,
|
||||
y, yl, loss_only=True)
|
||||
|
||||
add_loss *= -0.5
|
||||
# print(x, y)
|
||||
cost, lr = model.train(sess, x, xl, y, yl,
|
||||
return_lr=True, add_loss=add_loss)
|
||||
costs.append(cost)
|
||||
bar.set_description('epoch {} loss={:.6f} lr={:.6f}'.format(
|
||||
epoch,
|
||||
np.mean(costs),
|
||||
lr
|
||||
))
|
||||
|
||||
model.save(sess, save_path)
|
||||
|
||||
flow.close()
|
||||
|
||||
# 测试部分
|
||||
tf.reset_default_graph()
|
||||
model_pred = SequenceToSequence(
|
||||
input_vocab_size=len(ws),
|
||||
target_vocab_size=len(ws),
|
||||
batch_size=1,
|
||||
mode='decode',
|
||||
beam_width=12,
|
||||
bidirectional=bidirectional,
|
||||
cell_type=cell_type,
|
||||
depth=depth,
|
||||
attention_type=attention_type,
|
||||
use_residual=use_residual,
|
||||
use_dropout=use_dropout,
|
||||
hidden_units=hidden_units,
|
||||
time_major=time_major,
|
||||
parallel_iterations=1,
|
||||
learning_rate=0.001,
|
||||
optimizer='adam',
|
||||
share_embedding=True,
|
||||
pretrained_embedding=True
|
||||
)
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(init)
|
||||
model_pred.load(sess, save_path)
|
||||
|
||||
bar = batch_flow([x_data, y_data], ws, 1)
|
||||
t = 0
|
||||
for x, xl, y, yl in bar:
|
||||
x = np.flip(x, axis=1)
|
||||
pred = model_pred.predict(
|
||||
sess,
|
||||
np.array(x),
|
||||
np.array(xl)
|
||||
)
|
||||
print(ws.inverse_transform(x[0]))
|
||||
print(ws.inverse_transform(y[0]))
|
||||
print(ws.inverse_transform(pred[0]))
|
||||
t += 1
|
||||
if t >= 3:
|
||||
break
|
||||
|
||||
tf.reset_default_graph()
|
||||
model_pred = SequenceToSequence(
|
||||
input_vocab_size=len(ws),
|
||||
target_vocab_size=len(ws),
|
||||
batch_size=1,
|
||||
mode='decode',
|
||||
beam_width=1,
|
||||
bidirectional=bidirectional,
|
||||
cell_type=cell_type,
|
||||
depth=depth,
|
||||
attention_type=attention_type,
|
||||
use_residual=use_residual,
|
||||
use_dropout=use_dropout,
|
||||
hidden_units=hidden_units,
|
||||
time_major=time_major,
|
||||
parallel_iterations=1,
|
||||
learning_rate=0.001,
|
||||
optimizer='adam',
|
||||
share_embedding=True,
|
||||
pretrained_embedding=True
|
||||
)
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(init)
|
||||
model_pred.load(sess, save_path)
|
||||
|
||||
bar = batch_flow([x_data, y_data], ws, 1)
|
||||
t = 0
|
||||
for x, xl, y, yl in bar:
|
||||
pred = model_pred.predict(
|
||||
sess,
|
||||
np.array(x),
|
||||
np.array(xl)
|
||||
)
|
||||
print(ws.inverse_transform(x[0]))
|
||||
print(ws.inverse_transform(y[0]))
|
||||
print(ws.inverse_transform(pred[0]))
|
||||
t += 1
|
||||
if t >= 3:
|
||||
break
|
||||
|
||||
|
||||
def main():
|
||||
"""入口程序,开始测试不同参数组合"""
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
tf.set_random_seed(0)
|
||||
train_word_anti(
|
||||
bidirectional=True,
|
||||
cell_type='lstm',
|
||||
depth=2,
|
||||
attention_type='Bahdanau',
|
||||
use_residual=False,
|
||||
use_dropout=False,
|
||||
time_major=False,
|
||||
hidden_units=512
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
1
AugmentText/augment_seq2seq/data_mid/char/useless.txt
Normal file
1
AugmentText/augment_seq2seq/data_mid/char/useless.txt
Normal file
@ -0,0 +1 @@
|
||||
useless
|
1
AugmentText/augment_seq2seq/data_mid/word/useless.txt
Normal file
1
AugmentText/augment_seq2seq/data_mid/word/useless.txt
Normal file
@ -0,0 +1 @@
|
||||
useless
|
@ -0,0 +1 @@
|
||||
useless
|
@ -0,0 +1 @@
|
||||
useless
|
Loading…
Reference in New Issue
Block a user