bert feature and sim
This commit is contained in:
parent
2dd6c6d10b
commit
6de88d7c70
5
FeatureProject/bert/__init__.py
Normal file
5
FeatureProject/bert/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# !/usr/bin/python
|
||||
# @time :2019/5/10 9:12
|
||||
# @author :Mo
|
||||
# @function :
|
85
FeatureProject/bert/extract_keras_bert_feature.py
Normal file
85
FeatureProject/bert/extract_keras_bert_feature.py
Normal file
@ -0,0 +1,85 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# !/usr/bin/python
|
||||
# @time :2019/5/8 20:04
|
||||
# @author :Mo
|
||||
# @function :
|
||||
|
||||
from conf.feature_config import gpu_memory_fraction, config_name, ckpt_name, vocab_file, max_seq_len
|
||||
from keras_bert import load_trained_model_from_checkpoint, Tokenizer
|
||||
import keras.backend.tensorflow_backend as ktf_keras
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import codecs
|
||||
import os
|
||||
|
||||
# 全局使用,使其可以django、flask、tornado等调用
|
||||
graph = None
|
||||
model = None
|
||||
|
||||
# gpu配置与使用率设置
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction
|
||||
sess = tf.Session(config=config)
|
||||
ktf_keras.set_session(sess)
|
||||
|
||||
class KerasBertVector():
|
||||
def __init__(self):
|
||||
self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len
|
||||
# 全局使用,使其可以django、flask、tornado等调用
|
||||
global graph
|
||||
graph = tf.get_default_graph()
|
||||
global model
|
||||
model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path,
|
||||
seq_len=self.max_seq_len)
|
||||
self.token_dict = {}
|
||||
with codecs.open(self.dict_path, 'r', 'utf8') as reader:
|
||||
for line in reader:
|
||||
token = line.strip()
|
||||
self.token_dict[token] = len(self.token_dict)
|
||||
|
||||
self.tokenizer = Tokenizer(self.token_dict)
|
||||
|
||||
def bert_encode(self, texts):
|
||||
|
||||
input_ids = []
|
||||
input_masks = []
|
||||
input_type_ids = []
|
||||
for text in texts:
|
||||
print(text)
|
||||
tokens_text = self.tokenizer.tokenize(text)
|
||||
print('Tokens:', tokens_text)
|
||||
input_id, input_type_id = self.tokenizer.encode(first=text, max_len=self.max_seq_len)
|
||||
input_mask = [0 if ids == 0 else 1 for ids in input_id]
|
||||
input_ids.append(input_id)
|
||||
input_type_ids.append(input_type_id)
|
||||
input_masks.append(input_mask)
|
||||
|
||||
input_ids = np.array(input_ids)
|
||||
input_masks = np.array(input_masks)
|
||||
input_type_ids = np.array(input_type_ids)
|
||||
|
||||
# 全局使用,使其可以django、flask、tornado等调用
|
||||
with graph.as_default():
|
||||
predicts = model.predict([input_ids, input_type_ids], batch_size=1)
|
||||
print(predicts.shape)
|
||||
for i, token in enumerate(tokens_text):
|
||||
print(token, [len(predicts[0][i].tolist())], predicts[0][i].tolist())
|
||||
|
||||
# 相当于pool,采用的是https://github.com/terrifyzhao/bert-utils/blob/master/graph.py
|
||||
mul_mask = lambda x, m: x * np.expand_dims(m, axis=-1)
|
||||
masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9)
|
||||
pooled = masked_reduce_mean(predicts[0][-1], input_masks)
|
||||
pooled = pooled.tolist()
|
||||
print('bert:', pooled)
|
||||
return pooled
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bert_vector = KerasBertVector()
|
||||
pooled = bert_vector.bert_encode(['你好呀', '你是谁'])
|
||||
print(pooled)
|
||||
while True:
|
||||
print("input:")
|
||||
ques = input()
|
||||
print(bert_vector.bert_encode([ques]))
|
77
FeatureProject/bert/tet_bert_keras_sim.py
Normal file
77
FeatureProject/bert/tet_bert_keras_sim.py
Normal file
@ -0,0 +1,77 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# !/usr/bin/python
|
||||
# @time :2019/5/7 20:27
|
||||
# @author :Mo
|
||||
# @function :test sentence of bert encode and cosin sim of two question
|
||||
|
||||
|
||||
def calculate_count():
|
||||
"""
|
||||
统计一下1000条测试数据的平均耗时
|
||||
:return:
|
||||
"""
|
||||
from FeatureProject.bert.extract_keras_bert_feature import KerasBertVector
|
||||
import time
|
||||
|
||||
bert_vector = KerasBertVector()
|
||||
print("bert start ok!")
|
||||
time_start = time.time()
|
||||
for i in range(1000):
|
||||
vector = bert_vector.bert_encode(["jy,你知道吗,我一直都很喜欢你呀,在一起在一起在一起,哈哈哈哈"])
|
||||
|
||||
time_end = time.time()
|
||||
time_avg = (time_end-time_start)/1000
|
||||
print(vector)
|
||||
print(time_avg)
|
||||
# 0.12605296468734742 win10 gpu avg
|
||||
# 0.01629048466682434 linux cpu avg
|
||||
|
||||
|
||||
def sim_two_question():
|
||||
"""测试一下两个问题的相似句子"""
|
||||
from FeatureProject.bert.extract_keras_bert_feature import KerasBertVector
|
||||
from sklearn import preprocessing
|
||||
from math import pi
|
||||
import numpy as np
|
||||
import time
|
||||
import math
|
||||
|
||||
def cosine_distance(v1, v2): # 余弦距离
|
||||
if v1.all() and v2.all():
|
||||
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
||||
else:
|
||||
return 0
|
||||
|
||||
def scale_zoom(rate): # sig 缩放
|
||||
zoom = (1 + np.exp(-float(rate))) / 2
|
||||
return zoom
|
||||
|
||||
def scale_triangle(rate): # sin 缩放
|
||||
triangle = math.sin(rate/1*pi/2 - pi/2)
|
||||
return triangle
|
||||
|
||||
bert_vector = KerasBertVector()
|
||||
print("bert start ok!")
|
||||
while True:
|
||||
print("input ques-1: ")
|
||||
ques_1 = input()
|
||||
print("input ques_2: ")
|
||||
ques_2 = input()
|
||||
vector_1 = bert_vector.bert_encode([ques_1])
|
||||
vector_2 = bert_vector.bert_encode([ques_2])
|
||||
sim = cosine_distance(vector_1[0], vector_2[0])
|
||||
# sim_list = [sim, 0, 0.2, 0.4, 0.6, 0.8, 1.0]
|
||||
# sim = preprocessing.scale(sim_list)[0]
|
||||
# sim = preprocessing.MinMaxScaler(feature_range=(0, 1)).fit_transform(sim_list)[0]
|
||||
# sim_1 = preprocessing.normalize(sim_list, norm='l1')[0]
|
||||
# sim_2 = preprocessing.normalize(sim_list, norm='l2')[0]
|
||||
# sim = scale_zoom(sim)
|
||||
# sim = scale_triangle(sim)
|
||||
# print(sim_1)
|
||||
# print(sim_2)
|
||||
print(sim)
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
calculate_count()
|
||||
sim_two_question()
|
Loading…
Reference in New Issue
Block a user