添加readme
This commit is contained in:
parent
7ab44a43a6
commit
47ecaf6a3f
49
README.md
49
README.md
@ -1,16 +1,53 @@
|
|||||||
# bert_feature
|
# bert-utils
|
||||||
|
|
||||||
how to use Bert generate the sentence vector
|
本文对BERT进行了进一步的封装,方便生成句向量与做文本分类
|
||||||
|
|
||||||
1、download the model
|
1、下载BERT中文模型
|
||||||
|
|
||||||
model path: https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
|
下载地址: https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
|
||||||
|
|
||||||
2、Move the model in the same directory
|
2、把下载好的模型添加到当前目录下
|
||||||
|
|
||||||
3、init BertVector object and invokes the encode method, the param must be list
|
3、句向量生成
|
||||||
|
|
||||||
|
生成句向量不需要做fine tune,使用预先训练好的模型即可,可参考`extract_feature.py`的`main`方法,注意参数必须是一个list
|
||||||
```
|
```
|
||||||
from bert.extrac_feature import BertVector
|
from bert.extrac_feature import BertVector
|
||||||
bv = BertVector()
|
bv = BertVector()
|
||||||
bv.encode(['你好'])
|
bv.encode(['你好'])
|
||||||
```
|
```
|
||||||
|
|
||||||
|
4、文本分类
|
||||||
|
|
||||||
|
文本分类需要做fine tune,首先把数据准备好存放在`data`目录下,训练集的名字必须为`train.csv`,验证集的名字必须为`dev.csv`,测试集的名字必须为`test.csv`,
|
||||||
|
必须先调用`set_mode`方法,可参考`similarity.py`的`main`方法,
|
||||||
|
|
||||||
|
训练:
|
||||||
|
```
|
||||||
|
from similarity import BertSim
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
bs = BertSim()
|
||||||
|
bs.set_mode(tf.estimator.ModeKeys.TRAIN)
|
||||||
|
bs.train()
|
||||||
|
```
|
||||||
|
|
||||||
|
验证:
|
||||||
|
```
|
||||||
|
from similarity import BertSim
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
bs = BertSim()
|
||||||
|
bs.set_mode(tf.estimator.ModeKeys.EVAL)
|
||||||
|
bs.eval()
|
||||||
|
```
|
||||||
|
|
||||||
|
测试:
|
||||||
|
```
|
||||||
|
from similarity import BertSim
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
bs = BertSim()
|
||||||
|
bs.set_mode(tf.estimator.ModeKeys.PREDICT)
|
||||||
|
bs.test
|
||||||
|
```
|
@ -79,7 +79,7 @@ class SimProcessor(DataProcessor):
|
|||||||
return train_data
|
return train_data
|
||||||
|
|
||||||
def get_dev_examples(self, data_dir):
|
def get_dev_examples(self, data_dir):
|
||||||
file_path = os.path.join(data_dir, 'test.csv')
|
file_path = os.path.join(data_dir, 'dev.csv')
|
||||||
dev_df = pd.read_csv(file_path, encoding='utf-8')
|
dev_df = pd.read_csv(file_path, encoding='utf-8')
|
||||||
dev_data = []
|
dev_data = []
|
||||||
for index, dev in enumerate(dev_df.values):
|
for index, dev in enumerate(dev_df.values):
|
||||||
|
Loading…
Reference in New Issue
Block a user