fix pooled
This commit is contained in:
parent
a397f9bddf
commit
6de490e1e6
@ -117,16 +117,22 @@ class KerasBertVector():
|
|||||||
# 相当于pool,采用的是https://github.com/terrifyzhao/bert-utils/blob/master/graph.py
|
# 相当于pool,采用的是https://github.com/terrifyzhao/bert-utils/blob/master/graph.py
|
||||||
mul_mask = lambda x, m: x * np.expand_dims(m, axis=-1)
|
mul_mask = lambda x, m: x * np.expand_dims(m, axis=-1)
|
||||||
masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9)
|
masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9)
|
||||||
pooled = masked_reduce_mean(predicts[0], input_masks)
|
|
||||||
pooled = pooled.tolist()
|
pools = []
|
||||||
print('bert:', pooled)
|
for i in range(len(predicts)):
|
||||||
return pooled
|
pred = predicts[i]
|
||||||
|
masks = input_masks.tolist()
|
||||||
|
mask_np = np.array([masks[i]])
|
||||||
|
pooled = masked_reduce_mean(pred, mask_np)
|
||||||
|
pooled = pooled.tolist()
|
||||||
|
pools.append(pooled[0])
|
||||||
|
print('bert:', pools)
|
||||||
|
return pools
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 一次只提取一个句子
|
|
||||||
bert_vector = KerasBertVector()
|
bert_vector = KerasBertVector()
|
||||||
pooled = bert_vector.bert_encode(['你是谁呀'])
|
pooled = bert_vector.bert_encode(['你是谁呀', '小老弟'])
|
||||||
print(pooled)
|
print(pooled)
|
||||||
while True:
|
while True:
|
||||||
print("input:")
|
print("input:")
|
||||||
|
Loading…
Reference in New Issue
Block a user