fix syn_words and CRF

This commit is contained in:
yongzhuo 2020-11-28 09:27:07 +08:00
parent d4a94149a5
commit da4cba7fb8
2 changed files with 53 additions and 40 deletions

View File

@ -12,7 +12,7 @@ import jieba
KEY_WORDS = ["macropodus"] # 不替换同义词的词语 KEY_WORDS = ["macropodus"] # 不替换同义词的词语
ENGLISH = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' ENGLISH = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
def is_english(text): def is_english(text):
@ -22,7 +22,7 @@ def is_english(text):
:return: boolean, True or False :return: boolean, True or False
""" """
try: try:
text_r = text.replace(' ', '').strip() text_r = text.replace(" ", "").strip()
for tr in text_r: for tr in text_r:
if tr in ENGLISH: if tr in ENGLISH:
continue continue
@ -39,7 +39,7 @@ def is_number(text):
:return: boolean, True or False :return: boolean, True or False
""" """
try: try:
text_r = text.replace(' ', '').strip() text_r = text.replace(" ", "").strip()
for tr in text_r: for tr in text_r:
if tr.isdigit(): if tr.isdigit():
continue continue
@ -57,7 +57,7 @@ def get_syn_word(word):
""" """
if not is_number(word.strip()) or not is_english(word.strip()): if not is_number(word.strip()) or not is_english(word.strip()):
word_syn = synonyms.nearby(word) word_syn = synonyms.nearby(word)
word_syn = word_syn if not word_syn else [word] word_syn = word_syn[0] if len(word_syn[0]) else [word]
return word_syn return word_syn
else: else:
return [word] return [word]
@ -124,7 +124,7 @@ def word_swap(words, n=1):
while count < n: while count < n:
idx_select = random.sample(idxs, 2) idx_select = random.sample(idxs, 2)
temp = words[idx_select[0]] temp = words[idx_select[0]]
words[idx_select[0]] = words[idx_select[1]] words[idx_select[0]] = words[idx_select[1]]
words[idx_select[1]] = temp words[idx_select[1]] = temp
count += 1 count += 1
return words return words
@ -182,7 +182,7 @@ def eda(text, n=1, use_syn=True):
return sens_4 return sens_4
if __name__ == '__main__': if __name__ == "__main__":
sens = "".join(["macropodus", "是不是", "哪个", "啦啦", sens = "".join(["macropodus", "是不是", "哪个", "啦啦",
"只需做好这四点,就能让你养的天竺葵全年花开不断!"]) "只需做好这四点,就能让你养的天竺葵全年花开不断!"])
print(eda(sens)) print(eda(sens))

View File

@ -23,7 +23,7 @@ from keras import regularizers
from keras import activations from keras import activations
from keras import constraints from keras import constraints
import warnings import warnings
import keras import os
# crf_loss # crf_loss
from keras.losses import sparse_categorical_crossentropy from keras.losses import sparse_categorical_crossentropy
from keras.losses import categorical_crossentropy from keras.losses import categorical_crossentropy
@ -220,7 +220,7 @@ def to_tuple(shape):
with keras-team/keras. So we must apply this function to with keras-team/keras. So we must apply this function to
all input_shapes of the build methods in custom layers. all input_shapes of the build methods in custom layers.
""" """
if is_tf_keras: if os.environ.get("TF_KERAS")==1:
import tensorflow as tf import tensorflow as tf
return tuple(tf.TensorShape(shape).as_list()) return tuple(tf.TensorShape(shape).as_list())
else: else:
@ -530,36 +530,36 @@ class CRF(Layer):
base_config = super(CRF, self).get_config() base_config = super(CRF, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
# @property @property
# def loss_function(self): def loss_function(self):
# warnings.warn('CRF.loss_function is deprecated ' warnings.warn('CRF.loss_function is deprecated '
# 'and it might be removed in the future. Please ' 'and it might be removed in the future. Please '
# 'use losses.crf_loss instead.') 'use losses.crf_loss instead.')
# return crf_loss return crf_loss
#
# @property @property
# def accuracy(self): def accuracy(self):
# warnings.warn('CRF.accuracy is deprecated and it ' warnings.warn('CRF.accuracy is deprecated and it '
# 'might be removed in the future. Please ' 'might be removed in the future. Please '
# 'use metrics.crf_accuracy') 'use metrics.crf_accuracy')
# if self.test_mode == 'viterbi': if self.test_mode == 'viterbi':
# return crf_viterbi_accuracy return crf_viterbi_accuracy
# else: else:
# return crf_marginal_accuracy return crf_marginal_accuracy
#
# @property @property
# def viterbi_acc(self): def viterbi_acc(self):
# warnings.warn('CRF.viterbi_acc is deprecated and it might ' warnings.warn('CRF.viterbi_acc is deprecated and it might '
# 'be removed in the future. Please ' 'be removed in the future. Please '
# 'use metrics.viterbi_acc instead.') 'use metrics.viterbi_acc instead.')
# return crf_viterbi_accuracy return crf_viterbi_accuracy
#
# @property @property
# def marginal_acc(self): def marginal_acc(self):
# warnings.warn('CRF.moarginal_acc is deprecated and it ' warnings.warn('CRF.moarginal_acc is deprecated and it '
# 'might be removed in the future. Please ' 'might be removed in the future. Please '
# 'use metrics.marginal_acc instead.') 'use metrics.marginal_acc instead.')
# return crf_marginal_accuracy return crf_marginal_accuracy
@staticmethod @staticmethod
def softmaxNd(x, axis=-1): def softmaxNd(x, axis=-1):
@ -655,9 +655,22 @@ class CRF(Layer):
chain_energy = chain_energy * K.expand_dims( chain_energy = chain_energy * K.expand_dims(
K.expand_dims(m[:, 0] * m[:, 1])) K.expand_dims(m[:, 0] * m[:, 1]))
if return_logZ: if return_logZ:
# shapes: (1, B, F) + (B, F, 1) -> (B, F, F) # # shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
# energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
# new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F)
# return new_target_val, [new_target_val, i + 1]
energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2) energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F) new_target_val = K.logsumexp(-energy, 1)
# added from here
if len(states) > 3:
if K.backend() == 'theano':
m = states[3][:, t:(t + 2)]
else:
m = K.slice(states[3], [0, t], [-1, 2])
is_valid = K.expand_dims(m[:, 0])
new_target_val = is_valid * new_target_val + (1 - is_valid) * prev_target_val
# added until here
return new_target_val, [new_target_val, i + 1] return new_target_val, [new_target_val, i + 1]
else: else:
energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2) energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2)