Create vocab.py

This commit is contained in:
missQian 2020-10-04 21:30:03 +08:00 committed by GitHub
parent 2daa8cd3bf
commit 13d553482f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -0,0 +1,239 @@
import pickle
from collections import Counter
class Vocab(object):
PAD = '*PAD*'
UNK = '*UNK*'
NULL = '*NULL*'
START = '*START*'
END = '*END*'
ROOT = '*ROOT*'
def __init__(self):
self.tok2idx = {}
self.idx2count = {}
self.idx2tok = {}
self.special_token_size = 0
self.singleton_size = 0
self.singleton_max_count = 0
self.min_count = 0
self.max_count = 0
def __len__(self):
return len(self.tok2idx)
def items(self):
for k, v in self.tok2idx.items():
yield k, v
def __getitem__(self, item):
return self.tok2idx[item]
def keys(self):
return self.tok2idx.keys()
def vals(self):
return self.tok2idx.values()
def add_counter(self,
counter,
min_count = 1,
max_count = 1e7,
singleton_max_count=1,
update_count = False
):
'''
:param counter:
:param min_count:
:param max_count:
:param singleton_max_count: int, we treat a token as a singleton
when 0 < token count <= singleton_max_count,
this is in favor of some UNK replace strategies
:return:
'''
self.min_count = min_count
self.max_count = max_count
self.singleton_max_count = singleton_max_count
for tok, count in counter.most_common(n=int(max_count)):
if count >= self.min_count:
self.add_token(tok, count, update_count)
def add_spec_toks(self,
pad_tok = True,
unk_tok = True,
start_tok = False,
end_tok = False,
root_tok = False,
null_tok = False):
if pad_tok:
self.add_token(Vocab.PAD)
self.special_token_size += 1
if unk_tok:
self.add_token(Vocab.UNK)
self.special_token_size += 1
if start_tok:
self.add_token(Vocab.START)
self.special_token_size += 1
if end_tok:
self.add_token(Vocab.END)
self.special_token_size += 1
if root_tok:
self.add_token(Vocab.ROOT)
self.special_token_size += 1
if null_tok:
self.add_token(Vocab.NULL)
self.special_token_size += 1
def add_token(self, token, count = 1, update_count=False):
idx = self.tok2idx.get(token, None)
if idx is None:
idx = len(self.tok2idx)
self.tok2idx[token] = idx
self.idx2count[idx] = count
if count <= self.singleton_max_count:
self.singleton_size += 1
elif update_count:
new_count = self.idx2count[idx] + count
self.idx2count[idx] = new_count
if new_count > self.singleton_max_count:
self.singleton_size -= 1
return idx
def get_vocab_size(self):
return len(self.tok2idx)
def get_vocab_size_without_spec(self):
return len(self.tok2idx) - self.special_token_size
def get_index(self, token, default_value = '*UNK*'):
idx = self.tok2idx.get(token, None)
if idx is None:
if default_value:
return self.tok2idx[default_value]
else:
raise RuntimeError('Token %s not found'%token)
else:
return idx
def __iter__(self):
for tok, idx in self.tok2idx.items():
yield idx, tok
def get_token(self, index):
if len(self.idx2tok) == 0:
for tok, idx in self.tok2idx.items():
self.idx2tok[idx] = tok
return self.idx2tok[index]
def get_token_set(self):
return self.tok2idx.keys()
def recount_singleton_size(self):
singleton_size = 0
for count in self.idx2count.values():
if count <= self.singleton_max_count:
singleton_size += 1
self.singleton_max_count = singleton_size
def get_singleton_size(self, re_count=False):
if re_count:
self.recount_singleton_size()
return self.singleton_size
def is_singleton(self, token_or_index):
if isinstance(token_or_index, str):
idx = self.tok2idx.get(token_or_index, None)
if idx is None:
# we treat OOV as singletons
return True
elif isinstance(token_or_index, int):
idx = token_or_index
else:
raise TypeError('Unknown type %s'%(type(token_or_index)))
count = self.idx2count[idx]
return count <= self.singleton_max_count
def __contains__(self, token):
return token in self.tok2idx
def __str__(self):
spec_tok_size_str = 'special_token_size\t' + str(self.special_token_size) + '\n'
tok_size_str = 'token_size\t' + str(self.get_vocab_size_without_spec()) + '\n'
singleton_size_str = 'singleton_size\t' + str(self.singleton_size) + '\n'
singleton_max_count_str = 'singleton_max_count\t' + str(self.singleton_max_count) + '\n'
return spec_tok_size_str + tok_size_str + singleton_size_str + singleton_max_count_str
def save(self, file_path, format='text'):
'''
:param format: 'pickle' or 'text'
:return:
'''
if format == 'pickle':
with open(file_path, 'wb') as file:
pickle.dump(self, file)
elif format == 'text':
with open(file_path, 'w') as file:
file.write(str(self))
# write tokens by index increase order
tok2idx_list = sorted(list(self.tok2idx.items()), key=lambda x:x[1])
for tok, idx in tok2idx_list:
count = self.idx2count[idx]
file.write(str(idx)+ '\t' + tok + '\t' + str(count) +'\n')
else:
raise RuntimeError('Unknown save format')
@staticmethod
def load(file_path, format='text'):
if format == 'pickle':
with open(file_path, 'rb') as file:
return pickle.load(file)
elif format == 'text':
with open(file_path, 'r') as file:
vocab = Vocab()
lines = list(file.readlines())
spec_tok_size = int(lines[0].split('\t')[1])
tok_size = int(lines[1].split('\t')[1])
singleton_size = int(lines[2].split('\t')[1])
singleton_max_count = int(lines[3].split('\t')[1])
vocab.special_token_size = spec_tok_size
vocab.singleton_size = singleton_size
vocab.singleton_max_count = singleton_max_count
offset = 4 # skip head information
for i in range(offset, offset + tok_size + spec_tok_size):
line_arr = lines[i].strip().split('\t')
idx, tok, count = int(line_arr[0]), line_arr[1], int(line_arr[2])
vocab.tok2idx[tok] = idx
vocab.idx2count[idx] = count
return vocab
else:
raise RuntimeError('Unknown load format')