From 98219427aad9634373afa6555f433576ff3a5009 Mon Sep 17 00:00:00 2001 From: joe Date: Mon, 27 May 2019 13:53:07 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0bert=E5=8F=A5=E5=90=91?= =?UTF-8?q?=E9=87=8F=E6=8C=81=E4=B9=85=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bert_vec.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 bert_vec.py diff --git a/bert_vec.py b/bert_vec.py new file mode 100644 index 0000000..c18750c --- /dev/null +++ b/bert_vec.py @@ -0,0 +1,75 @@ +import pickle +from graph import set_logger +from termcolor import colored + +logger = set_logger(colored('BERT_VEC', 'yellow')) +bert_file_name = 'bert_data.pkl' + + +class BertData: + def __init__(self): + self.dic = {} + self._read_dic() + + # 批量插入数据 + def add_batch_data(self, keys, values): + for key, value in zip(keys, values): + self.dic[key] = value + + # 插入单条数据 + def add_data(self, key, value): + self.dic[key] = value + + # 根据key删除数据 + def delete_data(self, key): + if self.dic and self.dic.get(key, ''): + self.dic.pop(key) + + # 根据key获取数据 + def get_data(self, key): + return self.dic.get(key, '') + + # 获取全部数据 + def get_all_data(self): + return self.dic + + # 提交 + def commit(self): + self._save_dic() + + def _save_dic(self): + try: + with open(bert_file_name, 'wb')as file: + pickle.dump(self.dic, file) + logger.info('bert data saved successfully') + except: + logger.info('save bert data failed') + + def _read_dic(self): + try: + with open(bert_file_name, 'rb')as file: + self.dic = pickle.load(file) + except FileNotFoundError: + logger.info('local bert data is none') + + +if __name__ == '__main__': + bd = BertData() + data = [] + vec = [] + import numpy as np + + for i in range(30000): + data.append('阿迪和考虑就鞍山市会计法哈三联空间和福利卡就很烦' + str(i)) + vec.append(np.random.rand(768)) + bd.add_batch_data(data, vec) + # 增删改需要调用commit方法才会修改本地缓存的内容,查询不需要调用该方法 + bd.commit() + # bd.delete_data('上午好啊天气真的不错0') + # res = bd.get_data('上午好啊天气真的不错1') + # bd.add_data('上午好啊天气真的不错test', [1, 2, 3]) + # res = bd.get_all_data() + # print(res.keys()) + # for i in res.items(): + # print(i[0], ':', i[1]) + # print(res.values())