From 94aa5c8aab71c91893f32413c6a7adbbd60ee54a Mon Sep 17 00:00:00 2001 From: yongzhuo <31341349+yongzhuo@users.noreply.github.com> Date: Mon, 20 May 2019 19:52:26 +0800 Subject: [PATCH] =?UTF-8?q?add=20bert=20text-cnn=E3=80=81r-cnn=20and=20avt?= =?UTF-8?q?-cnn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../bert/keras_bert_classify_text_cnn.py | 317 ++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 ClassificationText/bert/keras_bert_classify_text_cnn.py diff --git a/ClassificationText/bert/keras_bert_classify_text_cnn.py b/ClassificationText/bert/keras_bert_classify_text_cnn.py new file mode 100644 index 0000000..0309ea8 --- /dev/null +++ b/ClassificationText/bert/keras_bert_classify_text_cnn.py @@ -0,0 +1,317 @@ +# -*- coding: UTF-8 -*- +# !/usr/bin/python +# @time :2019/5/18 23:51 +# @author :Mo +# @function :classify text of bert and (text-cnn、r-cnn or avt-cnn) + +from __future__ import division, absolute_import + +from keras.objectives import sparse_categorical_crossentropy, categorical_crossentropy +from conf.path_config import path_webank_train, path_webank_dev, path_webank_test +from keras.layers import Conv1D, GlobalAveragePooling1D, GlobalMaxPooling1D +from keras.callbacks import ModelCheckpoint, EarlyStopping +from keras.layers import SpatialDropout1D, Dropout +from sklearn.metrics import classification_report +from keras.layers import CuDNNGRU, CuDNNLSTM +from keras.layers import Bidirectional +from keras.layers import RepeatVector +from keras.layers import Concatenate +from keras.layers import GRU, LSTM +from keras.layers import Multiply +from keras.layers import Permute +from keras.layers import Lambda +from keras.layers import Dense +from keras.models import Model +from keras import regularizers +import numpy as np +import codecs + +import keras.backend as k_keras +import logging as logger + +from keras_bert import Tokenizer + +from ClassificationText.bert.keras_bert_layer import AttentionWeightedAverage +from ClassificationText.bert.keras_bert_embedding import KerasBertEmbedding +from ClassificationText.bert import args + +from conf.feature_config import config_name, ckpt_name, vocab_file, max_seq_len, layer_indexes, gpu_memory_fraction + + +def attention(inputs, single_attention_vector=False): + # attention机制 + time_steps = k_keras.int_shape(inputs)[1] + input_dim = k_keras.int_shape(inputs)[2] + x = Permute((2, 1))(inputs) + x = Dense(time_steps, activation='softmax')(x) + if single_attention_vector: + x = Lambda(lambda x: k_keras.mean(x, axis=1))(x) + x = RepeatVector(input_dim)(x) + + a_probs = Permute((2, 1))(x) + output_attention_mul = Multiply()([inputs, a_probs]) + return output_attention_mul + + +class BertTextCnnModel(): + def __init__(self): + # logger.info("BertBiLstmModel init start!") + print("BertBiLstmModel init start!") + self.config_path, self.checkpoint_path, self.dict_path = config_name, ckpt_name, vocab_file + self.max_seq_len, self.filters, self.embedding_dim, self.keep_prob = args.max_seq_len, args.filters, args.embedding_dim, args.keep_prob + self.activation, self.label = args.activation, args.label + # reader tokenizer + self.token_dict = {} + with codecs.open(self.dict_path, 'r', 'utf8') as reader: + for line in reader: + token = line.strip() + self.token_dict[token] = len(self.token_dict) + + self.tokenizer = Tokenizer(self.token_dict) + # 这里模型可以选text-rnn、r-cnn或者是avt-cnn + # self.build_model_text_cnn() + # self.build_model_r_cnn() + self.build_model_avt_cnn() + # logger.info("BertBiLstmModel init end!") + print("BertBiLstmModel init end!") + + def build_model_text_cnn(self): + ######### text-cnn ######### + # bert embedding + bert_inputs, bert_output = KerasBertEmbedding().bert_encode() + # text cnn + bert_output_emmbed = SpatialDropout1D(rate=self.keep_prob)(bert_output) + concat_out = [] + for index, filter_size in enumerate(self.filters): + x = Conv1D(name='TextCNN_Conv1D_{}'.format(index), filters=int(self.embedding_dim/2), kernel_size=self.filters[index], padding='valid', kernel_initializer='normal', activation='relu')(bert_output_emmbed) + x = GlobalMaxPooling1D(name='TextCNN_MaxPool1D_{}'.format(index))(x) + concat_out.append(x) + x = Concatenate(axis=1)(concat_out) + x = Dropout(self.keep_prob)(x) + + # 最后就是softmax + dense_layer = Dense(self.label, activation=self.activation)(x) + output_layers = [dense_layer] + self.model = Model(bert_inputs, output_layers) + + def build_model_r_cnn(self): + ######### RCNN ######### + # bert embedding + bert_inputs, bert_output = KerasBertEmbedding().bert_encode() + # rcnn + bert_output_emmbed = SpatialDropout1D(rate=self.keep_prob)(bert_output) + if args.use_lstm: + if args.use_cudnn_cell: + layer_cell = CuDNNLSTM + else: + layer_cell = LSTM + else: + if args.use_cudnn_cell: + layer_cell = CuDNNGRU + else: + layer_cell = GRU + + x = Bidirectional(layer_cell(units=args.units, return_sequences=args.return_sequences, + kernel_regularizer=regularizers.l2(args.l2 * 0.1), + recurrent_regularizer=regularizers.l2(args.l2) + ))(bert_output_emmbed) + x = Dropout(args.keep_prob)(x) + x = Conv1D(filters=int(self.embedding_dim / 2), kernel_size=2, padding='valid', kernel_initializer='normal', activation='relu')(x) + x = GlobalMaxPooling1D()(x) + x = Dropout(args.keep_prob)(x) + # 最后就是softmax + dense_layer = Dense(self.label, activation=self.activation)(x) + output_layers = [dense_layer] + self.model = Model(bert_inputs, output_layers) + + def build_model_avt_cnn(self): + #########text-cnn######### + # bert embedding + bert_inputs, bert_output = KerasBertEmbedding().bert_encode() + # text cnn + bert_output_emmbed = SpatialDropout1D(rate=self.keep_prob)(bert_output) + concat_x = [] + concat_y = [] + concat_z = [] + for index, filter_size in enumerate(self.filters): + conv = Conv1D(name='TextCNN_Conv1D_{}'.format(index), filters=int(self.embedding_dim/2), kernel_size=self.filters[index], padding='valid', kernel_initializer='normal', activation='relu')(bert_output_emmbed) + x = GlobalMaxPooling1D(name='TextCNN_MaxPooling1D_{}'.format(index))(conv) + y = GlobalAveragePooling1D(name='TextCNN_AveragePooling1D_{}'.format(index))(conv) + z = AttentionWeightedAverage(name='TextCNN_Annention_{}'.format(index))(conv) + concat_x.append(x) + concat_y.append(y) + concat_z.append(z) + + merge_x = Concatenate(axis=1)(concat_x) + merge_y = Concatenate(axis=1)(concat_y) + merge_z = Concatenate(axis=1)(concat_z) + merge_xyz = Concatenate(axis=1)([merge_x, merge_y, merge_z]) + x = Dropout(self.keep_prob)(merge_xyz) + + # 最后就是softmax + dense_layer = Dense(self.label, activation=self.activation)(x) + output_layers = [dense_layer] + self.model = Model(bert_inputs, output_layers) + + def compile_model(self): + self.model.compile(optimizer=args.optimizers, + loss=categorical_crossentropy, + metrics=args.metrics) + + def callback(self): + c_b = [ModelCheckpoint(args.path_save_model, monitor='val_loss', verbose=1, save_best_only=True, + save_weights_only=False, mode='min'), + EarlyStopping(min_delta=1e-9, patience=4, mode='min') + ] + return c_b + + def fit(self, x_train, y_train, x_dev, y_dev): + self.model.fit(x_train, y_train, batch_size=args.batch_size, + epochs=args.epochs, validation_data=(x_dev, y_dev), + shuffle=True, + callbacks=self.callback()) + self.model.save(args.path_save_model) + + def load_model(self): + print("BertBiLstmModel load_model start!") + # logger.info("BertBiLstmModel load_model start!") + self.model.load_weights(args.path_save_model) + # logger.info("BertBiLstmModel load_model end+!") + print("BertBiLstmModel load_model end+!") + + def process_pair(self, textss): + # 文本预处理,传入一个list,返回的是ids\mask\type-ids + input_ids = [] + input_masks = [] + input_type_ids = [] + for texts in textss: + tokens_text = self.tokenizer.tokenize(texts[0]) + logger.info('Tokens1:', tokens_text) + tokens_text2 = self.tokenizer.tokenize(texts[1]) + logger.info('Tokens2:', tokens_text2) + input_id, input_type_id = self.tokenizer.encode(first=texts[0], second=texts[1], max_len=self.max_seq_len) + input_mask = [0 if ids == 0 else 1 for ids in input_id] + input_ids.append(input_id) + input_type_ids.append(input_type_id) + input_masks.append(input_mask) + # numpy处理list + input_ids = np.array(input_ids) + input_masks = np.array(input_masks) + input_type_ids = np.array(input_type_ids) + logger.info("process ok!") + return input_ids, input_masks, input_type_ids + + def predict(self, sen_1, sen_2): + input_ids, input_masks, input_type_ids = self.process_pair([[sen_1, sen_2]]) + return self.model.predict([input_ids, input_masks], batch_size=1) + + def predict_list(self, questions): + label_preds = [] + for questions_pair in questions: + input_ids, input_masks, input_type_ids = self.process_pair([questions_pair]) + label_pred = self.model.predict([input_ids, input_masks], batch_size=1) + label_preds.append(label_pred[0]) + return label_preds + + +def classify_pair_corpus_webank(bert_model, path_webank): + # 数据预处理 + from utils.text_tools import text_preprocess, txtRead, txtWrite + import random + + webank_q_2_l = txtRead(path_webank, encodeType='utf-8') + questions = [] + labels = [] + for ques_label in webank_q_2_l[1:]: + q_2_l = ques_label.split(',') + q_1 = q_2_l[0] + q_2 = "".join(q_2_l[1:-1]) + label = q_2_l[-1] + questions.append([text_preprocess(q_1), text_preprocess(q_2)]) + label_int = int(label) + labels.append([0, 1] if label_int == 1 else [1, 0]) + + questions = np.array(questions) + labels = np.array(labels) + + input_ids, input_masks, input_type_ids = bert_model.process_pair(questions) + + return questions, labels, input_ids, input_masks, input_type_ids + + +def train(): + # 1. trian + bert_model = BertTextCnnModel() + bert_model.compile_model() + _, labels_train, input_ids_train, input_masks_train, _ = classify_pair_corpus_webank(bert_model, path_webank_train) + _, labels_dev, input_ids_dev, input_masks_dev, _ = classify_pair_corpus_webank(bert_model, path_webank_dev) + # questions_test, labels_test, input_ids_test, input_masks_test, _ = classify_pair_corpus_webank(bert_model, path_webank_test) + print("process corpus ok!") + bert_model.fit([input_ids_train, input_masks_train], labels_train, [input_ids_dev, input_masks_dev], labels_dev) + print("bert_model fit ok!") + + +def tet(): + # 2.test + bert_model = BertTextCnnModel() + bert_model.load_model() + questions_test, labels_test, input_ids_test, input_masks_test, _ = classify_pair_corpus_webank(bert_model, + path_webank_test) + print('predict_list start! you will wait for a few minutes') + labels_pred = bert_model.predict_list(questions_test) + print('predict_list end!') + + labels_pred_np = np.array(labels_pred) + labels_pred_np_arg = np.argmax(labels_pred_np, axis=1) + labels_test_np = np.array(labels_test) + labels_test_np_arg = np.argmax(labels_test_np, axis=1) + target_names = ['不相似', '相似'] + report_predict = classification_report(labels_test_np_arg, labels_pred_np_arg, + target_names=target_names, digits=9) + print(report_predict) + + +def predict(): + # 3. predict + bert_model = BertTextCnnModel() + bert_model.load_model() + pred = bert_model.predict(sen_1='jy', sen_2='myz') + print(pred[0][1]) + while True: + print("sen_1: ") + sen_1 = input() + print("sen_2: ") + sen_2 = input() + pred = bert_model.predict(sen_1=sen_1, sen_2=sen_2) + print(pred[0][1]) + + +if __name__ == "__main__": + train() + # tet() + # predict() + +# text cnn, not stop +# 100000/100000 [==============================] - 1842s 18ms/step - loss: 0.5173 - acc: 0.7488 - val_loss: 0.4704 - val_acc: 0.7827 +# Epoch 00010: val_loss improved from 0.47364 to 0.47036, saving model to model_webank_tdt/bert_bi_lstm_pair.h5 +# precision recall f1-score support +# 不相似 0.783744856 0.761800000 0.772616633 5000 +# 相似 0.768287938 0.789800000 0.778895464 5000 +# avg / total 0.776016397 0.775800000 0.775756048 10000 + + +# text-rcnn, real stop +# 100000/100000 [==============================] - 1671s 17ms/step - loss: 0.4627 - acc: 0.7971 - val_loss: 0.4810 - val_acc: 0.8018 +# precision recall f1-score support +# 不相似 0.777479378 0.810600000 0.793694311 5000 +# 相似 0.802172551 0.768000000 0.784714417 5000 +# avg / total 0.789825965 0.789300000 0.789204364 10000 + + +# avt-cnn, real stop +# 100000/100000 [==============================] - 1562s 16ms/step - loss: 0.4204 - acc: 0.8091 - val_loss: 0.4391 - val_acc: 0.7925 +# Epoch 00015: val_loss improved from 0.44410 to 0.43914, saving model to model_webank_tdt/bert_avt_cnn.h5 +# precision recall f1-score support +# 不相似 0.789808917 0.768800000 0.779162866 5000 +# 相似 0.774790571 0.795400000 0.784960032 5000 +# avg / total 0.782299744 0.782100000 0.782061449 10000 \ No newline at end of file