fix syn_words and CRF
This commit is contained in:
parent
d4a94149a5
commit
da4cba7fb8
@ -12,7 +12,7 @@ import jieba
|
||||
|
||||
|
||||
KEY_WORDS = ["macropodus"] # 不替换同义词的词语
|
||||
ENGLISH = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
ENGLISH = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
|
||||
def is_english(text):
|
||||
@ -22,7 +22,7 @@ def is_english(text):
|
||||
:return: boolean, True or False
|
||||
"""
|
||||
try:
|
||||
text_r = text.replace(' ', '').strip()
|
||||
text_r = text.replace(" ", "").strip()
|
||||
for tr in text_r:
|
||||
if tr in ENGLISH:
|
||||
continue
|
||||
@ -39,7 +39,7 @@ def is_number(text):
|
||||
:return: boolean, True or False
|
||||
"""
|
||||
try:
|
||||
text_r = text.replace(' ', '').strip()
|
||||
text_r = text.replace(" ", "").strip()
|
||||
for tr in text_r:
|
||||
if tr.isdigit():
|
||||
continue
|
||||
@ -57,7 +57,7 @@ def get_syn_word(word):
|
||||
"""
|
||||
if not is_number(word.strip()) or not is_english(word.strip()):
|
||||
word_syn = synonyms.nearby(word)
|
||||
word_syn = word_syn if not word_syn else [word]
|
||||
word_syn = word_syn[0] if len(word_syn[0]) else [word]
|
||||
return word_syn
|
||||
else:
|
||||
return [word]
|
||||
@ -124,7 +124,7 @@ def word_swap(words, n=1):
|
||||
while count < n:
|
||||
idx_select = random.sample(idxs, 2)
|
||||
temp = words[idx_select[0]]
|
||||
words[idx_select[0]] = words[idx_select[1]]
|
||||
words[idx_select[0]] = words[idx_select[1]]
|
||||
words[idx_select[1]] = temp
|
||||
count += 1
|
||||
return words
|
||||
@ -182,7 +182,7 @@ def eda(text, n=1, use_syn=True):
|
||||
return sens_4
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sens = "".join(["macropodus", "是不是", "哪个", "啦啦",
|
||||
"只需做好这四点,就能让你养的天竺葵全年花开不断!"])
|
||||
print(eda(sens))
|
||||
|
@ -23,7 +23,7 @@ from keras import regularizers
|
||||
from keras import activations
|
||||
from keras import constraints
|
||||
import warnings
|
||||
import keras
|
||||
import os
|
||||
# crf_loss
|
||||
from keras.losses import sparse_categorical_crossentropy
|
||||
from keras.losses import categorical_crossentropy
|
||||
@ -220,7 +220,7 @@ def to_tuple(shape):
|
||||
with keras-team/keras. So we must apply this function to
|
||||
all input_shapes of the build methods in custom layers.
|
||||
"""
|
||||
if is_tf_keras:
|
||||
if os.environ.get("TF_KERAS")==1:
|
||||
import tensorflow as tf
|
||||
return tuple(tf.TensorShape(shape).as_list())
|
||||
else:
|
||||
@ -530,36 +530,36 @@ class CRF(Layer):
|
||||
base_config = super(CRF, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
# @property
|
||||
# def loss_function(self):
|
||||
# warnings.warn('CRF.loss_function is deprecated '
|
||||
# 'and it might be removed in the future. Please '
|
||||
# 'use losses.crf_loss instead.')
|
||||
# return crf_loss
|
||||
#
|
||||
# @property
|
||||
# def accuracy(self):
|
||||
# warnings.warn('CRF.accuracy is deprecated and it '
|
||||
# 'might be removed in the future. Please '
|
||||
# 'use metrics.crf_accuracy')
|
||||
# if self.test_mode == 'viterbi':
|
||||
# return crf_viterbi_accuracy
|
||||
# else:
|
||||
# return crf_marginal_accuracy
|
||||
#
|
||||
# @property
|
||||
# def viterbi_acc(self):
|
||||
# warnings.warn('CRF.viterbi_acc is deprecated and it might '
|
||||
# 'be removed in the future. Please '
|
||||
# 'use metrics.viterbi_acc instead.')
|
||||
# return crf_viterbi_accuracy
|
||||
#
|
||||
# @property
|
||||
# def marginal_acc(self):
|
||||
# warnings.warn('CRF.moarginal_acc is deprecated and it '
|
||||
# 'might be removed in the future. Please '
|
||||
# 'use metrics.marginal_acc instead.')
|
||||
# return crf_marginal_accuracy
|
||||
@property
|
||||
def loss_function(self):
|
||||
warnings.warn('CRF.loss_function is deprecated '
|
||||
'and it might be removed in the future. Please '
|
||||
'use losses.crf_loss instead.')
|
||||
return crf_loss
|
||||
|
||||
@property
|
||||
def accuracy(self):
|
||||
warnings.warn('CRF.accuracy is deprecated and it '
|
||||
'might be removed in the future. Please '
|
||||
'use metrics.crf_accuracy')
|
||||
if self.test_mode == 'viterbi':
|
||||
return crf_viterbi_accuracy
|
||||
else:
|
||||
return crf_marginal_accuracy
|
||||
|
||||
@property
|
||||
def viterbi_acc(self):
|
||||
warnings.warn('CRF.viterbi_acc is deprecated and it might '
|
||||
'be removed in the future. Please '
|
||||
'use metrics.viterbi_acc instead.')
|
||||
return crf_viterbi_accuracy
|
||||
|
||||
@property
|
||||
def marginal_acc(self):
|
||||
warnings.warn('CRF.moarginal_acc is deprecated and it '
|
||||
'might be removed in the future. Please '
|
||||
'use metrics.marginal_acc instead.')
|
||||
return crf_marginal_accuracy
|
||||
|
||||
@staticmethod
|
||||
def softmaxNd(x, axis=-1):
|
||||
@ -655,9 +655,22 @@ class CRF(Layer):
|
||||
chain_energy = chain_energy * K.expand_dims(
|
||||
K.expand_dims(m[:, 0] * m[:, 1]))
|
||||
if return_logZ:
|
||||
# shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
|
||||
# # shapes: (1, B, F) + (B, F, 1) -> (B, F, F)
|
||||
# energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
|
||||
# new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F)
|
||||
# return new_target_val, [new_target_val, i + 1]
|
||||
|
||||
energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2)
|
||||
new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F)
|
||||
new_target_val = K.logsumexp(-energy, 1)
|
||||
# added from here
|
||||
if len(states) > 3:
|
||||
if K.backend() == 'theano':
|
||||
m = states[3][:, t:(t + 2)]
|
||||
else:
|
||||
m = K.slice(states[3], [0, t], [-1, 2])
|
||||
is_valid = K.expand_dims(m[:, 0])
|
||||
new_target_val = is_valid * new_target_val + (1 - is_valid) * prev_target_val
|
||||
# added until here
|
||||
return new_target_val, [new_target_val, i + 1]
|
||||
else:
|
||||
energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2)
|
||||
|
Loading…
Reference in New Issue
Block a user