Add files via upload
This commit is contained in:
parent
a0dcf570f0
commit
a9b9a5333f
5
utils/mode_util/__init__.py
Normal file
5
utils/mode_util/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# !/usr/bin/python
|
||||||
|
# @time :2019/4/15 9:58
|
||||||
|
# @author :Mo
|
||||||
|
# @function :
|
5
utils/mode_util/seq2seq/__init__.py
Normal file
5
utils/mode_util/seq2seq/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# !/usr/bin/python
|
||||||
|
# @time :2019/4/15 9:58
|
||||||
|
# @author :Mo
|
||||||
|
# @function :
|
264
utils/mode_util/seq2seq/data_utils.py
Normal file
264
utils/mode_util/seq2seq/data_utils.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
"""
|
||||||
|
一些数据操作所需的模块
|
||||||
|
Code from: QHDuan(2018-02-05) url: https://github.com/qhduan/just_another_seq2seq
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from utils.mode_util.seq2seq.word_sequence import WordSequence
|
||||||
|
from tensorflow.python.client import device_lib
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
VOCAB_SIZE_THRESHOLD_CPU = 50000
|
||||||
|
|
||||||
|
|
||||||
|
def _get_available_gpus():
|
||||||
|
"""获取当前可用GPU数量"""
|
||||||
|
local_device_protos = device_lib.list_local_devices()
|
||||||
|
return [x.name for x in local_device_protos if x.device_type == 'GPU']
|
||||||
|
|
||||||
|
|
||||||
|
def _get_embed_device(vocab_size):
|
||||||
|
"""Decide on which device to place an embed matrix given its vocab size.
|
||||||
|
根据输入输出的字典大小,选择在CPU还是GPU上初始化embedding向量
|
||||||
|
"""
|
||||||
|
gpus = _get_available_gpus()
|
||||||
|
if not gpus or vocab_size > VOCAB_SIZE_THRESHOLD_CPU:
|
||||||
|
return "/cpu:0"
|
||||||
|
return "/gpu:0"
|
||||||
|
|
||||||
|
|
||||||
|
def transform_sentence(sentence, ws, max_len=None, add_end=False):
|
||||||
|
"""转换一个单独句子
|
||||||
|
Args:
|
||||||
|
sentence: 一句话,例如一个数组['你', '好', '吗']
|
||||||
|
ws: 一个WordSequence对象,转换器
|
||||||
|
max_len:
|
||||||
|
进行padding的长度,也就是如果sentence长度小于max_len
|
||||||
|
则padding到max_len这么长
|
||||||
|
Ret:
|
||||||
|
encoded:
|
||||||
|
一个经过ws转换的数组,例如[4, 5, 6, 3]
|
||||||
|
encoded_len: 上面的长度
|
||||||
|
"""
|
||||||
|
encoded = ws.transform(
|
||||||
|
sentence,
|
||||||
|
max_len=max_len if max_len is not None else len(sentence))
|
||||||
|
encoded_len = len(sentence) + (1 if add_end else 0) # add end
|
||||||
|
if encoded_len > len(encoded):
|
||||||
|
encoded_len = len(encoded)
|
||||||
|
return encoded, encoded_len
|
||||||
|
|
||||||
|
|
||||||
|
def batch_flow(data, ws, batch_size, raw=False, add_end=True):
|
||||||
|
"""从数据中随机 batch_size 个的数据,然后 yield 出去
|
||||||
|
Args:
|
||||||
|
data:
|
||||||
|
是一个数组,必须包含一个护着更多个同等的数据队列数组
|
||||||
|
ws:
|
||||||
|
可以是一个WordSequence对象,也可以是多个组成的数组
|
||||||
|
如果是多个,那么数组数量应该与data的数据数量保持一致,即len(data) == len(ws)
|
||||||
|
batch_size:
|
||||||
|
批量的大小
|
||||||
|
raw:
|
||||||
|
是否返回原始对象,如果为True,假设结果ret,那么len(ret) == len(data) * 3
|
||||||
|
如果为False,那么len(ret) == len(data) * 2
|
||||||
|
|
||||||
|
例如需要输入问题与答案的队列,问题队列Q = (q_1, q_2, q_3 ... q_n)
|
||||||
|
答案队列A = (a_1, a_2, a_3 ... a_n),有len(Q) == len(A)
|
||||||
|
ws是一个Q与A共用的WordSequence对象,
|
||||||
|
那么可以有: batch_flow([Q, A], ws, batch_size=32)
|
||||||
|
这样会返回一个generator,每次next(generator)会返回一个包含4个对象的数组,分别代表:
|
||||||
|
next(generator) == q_i_encoded, q_i_len, a_i_encoded, a_i_len
|
||||||
|
如果设置raw = True,则:
|
||||||
|
next(generator) == q_i_encoded, q_i_len, q_i, a_i_encoded, a_i_len, a_i
|
||||||
|
|
||||||
|
其中 q_i_encoded 相当于 ws.transform(q_i)
|
||||||
|
|
||||||
|
不过经过了batch修正,把一个batch中每个结果的长度,padding到了数组内最大的句子长度
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_data = list(zip(*data))
|
||||||
|
|
||||||
|
if isinstance(ws, (list, tuple)):
|
||||||
|
assert len(ws) == len(data), \
|
||||||
|
'len(ws) must equal to len(data) if ws is list or tuple'
|
||||||
|
|
||||||
|
if isinstance(add_end, bool):
|
||||||
|
add_end = [add_end] * len(data)
|
||||||
|
else:
|
||||||
|
assert(isinstance(add_end, (list, tuple))), \
|
||||||
|
'add_end 不是 boolean,就应该是一个list(tuple) of boolean'
|
||||||
|
assert len(add_end) == len(data), \
|
||||||
|
'如果 add_end 是list(tuple),那么 add_end 的长度应该和输入数据长度一致'
|
||||||
|
|
||||||
|
mul = 2
|
||||||
|
if raw:
|
||||||
|
mul = 3
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data_batch = random.sample(all_data, batch_size)
|
||||||
|
batches = [[] for i in range(len(data) * mul)]
|
||||||
|
|
||||||
|
max_lens = []
|
||||||
|
for j in range(len(data)):
|
||||||
|
max_len = max([
|
||||||
|
len(x[j]) if hasattr(x[j], '__len__') else 0
|
||||||
|
for x in data_batch
|
||||||
|
]) + (1 if add_end[j] else 0)
|
||||||
|
max_lens.append(max_len)
|
||||||
|
|
||||||
|
for d in data_batch:
|
||||||
|
for j in range(len(data)):
|
||||||
|
if isinstance(ws, (list, tuple)):
|
||||||
|
w = ws[j]
|
||||||
|
else:
|
||||||
|
w = ws
|
||||||
|
|
||||||
|
# 添加结尾
|
||||||
|
line = d[j]
|
||||||
|
if add_end[j] and isinstance(line, (tuple, list)):
|
||||||
|
line = list(line) + [WordSequence.END_TAG]
|
||||||
|
|
||||||
|
if w is not None:
|
||||||
|
x, xl = transform_sentence(line, w, max_lens[j], add_end[j])
|
||||||
|
batches[j * mul].append(x)
|
||||||
|
batches[j * mul + 1].append(xl)
|
||||||
|
else:
|
||||||
|
batches[j * mul].append(line)
|
||||||
|
batches[j * mul + 1].append(line)
|
||||||
|
if raw:
|
||||||
|
batches[j * mul + 2].append(line)
|
||||||
|
batches = [np.asarray(x) for x in batches]
|
||||||
|
|
||||||
|
yield batches
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def batch_flow_bucket(data, ws, batch_size, raw=False,
|
||||||
|
add_end=True,
|
||||||
|
n_buckets=5, bucket_ind=1,
|
||||||
|
debug=False):
|
||||||
|
"""batch_flow的bucket版本
|
||||||
|
多了两重要参数,一个是n_buckets,一个是bucket_ind
|
||||||
|
n_buckets是分成几个buckets,理论上n_buckets == 1时就相当于没有进行buckets操作
|
||||||
|
bucket_ind是指定哪一维度的输入数据作为bucket的依据
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_data = list(zip(*data))
|
||||||
|
# for x in all_data:
|
||||||
|
# print(x[0][bucket_ind])
|
||||||
|
#
|
||||||
|
# lengths = 0
|
||||||
|
lengths = sorted(list(set([len(x[0][bucket_ind]) for x in all_data])))
|
||||||
|
if n_buckets > len(lengths):
|
||||||
|
n_buckets = len(lengths)
|
||||||
|
|
||||||
|
splits = np.array(lengths)[
|
||||||
|
(np.linspace(0, 1, 5, endpoint=False) * len(lengths)).astype(int)
|
||||||
|
].tolist()
|
||||||
|
splits += [np.inf]
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(splits)
|
||||||
|
|
||||||
|
ind_data = {}
|
||||||
|
for x in all_data:
|
||||||
|
l = len(x[0][bucket_ind])
|
||||||
|
for ind, s in enumerate(splits[:-1]):
|
||||||
|
if l >= s and l <= splits[ind + 1]:
|
||||||
|
if ind not in ind_data:
|
||||||
|
ind_data[ind] = []
|
||||||
|
ind_data[ind].append(x)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
inds = sorted(list(ind_data.keys()))
|
||||||
|
ind_p = [len(ind_data[x]) / len(all_data) for x in inds]
|
||||||
|
if debug:
|
||||||
|
print(np.sum(ind_p), ind_p)
|
||||||
|
|
||||||
|
if isinstance(ws, (list, tuple)):
|
||||||
|
assert len(ws) == len(data), \
|
||||||
|
'len(ws) must equal to len(data) if ws is list or tuple'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if isinstance(add_end, bool):
|
||||||
|
add_end = [add_end] * len(data)
|
||||||
|
else:
|
||||||
|
assert(isinstance(add_end, (list, tuple))), \
|
||||||
|
'add_end 不是 boolean,就应该是一个list(tuple) of boolean'
|
||||||
|
assert len(add_end) == len(data), \
|
||||||
|
'如果 add_end 是list(tuple),那么 add_end 的长度应该和输入数据长度一致'
|
||||||
|
|
||||||
|
mul = 2
|
||||||
|
if raw:
|
||||||
|
mul = 3
|
||||||
|
|
||||||
|
while True:
|
||||||
|
choice_ind = np.random.choice(inds, p=ind_p)
|
||||||
|
if debug:
|
||||||
|
print('choice_ind', choice_ind)
|
||||||
|
data_batch = random.sample(ind_data[choice_ind], batch_size)
|
||||||
|
batches = [[] for i in range(len(data) * mul)]
|
||||||
|
|
||||||
|
max_lens = []
|
||||||
|
for j in range(len(data)):
|
||||||
|
max_len = max([
|
||||||
|
len(x[j]) if hasattr(x[j], '__len__') else 0
|
||||||
|
for x in data_batch
|
||||||
|
]) + (1 if add_end[j] else 0)
|
||||||
|
max_lens.append(max_len)
|
||||||
|
|
||||||
|
for d in data_batch:
|
||||||
|
for j in range(len(data)):
|
||||||
|
if isinstance(ws, (list, tuple)):
|
||||||
|
w = ws[j]
|
||||||
|
else:
|
||||||
|
w = ws
|
||||||
|
|
||||||
|
# 添加结尾
|
||||||
|
line = d[j]
|
||||||
|
if add_end[j] and isinstance(line, (tuple, list)):
|
||||||
|
line = list(line) + [WordSequence.END_TAG]
|
||||||
|
|
||||||
|
if w is not None:
|
||||||
|
x, xl = transform_sentence(line, w, max_lens[j], add_end[j])
|
||||||
|
batches[j * mul].append(x)
|
||||||
|
batches[j * mul + 1].append(xl)
|
||||||
|
else:
|
||||||
|
batches[j * mul].append(line)
|
||||||
|
batches[j * mul + 1].append(line)
|
||||||
|
if raw:
|
||||||
|
batches[j * mul + 2].append(line)
|
||||||
|
batches = [np.asarray(x) for x in batches]
|
||||||
|
|
||||||
|
yield batches
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def test_batch_flow():
|
||||||
|
# """test batch_flow function"""
|
||||||
|
# from fake_data import generate
|
||||||
|
# x_data, y_data, ws_input, ws_target = generate(size=10000)
|
||||||
|
# flow = batch_flow([x_data, y_data], [ws_input, ws_target], 4)
|
||||||
|
# x, xl, y, yl = next(flow)
|
||||||
|
# print(x.shape, y.shape, xl.shape, yl.shape)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def test_batch_flow_bucket():
|
||||||
|
# """test batch_flow function"""
|
||||||
|
# from fake_data import generate
|
||||||
|
# x_data, y_data, ws_input, ws_target = generate(size=10000)
|
||||||
|
# flow = batch_flow_bucket(
|
||||||
|
# [x_data, y_data], [ws_input, ws_target], 4,
|
||||||
|
# debug=True)
|
||||||
|
# for _ in range(10):
|
||||||
|
# x, xl, y, yl = next(flow)
|
||||||
|
# print(x.shape, y.shape, xl.shape, yl.shape)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# test_batch_flow_bucket()
|
1099
utils/mode_util/seq2seq/model_seq2seq.py
Normal file
1099
utils/mode_util/seq2seq/model_seq2seq.py
Normal file
File diff suppressed because it is too large
Load Diff
110
utils/mode_util/seq2seq/thread_generator.py
Normal file
110
utils/mode_util/seq2seq/thread_generator.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
"""
|
||||||
|
Code from : https://gist.github.com/everilae/9697228
|
||||||
|
QHduan added __next__:https://github.com/qhduan/just_another_seq2seq
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A simple generator wrapper, not sure if it's good for anything at all.
|
||||||
|
# With basic python threading
|
||||||
|
from threading import Thread
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
# ... or use multiprocessing versions
|
||||||
|
# WARNING: use sentinel based on value, not identity
|
||||||
|
# from multiprocessing import Process, Queue as MpQueue
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadedGenerator(object):
|
||||||
|
"""
|
||||||
|
Generator that runs on a separate thread, returning values to calling
|
||||||
|
thread. Care must be taken that the iterator does not mutate any shared
|
||||||
|
variables referenced in the calling thread.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, iterator,
|
||||||
|
sentinel=object(),
|
||||||
|
queue_maxsize=0,
|
||||||
|
daemon=False):
|
||||||
|
self._iterator = iterator
|
||||||
|
self._sentinel = sentinel
|
||||||
|
self._queue = Queue(maxsize=queue_maxsize)
|
||||||
|
self._thread = Thread(
|
||||||
|
name=repr(iterator),
|
||||||
|
target=self._run
|
||||||
|
)
|
||||||
|
self._thread.daemon = daemon
|
||||||
|
self._started = False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'ThreadedGenerator({!r})'.format(self._iterator)
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
try:
|
||||||
|
for value in self._iterator:
|
||||||
|
if not self._started:
|
||||||
|
return
|
||||||
|
self._queue.put(value)
|
||||||
|
finally:
|
||||||
|
self._queue.put(self._sentinel)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._started = False
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self._queue.get(timeout=0)
|
||||||
|
except KeyboardInterrupt as e:
|
||||||
|
raise e
|
||||||
|
except: # pylint: disable=bare-except
|
||||||
|
pass
|
||||||
|
# self._thread.join()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self._started = True
|
||||||
|
self._thread.start()
|
||||||
|
for value in iter(self._queue.get, self._sentinel):
|
||||||
|
yield value
|
||||||
|
self._thread.join()
|
||||||
|
self._started = False
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if not self._started:
|
||||||
|
self._started = True
|
||||||
|
self._thread.start()
|
||||||
|
value = self._queue.get(timeout=30)
|
||||||
|
if value == self._sentinel:
|
||||||
|
raise StopIteration()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""测试"""
|
||||||
|
|
||||||
|
def gene():
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
yield i
|
||||||
|
i += 1
|
||||||
|
t = gene()
|
||||||
|
tt = ThreadedGenerator(t)
|
||||||
|
for _ in range(10):
|
||||||
|
print(next(tt))
|
||||||
|
tt.close()
|
||||||
|
# for i in range(10):
|
||||||
|
# print(next(tt))
|
||||||
|
|
||||||
|
# for t in ThreadedGenerator(range(10)):
|
||||||
|
# print(t)
|
||||||
|
# print('-' * 10)
|
||||||
|
#
|
||||||
|
# t = ThreadedGenerator(range(10))
|
||||||
|
# # def gene():
|
||||||
|
# # for t in range(10):
|
||||||
|
# # yield t
|
||||||
|
# # t = gene()
|
||||||
|
# for _ in range(10):
|
||||||
|
# print(next(t))
|
||||||
|
# print('-' * 10)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test()
|
178
utils/mode_util/seq2seq/word_sequence.py
Normal file
178
utils/mode_util/seq2seq/word_sequence.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
WordSequence类
|
||||||
|
Code from https://github.com/qhduan/just_another_seq2seq/blob/master/word_sequence.py
|
||||||
|
维护一个字典,把一个list(或者字符串)编码化,或者反向恢复
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class WordSequence(object):
|
||||||
|
"""一个可以把句子编码化(index)的类
|
||||||
|
"""
|
||||||
|
|
||||||
|
PAD_TAG = '<pad>'
|
||||||
|
UNK_TAG = '<unk>'
|
||||||
|
START_TAG = '<s>'
|
||||||
|
END_TAG = '</s>'
|
||||||
|
PAD = 0
|
||||||
|
UNK = 1
|
||||||
|
START = 2
|
||||||
|
END = 3
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化基本的dict
|
||||||
|
"""
|
||||||
|
self.dict = {
|
||||||
|
WordSequence.PAD_TAG: WordSequence.PAD,
|
||||||
|
WordSequence.UNK_TAG: WordSequence.UNK,
|
||||||
|
WordSequence.START_TAG: WordSequence.START,
|
||||||
|
WordSequence.END_TAG: WordSequence.END,
|
||||||
|
}
|
||||||
|
self.fited = False
|
||||||
|
|
||||||
|
|
||||||
|
def to_index(self, word):
|
||||||
|
"""把一个单字转换为index
|
||||||
|
"""
|
||||||
|
assert self.fited, 'WordSequence 尚未 fit'
|
||||||
|
if word in self.dict:
|
||||||
|
return self.dict[word]
|
||||||
|
return WordSequence.UNK
|
||||||
|
|
||||||
|
|
||||||
|
def to_word(self, index):
|
||||||
|
"""把一个index转换为单字
|
||||||
|
"""
|
||||||
|
assert self.fited, 'WordSequence 尚未 fit'
|
||||||
|
for k, v in self.dict.items():
|
||||||
|
if v == index:
|
||||||
|
return k
|
||||||
|
return WordSequence.UNK_TAG
|
||||||
|
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
"""返回字典大小
|
||||||
|
"""
|
||||||
|
assert self.fited, 'WordSequence 尚未 fit'
|
||||||
|
return len(self.dict) + 1
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""返回字典大小
|
||||||
|
"""
|
||||||
|
return self.size()
|
||||||
|
|
||||||
|
|
||||||
|
def fit(self, sentences, min_count=5, max_count=None, max_features=None):
|
||||||
|
"""训练 WordSequence
|
||||||
|
Args:
|
||||||
|
min_count 最小出现次数
|
||||||
|
max_count 最大出现次数
|
||||||
|
max_features 最大特征数
|
||||||
|
|
||||||
|
ws = WordSequence()
|
||||||
|
ws.fit([['hello', 'world']])
|
||||||
|
"""
|
||||||
|
assert not self.fited, 'WordSequence 只能 fit 一次'
|
||||||
|
|
||||||
|
count = {}
|
||||||
|
for sentence in sentences:
|
||||||
|
arr = list(sentence)
|
||||||
|
for a in arr:
|
||||||
|
if a not in count:
|
||||||
|
count[a] = 0
|
||||||
|
count[a] += 1
|
||||||
|
|
||||||
|
if min_count is not None:
|
||||||
|
count = {k: v for k, v in count.items() if v >= min_count}
|
||||||
|
|
||||||
|
if max_count is not None:
|
||||||
|
count = {k: v for k, v in count.items() if v <= max_count}
|
||||||
|
|
||||||
|
self.dict = {
|
||||||
|
WordSequence.PAD_TAG: WordSequence.PAD,
|
||||||
|
WordSequence.UNK_TAG: WordSequence.UNK,
|
||||||
|
WordSequence.START_TAG: WordSequence.START,
|
||||||
|
WordSequence.END_TAG: WordSequence.END,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(max_features, int):
|
||||||
|
count = sorted(list(count.items()), key=lambda x: x[1])
|
||||||
|
if max_features is not None and len(count) > max_features:
|
||||||
|
count = count[-int(max_features):]
|
||||||
|
for w, _ in count:
|
||||||
|
self.dict[w] = len(self.dict)
|
||||||
|
else:
|
||||||
|
for w in sorted(count.keys()):
|
||||||
|
self.dict[w] = len(self.dict)
|
||||||
|
|
||||||
|
self.fited = True
|
||||||
|
|
||||||
|
|
||||||
|
def transform(self,
|
||||||
|
sentence, max_len=None):
|
||||||
|
"""把句子转换为向量
|
||||||
|
例如输入 ['a', 'b', 'c']
|
||||||
|
输出 [1, 2, 3] 这个数字是字典里的编号,顺序没有意义
|
||||||
|
"""
|
||||||
|
assert self.fited, 'WordSequence 尚未 fit'
|
||||||
|
|
||||||
|
# if max_len is not None:
|
||||||
|
# r = [self.PAD] * max_len
|
||||||
|
# else:
|
||||||
|
# r = [self.PAD] * len(sentence)
|
||||||
|
|
||||||
|
if max_len is not None:
|
||||||
|
r = [self.PAD] * max_len
|
||||||
|
else:
|
||||||
|
r = [self.PAD] * len(sentence)
|
||||||
|
|
||||||
|
for index, a in enumerate(sentence):
|
||||||
|
if max_len is not None and index >= len(r):
|
||||||
|
break
|
||||||
|
r[index] = self.to_index(a)
|
||||||
|
|
||||||
|
return np.array(r)
|
||||||
|
|
||||||
|
|
||||||
|
def inverse_transform(self, indices,
|
||||||
|
ignore_pad=False, ignore_unk=False,
|
||||||
|
ignore_start=False, ignore_end=False):
|
||||||
|
"""把向量转换为句子,和上面的相反
|
||||||
|
"""
|
||||||
|
ret = []
|
||||||
|
for i in indices:
|
||||||
|
word = self.to_word(i)
|
||||||
|
if word == WordSequence.PAD_TAG and ignore_pad:
|
||||||
|
continue
|
||||||
|
if word == WordSequence.UNK_TAG and ignore_unk:
|
||||||
|
continue
|
||||||
|
if word == WordSequence.START_TAG and ignore_start:
|
||||||
|
continue
|
||||||
|
if word == WordSequence.END_TAG and ignore_end:
|
||||||
|
continue
|
||||||
|
ret.append(word)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""测试
|
||||||
|
"""
|
||||||
|
ws = WordSequence()
|
||||||
|
ws.fit([
|
||||||
|
['第', '一', '句', '话'],
|
||||||
|
['第', '二', '句', '话']
|
||||||
|
])
|
||||||
|
|
||||||
|
indice = ws.transform(['第', '三'])
|
||||||
|
print(indice)
|
||||||
|
|
||||||
|
back = ws.inverse_transform(indice)
|
||||||
|
print(back)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test()
|
Loading…
Reference in New Issue
Block a user