fix pooled
This commit is contained in:
parent
a397f9bddf
commit
6de490e1e6
@ -117,16 +117,22 @@ class KerasBertVector():
|
||||
# 相当于pool,采用的是https://github.com/terrifyzhao/bert-utils/blob/master/graph.py
|
||||
mul_mask = lambda x, m: x * np.expand_dims(m, axis=-1)
|
||||
masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9)
|
||||
pooled = masked_reduce_mean(predicts[0], input_masks)
|
||||
|
||||
pools = []
|
||||
for i in range(len(predicts)):
|
||||
pred = predicts[i]
|
||||
masks = input_masks.tolist()
|
||||
mask_np = np.array([masks[i]])
|
||||
pooled = masked_reduce_mean(pred, mask_np)
|
||||
pooled = pooled.tolist()
|
||||
print('bert:', pooled)
|
||||
return pooled
|
||||
pools.append(pooled[0])
|
||||
print('bert:', pools)
|
||||
return pools
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 一次只提取一个句子
|
||||
bert_vector = KerasBertVector()
|
||||
pooled = bert_vector.bert_encode(['你是谁呀'])
|
||||
pooled = bert_vector.bert_encode(['你是谁呀', '小老弟'])
|
||||
print(pooled)
|
||||
while True:
|
||||
print("input:")
|
||||
|
Loading…
Reference in New Issue
Block a user