fix syn_words and CRF

This commit is contained in:
yongzhuo 2020-11-28 09:27:07 +08:00
parent d4a94149a5
commit da4cba7fb8
2 changed files with 53 additions and 40 deletions

View File

@ -12,7 +12,7 @@ import jieba
KEY_WORDS = ["macropodus"] # 不替换同义词的词语
ENGLISH = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
ENGLISH = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
def is_english(text):
@ -22,7 +22,7 @@ def is_english(text):
:return: boolean, True or False
"""
try:
text_r = text.replace(' ', '').strip()
text_r = text.replace(" ", "").strip()
for tr in text_r:
if tr in ENGLISH:
continue
@ -39,7 +39,7 @@ def is_number(text):
:return: boolean, True or False
"""
try:
text_r = text.replace(' ', '').strip()
text_r = text.replace(" ", "").strip()
for tr in text_r:
if tr.isdigit():
continue
@ -57,7 +57,7 @@ def get_syn_word(word):
"""
if not is_number(word.strip()) or not is_english(word.strip()):
word_syn = synonyms.nearby(word)
word_syn = word_syn if not word_syn else [word]
word_syn = word_syn[0] if len(word_syn[0]) else [word]
return word_syn
else:
return [word]
@ -182,7 +182,7 @@ def eda(text, n=1, use_syn=True):
return sens_4
if __name__ == '__main__':
if __name__ == "__main__":
sens = "".join(["macropodus", "是不是", "哪个", "啦啦",
"只需做好这四点,就能让你养的天竺葵全年花开不断!"])
print(eda(sens))

View File

@ -23,7 +23,7 @@ from keras import regularizers
from keras import activations
from keras import constraints
import warnings
import keras
import os
# crf_loss
from keras.losses import sparse_categorical_crossentropy
from keras.losses import categorical_crossentropy
@ -220,7 +220,7 @@ def to_tuple(shape):
with keras-team/keras. So we must apply this function to
all input_shapes of the build methods in custom layers.
"""
if is_tf_keras:
if os.environ.get("TF_KERAS")==1:
import tensorflow as tf
return tuple(tf.TensorShape(shape).as_list())
else:
@ -530,36 +530,36 @@ class CRF(Layer):
base_config = super(CRF, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# @property
# def loss_function(self):
# warnings.warn('CRF.loss_function is deprecated '
# 'and it might be removed in the future. Please '
# 'use losses.crf_loss instead.')
# return crf_loss
#
# @property
# def accuracy(self):
# warnings.warn('CRF.accuracy is deprecated and it '
# 'might be removed in the future. Please '
# 'use metrics.crf_accuracy')
# if self.test_mode == 'viterbi':
# return crf_viterbi_accuracy
# else:
# return crf_marginal_accuracy
#
# @property
# def viterbi_acc(self):
# warnings.warn('CRF.viterbi_acc is deprecated and it might '
# 'be removed in the future. Please '
# 'use metrics.viterbi_acc instead.')
# return crf_viterbi_accuracy
#
# @property
# def marginal_acc(self):
# warnings.warn('CRF.moarginal_acc is deprecated and it '
# 'might be removed in the future. Please '
# 'use metrics.marginal_acc instead.')
# return crf_marginal_accuracy
@property
def loss_function(self):
warnings.warn('CRF.loss_function is deprecated '
'and it might be removed in the future. Please '
'use losses.crf_loss instead.')
return crf_loss
@property
def accuracy(self):
warnings.warn('CRF.accuracy is deprecated and it '
'might be removed in the future. Please '
'use metrics.crf_accuracy')
if self.test_mode == 'viterbi':
return crf_viterbi_accuracy
else:
return crf_marginal_accuracy
@property
def viterbi_acc(self):
warnings.warn('CRF.viterbi_acc is deprecated and it might '
'be removed in the future. Please '
'use metrics.viterbi_acc instead.')
return crf_viterbi_accuracy
@property
def marginal_acc(self):
warnings.warn('CRF.moarginal_acc is deprecated and it '
'might be removed in the future. Please '
'use metrics.marginal_acc instead.')
return crf_marginal_accuracy
@staticmethod
def softmaxNd(x, axis=-1):
@ -655,9 +655,22 @@ class CRF(Layer):
chain_energy = chain_energy * K.expand_dims(
K.expand_dims(m[:, 0] * m[:, 1]))
if return_logZ:
# shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
# # shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
# energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
# new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F)
# return new_target_val, [new_target_val, i + 1]
energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F)
new_target_val = K.logsumexp(-energy, 1)
# added from here
if len(states) > 3:
if K.backend() == 'theano':
m = states[3][:, t:(t + 2)]
else:
m = K.slice(states[3], [0, t], [-1, 2])
is_valid = K.expand_dims(m[:, 0])
new_target_val = is_valid * new_target_val + (1 - is_valid) * prev_target_val
# added until here
return new_target_val, [new_target_val, i + 1]
else:
energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2)