fix pooled

This commit is contained in:
yongzhuo 2019-07-30 15:03:10 +08:00 committed by GitHub
parent a397f9bddf
commit 6de490e1e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -117,16 +117,22 @@ class KerasBertVector():
# 相当于pool采用的是https://github.com/terrifyzhao/bert-utils/blob/master/graph.py
mul_mask = lambda x, m: x * np.expand_dims(m, axis=-1)
masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9)
pooled = masked_reduce_mean(predicts[0], input_masks)
pooled = pooled.tolist()
print('bert:', pooled)
return pooled
pools = []
for i in range(len(predicts)):
pred = predicts[i]
masks = input_masks.tolist()
mask_np = np.array([masks[i]])
pooled = masked_reduce_mean(pred, mask_np)
pooled = pooled.tolist()
pools.append(pooled[0])
print('bert:', pools)
return pools
if __name__ == "__main__":
# 一次只提取一个句子
bert_vector = KerasBertVector()
pooled = bert_vector.bert_encode(['你是谁呀'])
pooled = bert_vector.bert_encode(['你是谁呀', '小老弟'])
print(pooled)
while True:
print("input:")