add augment-nmt and augment-simbert
This commit is contained in:
parent
0eba3a83d2
commit
8b427f3a20
@ -52,6 +52,18 @@
|
||||
- transformer
|
||||
- GAN
|
||||
|
||||
## 预训练模型-UMILM
|
||||
使用BERT(UNILM)的生成能力, 即BERT的NSP句对任务
|
||||
- simbert(bert + unilm + adv): [https://github.com/ZhuiyiTechnology/simbert](https://github.com/ZhuiyiTechnology/simbert)
|
||||
- simbert: [鱼与熊掌兼得:融合检索和生成的SimBERT模型](https://spaces.ac.cn/archives/7427)
|
||||
- roformer-sim: [https://github.com/ZhuiyiTechnology/roformer-sim](https://github.com/ZhuiyiTechnology/roformer-sim)
|
||||
- simbert-v2(roformer + unilm + adv + bart + distill): [SimBERTv2来了!融合检索和生成的RoFormer-Sim模型](https://spaces.ac.cn/archives/8454)
|
||||
|
||||
## 回译(开源模型效果不是很好)
|
||||
中文转化成其他语言(如英语), 其他语言(如英语)转化成中文, Helsinki-NLP开源的预训练模型
|
||||
- opus-mt-en-zh: https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
|
||||
- opus-mt-zh-en: https://huggingface.co/Helsinki-NLP/opus-mt-zh-en
|
||||
|
||||
|
||||
# 参考/感谢
|
||||
* eda_chinese:[https://github.com/zhanlaoban/eda_nlp_for_Chinese](https://github.com/zhanlaoban/eda_nlp_for_Chinese)
|
11
AugmentText/augment_nmt/README.md
Normal file
11
AugmentText/augment_nmt/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
# Augment NMT
|
||||
|
||||
## 回译(开源模型效果不是很好)
|
||||
中文转化成其他语言(如英语), 其他语言(如英语)转化成中文, Helsinki-NLP开源的预训练模型
|
||||
- opus-mt-en-zh: https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
|
||||
- opus-mt-zh-en: https://huggingface.co/Helsinki-NLP/opus-mt-zh-en
|
||||
|
||||
## 备注
|
||||
开源模型的效果不是那么理想, 只能少部分生成, 比如一条
|
||||
|
||||
|
5
AugmentText/augment_nmt/__init__.py
Normal file
5
AugmentText/augment_nmt/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# !/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @time : 2021/9/22 21:04
|
||||
# @author : Mo
|
||||
# @function:
|
96
AugmentText/augment_nmt/nmt_local.py
Normal file
96
AugmentText/augment_nmt/nmt_local.py
Normal file
@ -0,0 +1,96 @@
|
||||
# !/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @time : 2021/9/22 21:37
|
||||
# @author : Mo
|
||||
# @function: NMT of Helsinki-NLP
|
||||
# 下载地址:
|
||||
# opus-mt-en-zh: https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
|
||||
# opus-mt-zh-en: https://huggingface.co/Helsinki-NLP/opus-mt-zh-en
|
||||
|
||||
|
||||
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, pipeline)
|
||||
import time
|
||||
import os
|
||||
|
||||
|
||||
class BackTranslate:
|
||||
def __init__(self, pretrained_dir):
|
||||
# zh-to-en
|
||||
tokenizer = AutoTokenizer.from_pretrained(os.path.join(pretrained_dir, "Helsinki-NLP/opus-mt-zh-en"))
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(pretrained_dir, "Helsinki-NLP/opus-mt-zh-en"))
|
||||
# en-to-zh
|
||||
tokenizer_back_translate = AutoTokenizer.from_pretrained(os.path.join(pretrained_dir, "Helsinki-NLP/opus-mt-en-zh"))
|
||||
model_back_translate = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(pretrained_dir, "Helsinki-NLP/opus-mt-en-zh"))
|
||||
# pipeline
|
||||
self.zh2en = pipeline("translation_zh_to_en", model=model, tokenizer=tokenizer)
|
||||
self.en2zh = pipeline("translation_en_to_zh", model=model_back_translate, tokenizer=tokenizer_back_translate)
|
||||
|
||||
def back_translate(self, text):
|
||||
""" 回译 """
|
||||
text_en = self.zh2en(text, max_length=510)[0]["translation_text"]
|
||||
print("text_en:", text_en)
|
||||
text_back = self.en2zh(text_en, max_length=510)[0]["translation_text"]
|
||||
print("text_back:", text_back)
|
||||
return text_back
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
pretrained_dir = "D:/soft_install/dataset/bert-model/translate"
|
||||
bt = BackTranslate(pretrained_dir)
|
||||
datas = [{"text": "平乐县,古称昭州,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭城,总面积1919.34平方公里。"},
|
||||
{"text": "平乐县主要旅游景点有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐县为漓江分界点,平乐以北称漓江,以南称桂江,是著名的大桂林旅游区之一。"},
|
||||
{"text": "印岭玲珑,昭水晶莹,环绕我平中。青年的乐园,多士受陶熔。生活自觉自治,学习自发自动。五育并重,手脑并用。迎接新潮流,建设新平中"},
|
||||
{"text": "桂林山水甲天下, 阳朔山水甲桂林"},
|
||||
{"text": "三国一统天下"},
|
||||
{"text": "世间万物皆系于其上"},
|
||||
{"text": "2020年真是一个糟糕的年代, 进入20年代,新冠爆发、经济下行,什么的都来了。"},
|
||||
{"text": "仿佛一切都变得不那么重要了。"},
|
||||
{"text": "苹果多少钱一斤"}
|
||||
]
|
||||
time_start = time.time()
|
||||
for da in datas:
|
||||
text = da.get("text", "")
|
||||
bt.back_translate(text)
|
||||
time_total = time.time() - time_start
|
||||
print("time_total:{}".format(time_total))
|
||||
print("time_per:{}".format(time_total / len(datas)))
|
||||
|
||||
while True:
|
||||
print("请输入:")
|
||||
ques = input()
|
||||
res = bt.back_translate(ques)
|
||||
print("####################################################")
|
||||
|
||||
|
||||
# 下载地址:
|
||||
# opus-mt-en-zh: https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
|
||||
# opus-mt-zh-en: https://huggingface.co/Helsinki-NLP/opus-mt-zh-en
|
||||
|
||||
|
||||
# 备注: 翻译效果不大好
|
||||
|
||||
|
||||
|
||||
"""
|
||||
text_en: Ping Lei County, anciently known as Zhao County, belongs to the city of Gui Lin, Guangxi Liang Autonomous Region, and is located in the north-east of Guangxi, south-east of the city of Gui Lin, eastern Pingshan County, south-west Su Ping, north-west of Yangyon and north-west of the city of Lilongqi, with a total area of 1919.34 square kilometres.
|
||||
text_back: 平莱县,古代称为赵县,属于广西梁自治区Gui Lin市,位于广西东北、Gui Lin市东南、Pingshan县东南、Su Ping西南、Yangyon西北和Lilongqi市西北,总面积1919.34平方公里。
|
||||
text_en: The main tourist attractions in the district of Ping Lei are Xin Xianjin Quan, Cold Water Qing Qing, Qingjiang, Qingjiang, Qingjiang, etc. The district of Ping Le is one of the well-known Grand Gui Lin tourist areas, which is known as Jingjiang, north of Ping Lei and south of Ping Lei.
|
||||
text_back: 平莱区的主要旅游景点为新贤进泉、冷水清清、青江、青江、青江、青江等。 平来区是著名的大桂林旅游区之一,称为青江,位于平莱以北和平莱以南。
|
||||
text_en: The young man's garden, the Doss, is molten with pottery. Life is self-governing, learning self-involvement. It's full and heavy, and the hands and brains work together. It takes a new tide and builds a new flat.
|
||||
text_back: 年轻人的花园,多斯人,被陶器熔化了。生活是自治的,学习自我参与。生活是满的和沉重的,手和大脑一起工作。它需要新的潮水,建造新的公寓。
|
||||
text_en: Guilin Mountain Watermarin, Sunshaw Hill Watermarin
|
||||
text_back: 古林山水马林、桑肖山水马林
|
||||
text_en: All three of us.
|
||||
text_back: 我们三个人
|
||||
text_en: Everything in the world is in it.
|
||||
text_back: 世界上所有的东西都在里面
|
||||
text_en: The year 2020 was a really bad time, and in the 20s, the crown broke out, the economy went down, everything came up.
|
||||
text_back: 2020年是一个非常糟糕的时期, 在20年代,王冠崩盘, 经济下滑,一切都出现了。
|
||||
text_en: As if everything had become less important.
|
||||
text_back: 仿佛一切都变得不重要了
|
||||
text_en: How much is an apple?
|
||||
text_back: 苹果多少钱?
|
||||
"""
|
||||
|
3
AugmentText/augment_nmt/requestments.txt
Normal file
3
AugmentText/augment_nmt/requestments.txt
Normal file
@ -0,0 +1,3 @@
|
||||
tensorflow-gpu==1.15.2
|
||||
transformers==0.4.10
|
||||
|
12
AugmentText/augment_simbert/README.md
Normal file
12
AugmentText/augment_simbert/README.md
Normal file
@ -0,0 +1,12 @@
|
||||
# Augment Simbert
|
||||
|
||||
## 预训练模型-UMILM
|
||||
使用BERT(UNILM)的生成能力, 即BERT的NSP句对任务
|
||||
- simbert(bert + unilm + adv): [https://github.com/ZhuiyiTechnology/simbert](https://github.com/ZhuiyiTechnology/simbert)
|
||||
- simbert: [鱼与熊掌兼得:融合检索和生成的SimBERT模型](https://spaces.ac.cn/archives/7427)
|
||||
- roformer-sim: [https://github.com/ZhuiyiTechnology/roformer-sim](https://github.com/ZhuiyiTechnology/roformer-sim)
|
||||
- simbert-v2(roformer + unilm + adv + bart + distill): [SimBERTv2来了!融合检索和生成的RoFormer-Sim模型](https://spaces.ac.cn/archives/8454)
|
||||
|
||||
## 备注
|
||||
效果还是比较好的, 可以生成多个相似句子, 但是生成式的模型一般都比较慢。
|
||||
|
5
AugmentText/augment_simbert/__init__.py
Normal file
5
AugmentText/augment_simbert/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# !/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @time : 2021/9/18 21:12
|
||||
# @author : Mo
|
||||
# @function:
|
155
AugmentText/augment_simbert/enhance_roformer.py
Normal file
155
AugmentText/augment_simbert/enhance_roformer.py
Normal file
@ -0,0 +1,155 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# !/usr/bin/python
|
||||
# @time :2019/4/9 23:05
|
||||
# @author :Mo
|
||||
# @function :SimBERT再训练BERT-base(NSP任务), UNILM的生成能力
|
||||
# @reference:https://github.com/ZhuiyiTechnology/roformer-sim
|
||||
# 目前仅保证支持 Tensorflow 1.x + Keras <= 2.3.1 + bert4keras>=0.10.6。
|
||||
# 具体用法请看 https://github.com/bojone/bert4keras/blob/8ffb46a16a79f87aa8cdf045df7994036b4be47d/bert4keras/snippets.py#L580
|
||||
|
||||
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
from bert4keras.snippets import sequence_padding, AutoRegressiveDecoder
|
||||
from bert4keras.models import build_transformer_model
|
||||
from bert4keras.tokenizers import Tokenizer
|
||||
from bert4keras.backend import keras, K
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
# bert配置
|
||||
# BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_roformer-sim-char_L-12_H-768_A-12"
|
||||
BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_roformer-sim-char_L-6_H-384_A-6"
|
||||
|
||||
config_path = BERT_DIR + "/bert_config.json"
|
||||
checkpoint_path = BERT_DIR + "/bert_model.ckpt"
|
||||
dict_path = BERT_DIR + "/vocab.txt"
|
||||
maxlen = 128
|
||||
|
||||
|
||||
# 建立分词器
|
||||
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
|
||||
|
||||
# 建立加载模型
|
||||
bert = build_transformer_model(
|
||||
config_path,
|
||||
checkpoint_path,
|
||||
with_pool='linear',
|
||||
model='roformer',
|
||||
application='unilm',
|
||||
return_keras_model=False,
|
||||
)
|
||||
|
||||
encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])
|
||||
seq2seq = keras.models.Model(bert.model.inputs, bert.model.outputs[1])
|
||||
|
||||
|
||||
class SynonymsGenerator(AutoRegressiveDecoder):
|
||||
"""seq2seq解码器
|
||||
"""
|
||||
@AutoRegressiveDecoder.wraps(default_rtype='probas')
|
||||
def predict(self, inputs, output_ids, step):
|
||||
token_ids, segment_ids = inputs
|
||||
token_ids = np.concatenate([token_ids, output_ids], 1)
|
||||
segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
|
||||
return self.last_token(seq2seq).predict([token_ids, segment_ids])
|
||||
|
||||
def generate(self, text, n=1, topp=0.95, mask_idxs=[]):
|
||||
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
|
||||
for i in mask_idxs:
|
||||
token_ids[i] = tokenizer._token_mask_id
|
||||
output_ids = self.random_sample([token_ids, segment_ids], n, topp=topp) # 基于随机采样
|
||||
return [tokenizer.decode(ids) for ids in output_ids]
|
||||
|
||||
|
||||
synonyms_generator = SynonymsGenerator(start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen)
|
||||
|
||||
|
||||
def gen_synonyms(text, n=100, k=20):
|
||||
""""含义: 产生sent的n个相似句,然后返回最相似的k个。
|
||||
做法:用seq2seq生成,并用encoder算相似度并排序。
|
||||
"""
|
||||
r = synonyms_generator.generate(text, n)
|
||||
r = [i for i in set(r) if i != text]
|
||||
r = [text] + r
|
||||
X, S = [], []
|
||||
for t in r:
|
||||
x, s = tokenizer.encode(t)
|
||||
X.append(x)
|
||||
S.append(s)
|
||||
X = sequence_padding(X)
|
||||
S = sequence_padding(S)
|
||||
Z = encoder.predict([X, S])
|
||||
Z /= (Z**2).sum(axis=1, keepdims=True)**0.5
|
||||
argsort = np.dot(Z[1:], -Z[0]).argsort()
|
||||
return [r[i + 1] for i in argsort[:k]]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
datas = [{"text": "平乐县,古称昭州,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭城,总面积1919.34平方公里。"},
|
||||
{"text": "平乐县主要旅游景点有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐县为漓江分界点,平乐以北称漓江,以南称桂江,是著名的大桂林旅游区之一。"},
|
||||
{"text": "印岭玲珑,昭水晶莹,环绕我平中。青年的乐园,多士受陶熔。生活自觉自治,学习自发自动。五育并重,手脑并用。迎接新潮流,建设新平中"},
|
||||
{"text": "桂林山水甲天下, 阳朔山水甲桂林"},
|
||||
{"text": "三国一统天下"},
|
||||
{"text": "世间万物皆系于其上"},
|
||||
{"text": "2020年真是一个糟糕的年代, 进入20年代,新冠爆发、经济下行,什么的都来了。"},
|
||||
{"text": "仿佛一切都变得不那么重要了。"},
|
||||
{"text": "苹果多少钱一斤"}
|
||||
]
|
||||
time_start = time.time()
|
||||
for da in datas:
|
||||
text = da.get("text", "")
|
||||
res = gen_synonyms(text)
|
||||
print(res)
|
||||
time_total = time.time()-time_start
|
||||
print("time_total:{}".format(time_total))
|
||||
print("time_per:{}".format(time_total/len(datas)/20))
|
||||
|
||||
while True:
|
||||
print("请输入:")
|
||||
text = input()
|
||||
res = gen_synonyms(text)
|
||||
print(res)
|
||||
|
||||
['平乐县古称昭州,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,西毗邻阳朔,北连恭城',
|
||||
'平乐县,古称昭州,隶属于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭城,总面积1919.',
|
||||
'平乐县隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭',
|
||||
'平乐县,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连',
|
||||
'平乐县,属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭',
|
||||
'平乐县,古称昭州,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平',
|
||||
'平乐县,古称昭州,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平县。',
|
||||
'平乐县,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔',
|
||||
'平乐县,古称昭州,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭城,总面积1919.34',
|
||||
'平乐县,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔。',
|
||||
'平乐县,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平县,西北毗邻阳朔县。',
|
||||
'平乐县隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔',
|
||||
'平乐县隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔。',
|
||||
'昭平县,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连',
|
||||
'昭乐县隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔。',
|
||||
'昭州,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭',
|
||||
'平乐县,古称昭州,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县',
|
||||
'广西壮族自治区桂林市,隶属于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭城,总面积1919',
|
||||
'平乐县,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平',
|
||||
'广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭城,总面积1919']
|
||||
|
||||
['榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐县为漓江分界点,平乐以北称漓江。',
|
||||
'景点主要就是榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐县为漓江分界点',
|
||||
'主要景点有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐以北称漓江,以南称桂',
|
||||
'景点有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐以北称漓江,以南称桂江',
|
||||
'榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐县为漓江分界点',
|
||||
'榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐县为漓江分界点。',
|
||||
'这个景点主要有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐以北称漓江,以南称',
|
||||
'1. 榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐县为漓江分界点。',
|
||||
'第一,平乐县的主要旅游景点有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等。',
|
||||
'平乐县主要旅游景点有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等。',
|
||||
'学校主要旅游景点有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐以北称漓江。',
|
||||
'1、榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐以北称漓江,以南称桂江。',
|
||||
'宜宾市榕津千年古榕、冷水石景苑、仙家温泉、漓江风景区等,普洱山为漓江分界点,平乐以北称漓江。',
|
||||
'花开湖畔的景观特色:榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐以北称漓江',
|
||||
'广州市榕津千年古榕、冷水石景苑、仙家温泉、漓江风景区、漓江风景区等,平乐县为漓江分界点。',
|
||||
'该景区主要旅游景点有:榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等。',
|
||||
'郴州市永化区榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,桂江县为漓江分界点,平',
|
||||
'公益群岛水平县主要旅游景点有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等',
|
||||
'此处主要是桂江最为优秀的桂江旅游区,主要是榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区等。',
|
||||
'桂江“大桂林”景点”的三角半南边有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等。']
|
135
AugmentText/augment_simbert/enhance_simbert.py
Normal file
135
AugmentText/augment_simbert/enhance_simbert.py
Normal file
@ -0,0 +1,135 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# !/usr/bin/python
|
||||
# @time :2019/4/9 23:05
|
||||
# @author :Mo
|
||||
# @function :SimBERT再训练BERT-base(NSP任务), UNILM的生成能力
|
||||
# @reference:https://github.com/bojone/bert4keras
|
||||
# 目前仅保证支持 Tensorflow 1.x + Keras <= 2.3.1 + bert4keras>=0.10.6。
|
||||
# 具体用法请看 https://github.com/bojone/bert4keras/blob/8ffb46a16a79f87aa8cdf045df7994036b4be47d/bert4keras/snippets.py#L580
|
||||
|
||||
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
from bert4keras.snippets import sequence_padding, AutoRegressiveDecoder
|
||||
from bert4keras.models import build_transformer_model
|
||||
from bert4keras.tokenizers import Tokenizer
|
||||
from bert4keras.backend import keras, K
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
# bert配置
|
||||
# BERT_DIR = "D:/soft_install/dataset/bert-model/chinese_L-12_H-768_A-12"
|
||||
# BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_roberta_L-4_H-312_A-12_K-104"
|
||||
# BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_roberta_L-6_H-384_A-12_K-128"
|
||||
# BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_simbert_L-4_H-312_A-12"
|
||||
# BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_simbert_L-12_H-768_A-12"
|
||||
BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_simbert_L-6_H-384_A-12"
|
||||
|
||||
config_path = BERT_DIR + "/bert_config.json"
|
||||
checkpoint_path = BERT_DIR + "/bert_model.ckpt"
|
||||
dict_path = BERT_DIR + "/vocab.txt"
|
||||
maxlen = 128
|
||||
|
||||
|
||||
# 建立分词器
|
||||
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
|
||||
|
||||
# 建立加载模型
|
||||
bert = build_transformer_model(
|
||||
config_path,
|
||||
checkpoint_path,
|
||||
with_pool='linear',
|
||||
application='unilm',
|
||||
return_keras_model=False,
|
||||
)
|
||||
|
||||
encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])
|
||||
seq2seq = keras.models.Model(bert.model.inputs, bert.model.outputs[1])
|
||||
|
||||
|
||||
# class SynonymsGenerator(AutoRegressiveDecoder):
|
||||
# """seq2seq解码器
|
||||
# """
|
||||
# @AutoRegressiveDecoder.set_rtype('probas')
|
||||
# def predict(self, inputs, output_ids, states):
|
||||
# token_ids, segment_ids = inputs
|
||||
# token_ids = np.concatenate([token_ids, output_ids], 1)
|
||||
# segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
|
||||
# # return self.last_token(seq2seq).predict([token_ids, segment_ids])
|
||||
# return seq2seq.predict([token_ids, segment_ids])[:, -1]
|
||||
#
|
||||
# def generate(self, text, n=1, topp=0.95):
|
||||
# token_ids, segment_ids = tokenizer.encode(text, max_length=maxlen)
|
||||
# output_ids = self.random_sample([token_ids, segment_ids], n, topp=topp) # 基于随机采样
|
||||
# return [tokenizer.decode(ids) for ids in output_ids]
|
||||
|
||||
|
||||
class SynonymsGenerator(AutoRegressiveDecoder):
|
||||
"""seq2seq解码器
|
||||
"""
|
||||
@AutoRegressiveDecoder.wraps(default_rtype='probas')
|
||||
def predict(self, inputs, output_ids, step):
|
||||
token_ids, segment_ids = inputs
|
||||
token_ids = np.concatenate([token_ids, output_ids], 1)
|
||||
segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
|
||||
return self.last_token(seq2seq).predict([token_ids, segment_ids])
|
||||
|
||||
def generate(self, text, n=1, topp=0.95, mask_idxs=[]):
|
||||
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
|
||||
for i in mask_idxs:
|
||||
token_ids[i] = tokenizer._token_mask_id
|
||||
output_ids = self.random_sample([token_ids, segment_ids], n, topp=topp) # 基于随机采样
|
||||
return [tokenizer.decode(ids) for ids in output_ids]
|
||||
|
||||
|
||||
synonyms_generator = SynonymsGenerator(start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen)
|
||||
|
||||
|
||||
def gen_synonyms(text, n=100, k=20):
|
||||
""""含义: 产生sent的n个相似句,然后返回最相似的k个。
|
||||
做法:用seq2seq生成,并用encoder算相似度并排序。
|
||||
"""
|
||||
r = synonyms_generator.generate(text, n)
|
||||
r = [i for i in set(r) if i != text]
|
||||
r = [text] + r
|
||||
X, S = [], []
|
||||
for t in r:
|
||||
x, s = tokenizer.encode(t)
|
||||
X.append(x)
|
||||
S.append(s)
|
||||
X = sequence_padding(X)
|
||||
S = sequence_padding(S)
|
||||
Z = encoder.predict([X, S])
|
||||
Z /= (Z**2).sum(axis=1, keepdims=True)**0.5
|
||||
argsort = np.dot(Z[1:], -Z[0]).argsort()
|
||||
return [r[i + 1] for i in argsort[:k]]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
datas = [{"text": "平乐县,古称昭州,隶属于广西壮族自治区桂林市,位于广西东北部,桂林市东南部,东临钟山县,南接昭平,西北毗邻阳朔,北连恭城,总面积1919.34平方公里。"},
|
||||
{"text": "平乐县主要旅游景点有榕津千年古榕、冷水石景苑、仙家温泉、桂江风景区、漓江风景区等,平乐县为漓江分界点,平乐以北称漓江,以南称桂江,是著名的大桂林旅游区之一。"},
|
||||
{"text": "印岭玲珑,昭水晶莹,环绕我平中。青年的乐园,多士受陶熔。生活自觉自治,学习自发自动。五育并重,手脑并用。迎接新潮流,建设新平中"},
|
||||
{"text": "桂林山水甲天下, 阳朔山水甲桂林"},
|
||||
{"text": "三国一统天下"},
|
||||
{"text": "世间万物皆系于其上"},
|
||||
{"text": "2020年真是一个糟糕的年代, 进入20年代,新冠爆发、经济下行,什么的都来了。"},
|
||||
{"text": "仿佛一切都变得不那么重要了。"},
|
||||
{"text": "苹果多少钱一斤"}
|
||||
]
|
||||
time_start = time.time()
|
||||
for da in datas:
|
||||
text = da.get("text", "")
|
||||
res = gen_synonyms(text)
|
||||
print(res)
|
||||
time_total = time.time() - time_start
|
||||
print("time_total:{}".format(time_total))
|
||||
print("time_per:{}".format(time_total / len(datas)))
|
||||
|
||||
while True:
|
||||
print("请输入:")
|
||||
text = input()
|
||||
res = gen_synonyms(text)
|
||||
print(res)
|
||||
|
||||
|
3
AugmentText/augment_simbert/requestments.txt
Normal file
3
AugmentText/augment_simbert/requestments.txt
Normal file
@ -0,0 +1,3 @@
|
||||
tensorflow==1.15.2
|
||||
bert4keras==0.10.7
|
||||
|
Loading…
Reference in New Issue
Block a user