diff --git a/FeatureProject/bert/extract_keras_bert_feature.py b/FeatureProject/bert/extract_keras_bert_feature.py index c65244b..190eb64 100644 --- a/FeatureProject/bert/extract_keras_bert_feature.py +++ b/FeatureProject/bert/extract_keras_bert_feature.py @@ -117,16 +117,22 @@ class KerasBertVector(): # 相当于pool,采用的是https://github.com/terrifyzhao/bert-utils/blob/master/graph.py mul_mask = lambda x, m: x * np.expand_dims(m, axis=-1) masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9) - pooled = masked_reduce_mean(predicts[0], input_masks) - pooled = pooled.tolist() - print('bert:', pooled) - return pooled + + pools = [] + for i in range(len(predicts)): + pred = predicts[i] + masks = input_masks.tolist() + mask_np = np.array([masks[i]]) + pooled = masked_reduce_mean(pred, mask_np) + pooled = pooled.tolist() + pools.append(pooled[0]) + print('bert:', pools) + return pools if __name__ == "__main__": - # 一次只提取一个句子 bert_vector = KerasBertVector() - pooled = bert_vector.bert_encode(['你是谁呀']) + pooled = bert_vector.bert_encode(['你是谁呀', '小老弟']) print(pooled) while True: print("input:")