fix syn_words and CRF
This commit is contained in:
parent
d4a94149a5
commit
da4cba7fb8
@ -12,7 +12,7 @@ import jieba
|
|||||||
|
|
||||||
|
|
||||||
KEY_WORDS = ["macropodus"] # 不替换同义词的词语
|
KEY_WORDS = ["macropodus"] # 不替换同义词的词语
|
||||||
ENGLISH = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
ENGLISH = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
|
||||||
|
|
||||||
def is_english(text):
|
def is_english(text):
|
||||||
@ -22,7 +22,7 @@ def is_english(text):
|
|||||||
:return: boolean, True or False
|
:return: boolean, True or False
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
text_r = text.replace(' ', '').strip()
|
text_r = text.replace(" ", "").strip()
|
||||||
for tr in text_r:
|
for tr in text_r:
|
||||||
if tr in ENGLISH:
|
if tr in ENGLISH:
|
||||||
continue
|
continue
|
||||||
@ -39,7 +39,7 @@ def is_number(text):
|
|||||||
:return: boolean, True or False
|
:return: boolean, True or False
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
text_r = text.replace(' ', '').strip()
|
text_r = text.replace(" ", "").strip()
|
||||||
for tr in text_r:
|
for tr in text_r:
|
||||||
if tr.isdigit():
|
if tr.isdigit():
|
||||||
continue
|
continue
|
||||||
@ -57,7 +57,7 @@ def get_syn_word(word):
|
|||||||
"""
|
"""
|
||||||
if not is_number(word.strip()) or not is_english(word.strip()):
|
if not is_number(word.strip()) or not is_english(word.strip()):
|
||||||
word_syn = synonyms.nearby(word)
|
word_syn = synonyms.nearby(word)
|
||||||
word_syn = word_syn if not word_syn else [word]
|
word_syn = word_syn[0] if len(word_syn[0]) else [word]
|
||||||
return word_syn
|
return word_syn
|
||||||
else:
|
else:
|
||||||
return [word]
|
return [word]
|
||||||
@ -182,7 +182,7 @@ def eda(text, n=1, use_syn=True):
|
|||||||
return sens_4
|
return sens_4
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
sens = "".join(["macropodus", "是不是", "哪个", "啦啦",
|
sens = "".join(["macropodus", "是不是", "哪个", "啦啦",
|
||||||
"只需做好这四点,就能让你养的天竺葵全年花开不断!"])
|
"只需做好这四点,就能让你养的天竺葵全年花开不断!"])
|
||||||
print(eda(sens))
|
print(eda(sens))
|
||||||
|
@ -23,7 +23,7 @@ from keras import regularizers
|
|||||||
from keras import activations
|
from keras import activations
|
||||||
from keras import constraints
|
from keras import constraints
|
||||||
import warnings
|
import warnings
|
||||||
import keras
|
import os
|
||||||
# crf_loss
|
# crf_loss
|
||||||
from keras.losses import sparse_categorical_crossentropy
|
from keras.losses import sparse_categorical_crossentropy
|
||||||
from keras.losses import categorical_crossentropy
|
from keras.losses import categorical_crossentropy
|
||||||
@ -220,7 +220,7 @@ def to_tuple(shape):
|
|||||||
with keras-team/keras. So we must apply this function to
|
with keras-team/keras. So we must apply this function to
|
||||||
all input_shapes of the build methods in custom layers.
|
all input_shapes of the build methods in custom layers.
|
||||||
"""
|
"""
|
||||||
if is_tf_keras:
|
if os.environ.get("TF_KERAS")==1:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
return tuple(tf.TensorShape(shape).as_list())
|
return tuple(tf.TensorShape(shape).as_list())
|
||||||
else:
|
else:
|
||||||
@ -530,36 +530,36 @@ class CRF(Layer):
|
|||||||
base_config = super(CRF, self).get_config()
|
base_config = super(CRF, self).get_config()
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
# @property
|
@property
|
||||||
# def loss_function(self):
|
def loss_function(self):
|
||||||
# warnings.warn('CRF.loss_function is deprecated '
|
warnings.warn('CRF.loss_function is deprecated '
|
||||||
# 'and it might be removed in the future. Please '
|
'and it might be removed in the future. Please '
|
||||||
# 'use losses.crf_loss instead.')
|
'use losses.crf_loss instead.')
|
||||||
# return crf_loss
|
return crf_loss
|
||||||
#
|
|
||||||
# @property
|
@property
|
||||||
# def accuracy(self):
|
def accuracy(self):
|
||||||
# warnings.warn('CRF.accuracy is deprecated and it '
|
warnings.warn('CRF.accuracy is deprecated and it '
|
||||||
# 'might be removed in the future. Please '
|
'might be removed in the future. Please '
|
||||||
# 'use metrics.crf_accuracy')
|
'use metrics.crf_accuracy')
|
||||||
# if self.test_mode == 'viterbi':
|
if self.test_mode == 'viterbi':
|
||||||
# return crf_viterbi_accuracy
|
return crf_viterbi_accuracy
|
||||||
# else:
|
else:
|
||||||
# return crf_marginal_accuracy
|
return crf_marginal_accuracy
|
||||||
#
|
|
||||||
# @property
|
@property
|
||||||
# def viterbi_acc(self):
|
def viterbi_acc(self):
|
||||||
# warnings.warn('CRF.viterbi_acc is deprecated and it might '
|
warnings.warn('CRF.viterbi_acc is deprecated and it might '
|
||||||
# 'be removed in the future. Please '
|
'be removed in the future. Please '
|
||||||
# 'use metrics.viterbi_acc instead.')
|
'use metrics.viterbi_acc instead.')
|
||||||
# return crf_viterbi_accuracy
|
return crf_viterbi_accuracy
|
||||||
#
|
|
||||||
# @property
|
@property
|
||||||
# def marginal_acc(self):
|
def marginal_acc(self):
|
||||||
# warnings.warn('CRF.moarginal_acc is deprecated and it '
|
warnings.warn('CRF.moarginal_acc is deprecated and it '
|
||||||
# 'might be removed in the future. Please '
|
'might be removed in the future. Please '
|
||||||
# 'use metrics.marginal_acc instead.')
|
'use metrics.marginal_acc instead.')
|
||||||
# return crf_marginal_accuracy
|
return crf_marginal_accuracy
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def softmaxNd(x, axis=-1):
|
def softmaxNd(x, axis=-1):
|
||||||
@ -655,9 +655,22 @@ class CRF(Layer):
|
|||||||
chain_energy = chain_energy * K.expand_dims(
|
chain_energy = chain_energy * K.expand_dims(
|
||||||
K.expand_dims(m[:, 0] * m[:, 1]))
|
K.expand_dims(m[:, 0] * m[:, 1]))
|
||||||
if return_logZ:
|
if return_logZ:
|
||||||
# shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
|
# # shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
|
||||||
|
# energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
|
||||||
|
# new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F)
|
||||||
|
# return new_target_val, [new_target_val, i + 1]
|
||||||
|
|
||||||
energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
|
energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
|
||||||
new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F)
|
new_target_val = K.logsumexp(-energy, 1)
|
||||||
|
# added from here
|
||||||
|
if len(states) > 3:
|
||||||
|
if K.backend() == 'theano':
|
||||||
|
m = states[3][:, t:(t + 2)]
|
||||||
|
else:
|
||||||
|
m = K.slice(states[3], [0, t], [-1, 2])
|
||||||
|
is_valid = K.expand_dims(m[:, 0])
|
||||||
|
new_target_val = is_valid * new_target_val + (1 - is_valid) * prev_target_val
|
||||||
|
# added until here
|
||||||
return new_target_val, [new_target_val, i + 1]
|
return new_target_val, [new_target_val, i + 1]
|
||||||
else:
|
else:
|
||||||
energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2)
|
energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2)
|
||||||
|
Loading…
Reference in New Issue
Block a user