From 6de490e1e6ef3213b288f447f0a222c2a16dd77b Mon Sep 17 00:00:00 2001 From: yongzhuo <31341349+yongzhuo@users.noreply.github.com> Date: Tue, 30 Jul 2019 15:03:10 +0800 Subject: [PATCH] fix pooled --- .../bert/extract_keras_bert_feature.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/FeatureProject/bert/extract_keras_bert_feature.py b/FeatureProject/bert/extract_keras_bert_feature.py index c65244b..190eb64 100644 --- a/FeatureProject/bert/extract_keras_bert_feature.py +++ b/FeatureProject/bert/extract_keras_bert_feature.py @@ -117,16 +117,22 @@ class KerasBertVector(): # 相当于pool,采用的是https://github.com/terrifyzhao/bert-utils/blob/master/graph.py mul_mask = lambda x, m: x * np.expand_dims(m, axis=-1) masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9) - pooled = masked_reduce_mean(predicts[0], input_masks) - pooled = pooled.tolist() - print('bert:', pooled) - return pooled + + pools = [] + for i in range(len(predicts)): + pred = predicts[i] + masks = input_masks.tolist() + mask_np = np.array([masks[i]]) + pooled = masked_reduce_mean(pred, mask_np) + pooled = pooled.tolist() + pools.append(pooled[0]) + print('bert:', pools) + return pools if __name__ == "__main__": - # 一次只提取一个句子 bert_vector = KerasBertVector() - pooled = bert_vector.bert_encode(['你是谁呀']) + pooled = bert_vector.bert_encode(['你是谁呀', '小老弟']) print(pooled) while True: print("input:")