From 849ae9fa08dcb263491aaa7e2528c3a210db96b7 Mon Sep 17 00:00:00 2001 From: yongzhuo <31341349+yongzhuo@users.noreply.github.com> Date: Wed, 17 Apr 2019 20:07:33 +0800 Subject: [PATCH] augment of seq2seq with webank datas --- AugmentText/augment_seq2seq/__init__.py | 5 + .../code_seq2seq_char/__init__.py | 5 + .../code_seq2seq_char/extract_char_webank.py | 91 +++++++ .../code_seq2seq_char/predict_char_anti.py | 95 ++++++++ .../code_seq2seq_char/train_char_anti.py | 174 ++++++++++++++ .../code_seq2seq_word/__init__.py | 5 + .../code_seq2seq_word/extract_webank.py | 142 +++++++++++ .../code_seq2seq_word/predict_word_anti.py | 117 +++++++++ .../code_seq2seq_word/train_word_anti.py | 227 ++++++++++++++++++ .../augment_seq2seq/data_mid/char/useless.txt | 1 + .../augment_seq2seq/data_mid/word/useless.txt | 1 + .../seq2seq_char_webank/useless.txt | 1 + .../seq2seq_word_webank/useless.txt | 1 + 13 files changed, 865 insertions(+) create mode 100644 AugmentText/augment_seq2seq/__init__.py create mode 100644 AugmentText/augment_seq2seq/code_seq2seq_char/__init__.py create mode 100644 AugmentText/augment_seq2seq/code_seq2seq_char/extract_char_webank.py create mode 100644 AugmentText/augment_seq2seq/code_seq2seq_char/predict_char_anti.py create mode 100644 AugmentText/augment_seq2seq/code_seq2seq_char/train_char_anti.py create mode 100644 AugmentText/augment_seq2seq/code_seq2seq_word/__init__.py create mode 100644 AugmentText/augment_seq2seq/code_seq2seq_word/extract_webank.py create mode 100644 AugmentText/augment_seq2seq/code_seq2seq_word/predict_word_anti.py create mode 100644 AugmentText/augment_seq2seq/code_seq2seq_word/train_word_anti.py create mode 100644 AugmentText/augment_seq2seq/data_mid/char/useless.txt create mode 100644 AugmentText/augment_seq2seq/data_mid/word/useless.txt create mode 100644 AugmentText/augment_seq2seq/model_seq2seq_tp/seq2seq_char_webank/useless.txt create mode 100644 AugmentText/augment_seq2seq/model_seq2seq_tp/seq2seq_word_webank/useless.txt diff --git a/AugmentText/augment_seq2seq/__init__.py b/AugmentText/augment_seq2seq/__init__.py new file mode 100644 index 0000000..500dc6e --- /dev/null +++ b/AugmentText/augment_seq2seq/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: UTF-8 -*- +# !/usr/bin/python +# @time :2019/4/15 10:17 +# @author :Mo +# @function : \ No newline at end of file diff --git a/AugmentText/augment_seq2seq/code_seq2seq_char/__init__.py b/AugmentText/augment_seq2seq/code_seq2seq_char/__init__.py new file mode 100644 index 0000000..397e1d7 --- /dev/null +++ b/AugmentText/augment_seq2seq/code_seq2seq_char/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: UTF-8 -*- +# !/usr/bin/python +# @time :2019/4/15 10:50 +# @author :Mo +# @function : \ No newline at end of file diff --git a/AugmentText/augment_seq2seq/code_seq2seq_char/extract_char_webank.py b/AugmentText/augment_seq2seq/code_seq2seq_char/extract_char_webank.py new file mode 100644 index 0000000..af9f080 --- /dev/null +++ b/AugmentText/augment_seq2seq/code_seq2seq_char/extract_char_webank.py @@ -0,0 +1,91 @@ +""" +把文件格式转换为可训练格式 +Code from: QHDuan(2018-02-05) url: https://github.com/qhduan/just_another_seq2seq +""" +from conf.path_config import train_data_web_ws_anti +from conf.path_config import train_data_web_xy_anti +from conf.path_config import model_ckpt_web_anti +from conf.path_config import path_webank_sim + +from utils.mode_util.seq2seq.word_sequence import WordSequence +from utils.text_tools import txtRead +from tqdm import tqdm +import pickle +import sys +import re + +sys.path.append('..') + + +def make_split(line): + """构造合并两个句子之间的符号 + """ + if re.match(r'.*([,。…?!~\.,!?])$', ''.join(line)): + return [] + return [','] + + +def good_line(line): + if len(re.findall(r'[a-zA-Z0-9]', ''.join(line))) > 2: + return False + return True + + +def regular(sen, limit=50): + sen = re.sub(r'\.{3,100}', '…', sen) + sen = re.sub(r'…{2,100}', '…', sen) + sen = re.sub(r'[,]{1,100}', ',', sen) + sen = re.sub(r'[\.]{1,100}', '。', sen) + sen = re.sub(r'[\?]{1,100}', '?', sen) + sen = re.sub(r'[!]{1,100}', '!', sen) + if len(sen) > limit: + sen = sen[0:limit] + return sen + + +def creat_train_data_of_sim_corpus(limit=50, x_limit=2, y_limit=2): + x_datas = [] + y_datas = [] + max_len = 0 + sim_ali_web_gov_dli_datas = txtRead(path_webank_sim, encodeType="gbk") + for sim_ali_web_gov_dli_datas_one in sim_ali_web_gov_dli_datas[1:]: + sim_ali_web_gov_dli_datas_one_split = sim_ali_web_gov_dli_datas_one.strip().split(",") + if sim_ali_web_gov_dli_datas_one_split[2]=="1": + len_x1 = len(sim_ali_web_gov_dli_datas_one_split[0]) + len_x2 = len(sim_ali_web_gov_dli_datas_one_split[1]) + # if max_len < len_x1 or max_len < len_x2: + max_len = max(len_x1, len_x2, max_len) + + sentence_org = regular(sim_ali_web_gov_dli_datas_one_split[0], limit=limit) + sentence_sim = regular(sim_ali_web_gov_dli_datas_one_split[1], limit=limit) + x_datas.append([sen for sen in sentence_org]) + y_datas.append([sen for sen in sentence_sim]) + x_datas.append([sen for sen in sentence_sim]) + y_datas.append([sen for sen in sentence_org]) + + datas = list(zip(x_datas, y_datas)) + datas = [ + (x, y) + for x, y in datas + if len(x) < limit and len(y) < limit and len(y) >= y_limit and len(x) >= x_limit + ] + x_datas, y_datas = zip(*datas) + + print('fit word_sequence') + + ws_input = WordSequence() + ws_input.fit(x_datas + y_datas) + + print('dump') + + pickle.dump((x_datas, y_datas), + open(train_data_web_xy_anti, 'wb') + ) + pickle.dump(ws_input, open(train_data_web_ws_anti, 'wb')) + + print('done') + print(max_len) + + +if __name__ == '__main__': + creat_train_data_of_sim_corpus() diff --git a/AugmentText/augment_seq2seq/code_seq2seq_char/predict_char_anti.py b/AugmentText/augment_seq2seq/code_seq2seq_char/predict_char_anti.py new file mode 100644 index 0000000..5983866 --- /dev/null +++ b/AugmentText/augment_seq2seq/code_seq2seq_char/predict_char_anti.py @@ -0,0 +1,95 @@ +""" +对SequenceToSequence模型进行基本的参数组合测试 +Code from: QHDuan(2018-02-05) url: https://github.com/qhduan/just_another_seq2seq + +""" + +from utils.mode_util.seq2seq.data_utils import batch_flow_bucket as batch_flow +from utils.mode_util.seq2seq.thread_generator import ThreadedGenerator +from utils.mode_util.seq2seq.model_seq2seq import SequenceToSequence +from utils.mode_util.seq2seq.word_sequence import WordSequence + +from conf.path_config import train_data_web_ws_anti +from conf.path_config import train_data_web_xy_anti +from conf.path_config import model_ckpt_web_anti +from conf.path_config import path_params + +import tensorflow as tf +import numpy as np +import pickle +import json +import sys + +sys.path.append('..') + + +def predict_anti(params): + """测试不同参数在生成的假数据上的运行结果""" + + x_data, _ = pickle.load(open(train_data_web_xy_anti, 'rb')) + ws = pickle.load(open(train_data_web_ws_anti, 'rb')) + + for x in x_data[:5]: + print(' '.join(x)) + + config = tf.ConfigProto( + # device_count={'CPU': 1, 'GPU': 0}, + allow_soft_placement=True, + log_device_placement=False + ) + + save_path = model_ckpt_web_anti + + # 测试部分 + tf.reset_default_graph() + model_pred = SequenceToSequence( + input_vocab_size=len(ws), + target_vocab_size=len(ws), + batch_size=1, + mode='decode', + beam_width=0, + **params + ) + init = tf.global_variables_initializer() + + with tf.Session(config=config) as sess: + sess.run(init) + model_pred.load(sess, save_path) + + while True: + user_text = input('Input Chat Sentence:') + if user_text in ('exit', 'quit'): + exit(0) + x_test = [list(user_text.lower())] + # x_test = [word_tokenize(user_text)] + bar = batch_flow([x_test], ws, 1) + x, xl = next(bar) + x = np.flip(x, axis=1) + # x = np.array([ + # list(reversed(xx)) + # for xx in x + # ]) + print(x, xl) + pred = model_pred.predict( + sess, + np.array(x), + np.array(xl) + ) + print(pred) + # prob = np.exp(prob.transpose()) + print(ws.inverse_transform(x[0])) + # print(ws.inverse_transform(pred[0])) + # print(pred.shape, prob.shape) + for p in pred: + ans = ws.inverse_transform(p) + print(ans) + + +def main(): + """入口程序""" + import json + predict_anti(json.load(open(path_params))) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/AugmentText/augment_seq2seq/code_seq2seq_char/train_char_anti.py b/AugmentText/augment_seq2seq/code_seq2seq_char/train_char_anti.py new file mode 100644 index 0000000..374a3e7 --- /dev/null +++ b/AugmentText/augment_seq2seq/code_seq2seq_char/train_char_anti.py @@ -0,0 +1,174 @@ +""" +对SequenceToSequence模型进行基本的参数组合测试 +Code from: QHDuan(2018-02-05) url: https://github.com/qhduan/just_another_seq2seq + +""" + +from utils.mode_util.seq2seq.data_utils import batch_flow_bucket as batch_flow +from utils.mode_util.seq2seq.thread_generator import ThreadedGenerator +from utils.mode_util.seq2seq.model_seq2seq import SequenceToSequence +from utils.mode_util.seq2seq.word_sequence import WordSequence + +from conf.path_config import train_data_web_ws_anti +from conf.path_config import train_data_web_xy_anti +from conf.path_config import model_ckpt_web_anti +from conf.path_config import path_params + +from sklearn.utils import shuffle +import tensorflow as tf +from tqdm import tqdm +import numpy as np +import pickle +import random +import json +import sys + +sys.path.append('..') + + +def train_and_dev(params): + """测试不同参数在生成的假数据上的运行结果""" + + x_data, y_data = pickle.load(open(train_data_web_xy_anti, 'rb')) + ws = pickle.load(open(train_data_web_ws_anti, 'rb')) + + # 训练部分 + n_epoch = 2 + batch_size = 128 + x_data, y_data = shuffle(x_data, y_data, random_state=20190412) + + steps = int(len(x_data) / batch_size) + 1 + + config = tf.ConfigProto( + # device_count={'CPU': 1, 'GPU': 0}, + allow_soft_placement=True, + log_device_placement=False + ) + + save_path = model_ckpt_web_anti + + tf.reset_default_graph() + with tf.Graph().as_default(): + random.seed(0) + np.random.seed(0) + tf.set_random_seed(0) + + with tf.Session(config=config) as sess: + + model = SequenceToSequence( + input_vocab_size=len(ws), + target_vocab_size=len(ws), + batch_size=batch_size, + **params + ) + init = tf.global_variables_initializer() + sess.run(init) + + + flow = ThreadedGenerator( + batch_flow([x_data, y_data], ws, batch_size, + add_end=[False, True]), + queue_maxsize=30) + + dummy_encoder_inputs = np.array([ + np.array([WordSequence.PAD]) for _ in range(batch_size)]) + dummy_encoder_inputs_lengths = np.array([1] * batch_size) + + for epoch in range(1, n_epoch + 1): + costs = [] + bar = tqdm(range(steps), total=steps, + desc='epoch {}, loss=0.000000'.format(epoch)) + for _ in bar: + x, xl, y, yl = next(flow) + x = np.flip(x, axis=1) + + add_loss = model.train(sess, + dummy_encoder_inputs, + dummy_encoder_inputs_lengths, + y, yl, loss_only=True) + + add_loss *= -0.5 + + cost, lr = model.train(sess, x, xl, y, yl, + return_lr=True, add_loss=add_loss) + costs.append(cost) + bar.set_description('epoch {} loss={:.6f} lr={:.6f}'.format( + epoch, + np.mean(costs), + lr + )) + + model.save(sess, save_path) + + flow.close() + + # 测试部分 + tf.reset_default_graph() + model_pred = SequenceToSequence( + input_vocab_size=len(ws), + target_vocab_size=len(ws), + batch_size=1, + mode='decode', + beam_width=12, + **params + ) + init = tf.global_variables_initializer() + + with tf.Session(config=config) as sess: + sess.run(init) + model_pred.load(sess, save_path) + + bar = batch_flow([x_data, y_data], ws, 1, add_end=False) + t = 0 + for x, xl, y, yl in bar: + x = np.flip(x, axis=1) + pred = model_pred.predict( + sess, + np.array(x), + np.array(xl) + ) + print(ws.inverse_transform(x[0])) + print(ws.inverse_transform(y[0])) + print(ws.inverse_transform(pred[0])) + t += 1 + if t >= 3: + break + + tf.reset_default_graph() + model_pred = SequenceToSequence( + input_vocab_size=len(ws), + target_vocab_size=len(ws), + batch_size=1, + mode='decode', + beam_width=1, + **params + ) + init = tf.global_variables_initializer() + + with tf.Session(config=config) as sess: + sess.run(init) + model_pred.load(sess, save_path) + + bar = batch_flow([x_data, y_data], ws, 1, add_end=False) + t = 0 + for x, xl, y, yl in bar: + pred = model_pred.predict( + sess, + np.array(x), + np.array(xl) + ) + print(ws.inverse_transform(x[0])) + print(ws.inverse_transform(y[0])) + print(ws.inverse_transform(pred[0])) + t += 1 + if t >= 3: + break + + +def main(): + """入口程序""" + train_and_dev(json.load(open(path_params))) + + +if __name__ == '__main__': + main() diff --git a/AugmentText/augment_seq2seq/code_seq2seq_word/__init__.py b/AugmentText/augment_seq2seq/code_seq2seq_word/__init__.py new file mode 100644 index 0000000..2a1ecd2 --- /dev/null +++ b/AugmentText/augment_seq2seq/code_seq2seq_word/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: UTF-8 -*- +# !/usr/bin/python +# @time :2019/4/15 10:52 +# @author :Mo +# @function : \ No newline at end of file diff --git a/AugmentText/augment_seq2seq/code_seq2seq_word/extract_webank.py b/AugmentText/augment_seq2seq/code_seq2seq_word/extract_webank.py new file mode 100644 index 0000000..2fe617a --- /dev/null +++ b/AugmentText/augment_seq2seq/code_seq2seq_word/extract_webank.py @@ -0,0 +1,142 @@ +""" +把文件格式转换为可训练格式 +Code from: QHDuan(2018-02-05) url: https://github.com/qhduan/just_another_seq2seq +""" + +import re +import sys +import pickle +import jieba +import gensim +import numpy as np +from tqdm import tqdm +from conf.path_config import projectdir +from conf.path_config import w2v_model_merge_short_path +from utils.mode_util.seq2seq.word_sequence import WordSequence + +from conf.path_config import model_ckpt_web_anti_word +from conf.path_config import train_data_web_xyw_anti +from conf.path_config import train_data_web_emb_anti +from conf.path_config import path_webank_sim + +sys.path.append('..') + + +def make_split(line): + """构造合并两个句子之间的符号 + """ + if re.match(r'.*([,。…?!~\.,!?])$', ''.join(line)): + return [] + return [','] + + +def good_line(line): + """判断一个句子是否好""" + if len(re.findall(r'[a-zA-Z0-9]', ''.join(line))) > 2: + return False + return True + + +def regular(sen, limit=50): + sen = re.sub(r'\.{3,100}', '…', sen) + sen = re.sub(r'…{2,100}', '…', sen) + sen = re.sub(r'[,]{1,100}', ',', sen) + sen = re.sub(r'[\.]{1,100}', '。', sen) + sen = re.sub(r'[\?]{1,100}', '?', sen) + sen = re.sub(r'[!]{1,100}', '!', sen) + if len(sen) > limit: + sen = sen[0:limit] + return sen + + +def creat_train_data_of_bank_corpus(limit=50, x_limit=3, y_limit=3): + """执行程序 + Args: + limit: 只输出句子长度小于limit的句子 + """ + + print('load word2vec start!') + word_vec = gensim.models.KeyedVectors.load_word2vec_format(w2v_model_merge_short_path, encoding='gbk', binary=False, limit=None) + print('load word2vec end!') + fp = open(path_webank_sim, 'r', encoding='gbk', errors='ignore') + + x_datas = [] + y_datas = [] + max_len = 0 + count_fp = 0 + for line in tqdm(fp): + count_fp += 1 + if count_fp == 1: + continue + sim_bank_datas_one_split = line.strip().split(",") + len_x1 = len(sim_bank_datas_one_split[0]) + len_x2 = len(sim_bank_datas_one_split[1]) + # if max_len < len_x1 or max_len < len_x2: + max_len = max(len_x1, len_x2, max_len) + + sentence_org = regular(sim_bank_datas_one_split[0], limit=limit) + sentence_sim = regular(sim_bank_datas_one_split[1], limit=limit) + org_cut = jieba._lcut(sentence_org) + sen_cut = jieba._lcut(sentence_sim) + + x_datas.append(org_cut) + y_datas.append(sen_cut) + x_datas.append(sen_cut) + y_datas.append(org_cut) + + print(len(x_datas), len(y_datas)) + for ask, answer in zip(x_datas[:50], y_datas[:50]): + print(''.join(ask)) + print(''.join(answer)) + print('-' * 50) + + data = list(zip(x_datas, y_datas)) + data = [ + (x, y) + for x, y in data + if len(x) < limit \ + and len(y) < limit \ + and len(y) >= y_limit \ + and len(x) >= x_limit + ] + x_data, y_data = zip(*data) + + print('refine train data') + + train_data = x_data + y_data + + print('fit word_sequence') + + ws_input = WordSequence() + + ws_input.fit(train_data, max_features=100000) + + print('dump word_sequence') + + pickle.dump((x_data, y_data, ws_input), + open(train_data_web_xyw_anti, 'wb') + ) + + print('make embedding vecs') + + emb = np.zeros((len(ws_input), len(word_vec['']))) + + np.random.seed(1) + for word, ind in ws_input.dict.items(): + if word in word_vec: + emb[ind] = word_vec[word] + else: + emb[ind] = np.random.random(size=(300,)) - 0.5 + + print('dump emb') + + pickle.dump( + emb, + open(train_data_web_emb_anti, 'wb') + ) + + print('done') + + +if __name__ == '__main__': + creat_train_data_of_bank_corpus() \ No newline at end of file diff --git a/AugmentText/augment_seq2seq/code_seq2seq_word/predict_word_anti.py b/AugmentText/augment_seq2seq/code_seq2seq_word/predict_word_anti.py new file mode 100644 index 0000000..4c30239 --- /dev/null +++ b/AugmentText/augment_seq2seq/code_seq2seq_word/predict_word_anti.py @@ -0,0 +1,117 @@ +""" +对SequenceToSequence模型进行基本的参数组合测试 +""" + +from utils.mode_util.seq2seq.thread_generator import ThreadedGenerator +from utils.mode_util.seq2seq.model_seq2seq import SequenceToSequence +from utils.mode_util.seq2seq.word_sequence import WordSequence +from utils.mode_util.seq2seq.data_utils import batch_flow + +from conf.path_config import model_ckpt_web_anti_word +from conf.path_config import train_data_web_xyw_anti +from conf.path_config import train_data_web_emb_anti +from conf.path_config import path_webank_sim +from conf.path_config import path_params + +import tensorflow as tf +import numpy as np +import random +import pickle +import jieba +import sys + + +sys.path.append('..') + + +def pred_word_anti(bidirectional, cell_type, depth, + attention_type, use_residual, use_dropout, time_major, hidden_units): + """测试不同参数在生成的假数据上的运行结果""" + + x_data, _, ws = pickle.load(open(train_data_web_xyw_anti, 'rb')) + + for x in x_data[:5]: + print(' '.join(x)) + + config = tf.ConfigProto( + device_count={'CPU': 1, 'GPU': 0}, + allow_soft_placement=True, + log_device_placement=False + ) + + save_path = model_ckpt_web_anti_word + + # 测试部分 + tf.reset_default_graph() + model_pred = SequenceToSequence( + input_vocab_size=len(ws), + target_vocab_size=len(ws), + batch_size=1, + mode='decode', + beam_width=1, + bidirectional=bidirectional, + cell_type=cell_type, + depth=depth, + attention_type=attention_type, + use_residual=use_residual, + use_dropout=use_dropout, + parallel_iterations=1, + time_major=time_major, + hidden_units=hidden_units, + share_embedding=True, + pretrained_embedding=True + ) + init = tf.global_variables_initializer() + + with tf.Session(config=config) as sess: + sess.run(init) + model_pred.load(sess, save_path) + + while True: + user_text = input('Input Chat Sentence:') + if user_text in ('exit', 'quit'): + exit(0) + x_test = [jieba.lcut(user_text.lower())] + # x_test = [word_tokenize(user_text)] + bar = batch_flow([x_test], ws, 1) + x, xl = next(bar) + x = np.flip(x, axis=1) + # x = np.array([ + # list(reversed(xx)) + # for xx in x + # ]) + print(x, xl) + pred = model_pred.predict( + sess, + np.array(x), + np.array(xl) + ) + print(pred) + # prob = np.exp(prob.transpose()) + print(ws.inverse_transform(x[0])) + # print(ws.inverse_transform(pred[0])) + # print(pred.shape, prob.shape) + for p in pred: + ans = ws.inverse_transform(p) + print(ans) + + +def main(): + """入口程序,开始测试不同参数组合""" + random.seed(0) + np.random.seed(0) + tf.set_random_seed(0) + pred_word_anti( + bidirectional=True, + cell_type='lstm', + depth=2, + attention_type='Bahdanau', + use_residual=False, + use_dropout=False, + time_major=False, + hidden_units=512 + ) + + +if __name__ == '__main__': + main() diff --git a/AugmentText/augment_seq2seq/code_seq2seq_word/train_word_anti.py b/AugmentText/augment_seq2seq/code_seq2seq_word/train_word_anti.py new file mode 100644 index 0000000..1ec7415 --- /dev/null +++ b/AugmentText/augment_seq2seq/code_seq2seq_word/train_word_anti.py @@ -0,0 +1,227 @@ +""" +对SequenceToSequence模型进行基本的参数组合测试 +""" + +from utils.mode_util.seq2seq.data_utils import batch_flow +from utils.mode_util.seq2seq.thread_generator import ThreadedGenerator +from utils.mode_util.seq2seq.model_seq2seq import SequenceToSequence +from utils.mode_util.seq2seq.word_sequence import WordSequence + +from conf.path_config import model_ckpt_web_anti_word +from conf.path_config import train_data_web_xyw_anti +from conf.path_config import train_data_web_emb_anti +from conf.path_config import path_webank_sim +from conf.path_config import path_params + +import tensorflow as tf +from tqdm import tqdm +import numpy as np +import random +import pickle +import sys + +sys.path.append('..') + + +def train_word_anti(bidirectional, cell_type, depth, + attention_type, use_residual, use_dropout, time_major, hidden_units): + """测试不同参数在生成的假数据上的运行结果""" + + emb = pickle.load(open(train_data_web_emb_anti, 'rb')) + + x_data, y_data, ws = pickle.load( + open(train_data_web_xyw_anti, 'rb')) + + # 训练部分 + n_epoch = 10 + batch_size = 128 + # x_data, y_data = shuffle(x_data, y_data, random_state=0) + # x_data = x_data[:100000] + # y_data = y_data[:100000] + steps = int(len(x_data) / batch_size) + 1 + + config = tf.ConfigProto( + # device_count={'CPU': 1, 'GPU': 0}, + allow_soft_placement=True, + log_device_placement=False + ) + + save_path = model_ckpt_web_anti_word + + tf.reset_default_graph() + with tf.Graph().as_default(): + random.seed(0) + np.random.seed(0) + tf.set_random_seed(0) + + with tf.Session(config=config) as sess: + + model = SequenceToSequence( + input_vocab_size=len(ws), + target_vocab_size=len(ws), + batch_size=batch_size, + bidirectional=bidirectional, + cell_type=cell_type, + depth=depth, + attention_type=attention_type, + use_residual=use_residual, + use_dropout=use_dropout, + hidden_units=hidden_units, + time_major=time_major, + learning_rate=0.001, + optimizer='adam', + share_embedding=True, + dropout=0.2, + pretrained_embedding=True + ) + init = tf.global_variables_initializer() + sess.run(init) + + # 加载训练好的embedding + model.feed_embedding(sess, encoder=emb) + + # print(sess.run(model.input_layer.kernel)) + # exit(1) + + flow = ThreadedGenerator( + batch_flow([x_data, y_data], ws, batch_size), + queue_maxsize=30) + + dummy_encoder_inputs = np.array([ + np.array([WordSequence.PAD]) for _ in range(batch_size)]) + dummy_encoder_inputs_lengths = np.array([1] * batch_size) + + for epoch in range(1, n_epoch + 1): + costs = [] + bar = tqdm(range(steps), total=steps, + desc='epoch {}, loss=0.000000'.format(epoch)) + for _ in bar: + x, xl, y, yl = next(flow) + x = np.flip(x, axis=1) + + add_loss = model.train(sess, + dummy_encoder_inputs, + dummy_encoder_inputs_lengths, + y, yl, loss_only=True) + + add_loss *= -0.5 + # print(x, y) + cost, lr = model.train(sess, x, xl, y, yl, + return_lr=True, add_loss=add_loss) + costs.append(cost) + bar.set_description('epoch {} loss={:.6f} lr={:.6f}'.format( + epoch, + np.mean(costs), + lr + )) + + model.save(sess, save_path) + + flow.close() + + # 测试部分 + tf.reset_default_graph() + model_pred = SequenceToSequence( + input_vocab_size=len(ws), + target_vocab_size=len(ws), + batch_size=1, + mode='decode', + beam_width=12, + bidirectional=bidirectional, + cell_type=cell_type, + depth=depth, + attention_type=attention_type, + use_residual=use_residual, + use_dropout=use_dropout, + hidden_units=hidden_units, + time_major=time_major, + parallel_iterations=1, + learning_rate=0.001, + optimizer='adam', + share_embedding=True, + pretrained_embedding=True + ) + init = tf.global_variables_initializer() + + with tf.Session(config=config) as sess: + sess.run(init) + model_pred.load(sess, save_path) + + bar = batch_flow([x_data, y_data], ws, 1) + t = 0 + for x, xl, y, yl in bar: + x = np.flip(x, axis=1) + pred = model_pred.predict( + sess, + np.array(x), + np.array(xl) + ) + print(ws.inverse_transform(x[0])) + print(ws.inverse_transform(y[0])) + print(ws.inverse_transform(pred[0])) + t += 1 + if t >= 3: + break + + tf.reset_default_graph() + model_pred = SequenceToSequence( + input_vocab_size=len(ws), + target_vocab_size=len(ws), + batch_size=1, + mode='decode', + beam_width=1, + bidirectional=bidirectional, + cell_type=cell_type, + depth=depth, + attention_type=attention_type, + use_residual=use_residual, + use_dropout=use_dropout, + hidden_units=hidden_units, + time_major=time_major, + parallel_iterations=1, + learning_rate=0.001, + optimizer='adam', + share_embedding=True, + pretrained_embedding=True + ) + init = tf.global_variables_initializer() + + with tf.Session(config=config) as sess: + sess.run(init) + model_pred.load(sess, save_path) + + bar = batch_flow([x_data, y_data], ws, 1) + t = 0 + for x, xl, y, yl in bar: + pred = model_pred.predict( + sess, + np.array(x), + np.array(xl) + ) + print(ws.inverse_transform(x[0])) + print(ws.inverse_transform(y[0])) + print(ws.inverse_transform(pred[0])) + t += 1 + if t >= 3: + break + + +def main(): + """入口程序,开始测试不同参数组合""" + random.seed(0) + np.random.seed(0) + tf.set_random_seed(0) + train_word_anti( + bidirectional=True, + cell_type='lstm', + depth=2, + attention_type='Bahdanau', + use_residual=False, + use_dropout=False, + time_major=False, + hidden_units=512 + ) + + +if __name__ == '__main__': + main() diff --git a/AugmentText/augment_seq2seq/data_mid/char/useless.txt b/AugmentText/augment_seq2seq/data_mid/char/useless.txt new file mode 100644 index 0000000..5257e61 --- /dev/null +++ b/AugmentText/augment_seq2seq/data_mid/char/useless.txt @@ -0,0 +1 @@ +useless diff --git a/AugmentText/augment_seq2seq/data_mid/word/useless.txt b/AugmentText/augment_seq2seq/data_mid/word/useless.txt new file mode 100644 index 0000000..5257e61 --- /dev/null +++ b/AugmentText/augment_seq2seq/data_mid/word/useless.txt @@ -0,0 +1 @@ +useless diff --git a/AugmentText/augment_seq2seq/model_seq2seq_tp/seq2seq_char_webank/useless.txt b/AugmentText/augment_seq2seq/model_seq2seq_tp/seq2seq_char_webank/useless.txt new file mode 100644 index 0000000..5257e61 --- /dev/null +++ b/AugmentText/augment_seq2seq/model_seq2seq_tp/seq2seq_char_webank/useless.txt @@ -0,0 +1 @@ +useless diff --git a/AugmentText/augment_seq2seq/model_seq2seq_tp/seq2seq_word_webank/useless.txt b/AugmentText/augment_seq2seq/model_seq2seq_tp/seq2seq_word_webank/useless.txt new file mode 100644 index 0000000..5257e61 --- /dev/null +++ b/AugmentText/augment_seq2seq/model_seq2seq_tp/seq2seq_word_webank/useless.txt @@ -0,0 +1 @@ +useless