From da4cba7fb85c8acb843891657cfdec3c9763088b Mon Sep 17 00:00:00 2001 From: yongzhuo <2714618994@qq.com> Date: Sat, 28 Nov 2020 09:27:07 +0800 Subject: [PATCH] fix syn_words and CRF --- AugmentText/augment_eda/enhance_eda_v2.py | 12 ++-- Ner/bert/keras_bert_layer.py | 81 +++++++++++++---------- 2 files changed, 53 insertions(+), 40 deletions(-) diff --git a/AugmentText/augment_eda/enhance_eda_v2.py b/AugmentText/augment_eda/enhance_eda_v2.py index 4d27c9b..1b17ab3 100644 --- a/AugmentText/augment_eda/enhance_eda_v2.py +++ b/AugmentText/augment_eda/enhance_eda_v2.py @@ -12,7 +12,7 @@ import jieba KEY_WORDS = ["macropodus"] # 不替换同义词的词语 -ENGLISH = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' +ENGLISH = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" def is_english(text): @@ -22,7 +22,7 @@ def is_english(text): :return: boolean, True or False """ try: - text_r = text.replace(' ', '').strip() + text_r = text.replace(" ", "").strip() for tr in text_r: if tr in ENGLISH: continue @@ -39,7 +39,7 @@ def is_number(text): :return: boolean, True or False """ try: - text_r = text.replace(' ', '').strip() + text_r = text.replace(" ", "").strip() for tr in text_r: if tr.isdigit(): continue @@ -57,7 +57,7 @@ def get_syn_word(word): """ if not is_number(word.strip()) or not is_english(word.strip()): word_syn = synonyms.nearby(word) - word_syn = word_syn if not word_syn else [word] + word_syn = word_syn[0] if len(word_syn[0]) else [word] return word_syn else: return [word] @@ -124,7 +124,7 @@ def word_swap(words, n=1): while count < n: idx_select = random.sample(idxs, 2) temp = words[idx_select[0]] - words[idx_select[0]] = words[idx_select[1]] + words[idx_select[0]] = words[idx_select[1]] words[idx_select[1]] = temp count += 1 return words @@ -182,7 +182,7 @@ def eda(text, n=1, use_syn=True): return sens_4 -if __name__ == '__main__': +if __name__ == "__main__": sens = "".join(["macropodus", "是不是", "哪个", "啦啦", "只需做好这四点,就能让你养的天竺葵全年花开不断!"]) print(eda(sens)) diff --git a/Ner/bert/keras_bert_layer.py b/Ner/bert/keras_bert_layer.py index 7478147..2e84f62 100644 --- a/Ner/bert/keras_bert_layer.py +++ b/Ner/bert/keras_bert_layer.py @@ -23,7 +23,7 @@ from keras import regularizers from keras import activations from keras import constraints import warnings -import keras +import os # crf_loss from keras.losses import sparse_categorical_crossentropy from keras.losses import categorical_crossentropy @@ -220,7 +220,7 @@ def to_tuple(shape): with keras-team/keras. So we must apply this function to all input_shapes of the build methods in custom layers. """ - if is_tf_keras: + if os.environ.get("TF_KERAS")==1: import tensorflow as tf return tuple(tf.TensorShape(shape).as_list()) else: @@ -530,36 +530,36 @@ class CRF(Layer): base_config = super(CRF, self).get_config() return dict(list(base_config.items()) + list(config.items())) - # @property - # def loss_function(self): - # warnings.warn('CRF.loss_function is deprecated ' - # 'and it might be removed in the future. Please ' - # 'use losses.crf_loss instead.') - # return crf_loss - # - # @property - # def accuracy(self): - # warnings.warn('CRF.accuracy is deprecated and it ' - # 'might be removed in the future. Please ' - # 'use metrics.crf_accuracy') - # if self.test_mode == 'viterbi': - # return crf_viterbi_accuracy - # else: - # return crf_marginal_accuracy - # - # @property - # def viterbi_acc(self): - # warnings.warn('CRF.viterbi_acc is deprecated and it might ' - # 'be removed in the future. Please ' - # 'use metrics.viterbi_acc instead.') - # return crf_viterbi_accuracy - # - # @property - # def marginal_acc(self): - # warnings.warn('CRF.moarginal_acc is deprecated and it ' - # 'might be removed in the future. Please ' - # 'use metrics.marginal_acc instead.') - # return crf_marginal_accuracy + @property + def loss_function(self): + warnings.warn('CRF.loss_function is deprecated ' + 'and it might be removed in the future. Please ' + 'use losses.crf_loss instead.') + return crf_loss + + @property + def accuracy(self): + warnings.warn('CRF.accuracy is deprecated and it ' + 'might be removed in the future. Please ' + 'use metrics.crf_accuracy') + if self.test_mode == 'viterbi': + return crf_viterbi_accuracy + else: + return crf_marginal_accuracy + + @property + def viterbi_acc(self): + warnings.warn('CRF.viterbi_acc is deprecated and it might ' + 'be removed in the future. Please ' + 'use metrics.viterbi_acc instead.') + return crf_viterbi_accuracy + + @property + def marginal_acc(self): + warnings.warn('CRF.moarginal_acc is deprecated and it ' + 'might be removed in the future. Please ' + 'use metrics.marginal_acc instead.') + return crf_marginal_accuracy @staticmethod def softmaxNd(x, axis=-1): @@ -655,9 +655,22 @@ class CRF(Layer): chain_energy = chain_energy * K.expand_dims( K.expand_dims(m[:, 0] * m[:, 1])) if return_logZ: - # shapes: (1, B, F) + (B, F, 1) -> (B, F, F) + # # shapes: (1, B, F) + (B, F, 1) -> (B, F, F) + # energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2) + # new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F) + # return new_target_val, [new_target_val, i + 1] + energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2) - new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F) + new_target_val = K.logsumexp(-energy, 1) + # added from here + if len(states) > 3: + if K.backend() == 'theano': + m = states[3][:, t:(t + 2)] + else: + m = K.slice(states[3], [0, t], [-1, 2]) + is_valid = K.expand_dims(m[:, 0]) + new_target_val = is_valid * new_target_val + (1 - is_valid) * prev_target_val + # added until here return new_target_val, [new_target_val, i + 1] else: energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2)