This commit is contained in:
loujie0822 2020-05-04 16:44:39 +08:00
parent 5f33f5bdc0
commit 540d6f61d2
4 changed files with 9 additions and 5 deletions

View File

@ -73,7 +73,7 @@ class Data:
self.HP_char_hidden_dim = 50
self.HP_hidden_dim = 128
self.HP_dropout = 0.5
self.HP_lstm_layer = 1
self.HP_lstm_layer = 2
self.HP_bilstm = True
self.HP_use_char = False
self.HP_gpu = True

View File

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
import warnings
warnings.filterwarnings("ignore")
import argparse
import copy
@ -537,6 +539,8 @@ if __name__ == '__main__':
with open(save_data_name, 'wb') as f:
pickle.dump(data, f)
set_seed(seed_num)
data.show_data_summary()
print('data.use_biword=', data.use_bigram)
print('data.HP_batch_size=', data.HP_batch_size)
train(data, save_model_dir, seg,debug=False)

View File

@ -62,8 +62,8 @@ def get_args():
parser.add_argument("--max_len", default=1000, type=int)
parser.add_argument('--word_emb_dim', type=int, default=300)
parser.add_argument('--char_emb_dim', type=int, default=300)
parser.add_argument('--hidden_size', type=int, default=300)
parser.add_argument('--num_layers', type=int, default=1)
parser.add_argument('--hidden_size', type=int, default=150)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--dropout', type=int, default=0.5)
parser.add_argument('--rnn_encoder', type=str, default='lstm', help="must choose in blow: lstm or gru")
parser.add_argument('--bidirectional', type=bool, default=True)

View File

@ -9,7 +9,7 @@ import torch
from torch import nn
from tqdm import tqdm
import models.ner_net.lstm_crf_v2 as ner
import models.ner_net.lstm_crf as ner
from models.ner_net.tener import TENER
from utils.metrics import SpanFPreRecMetric
from utils.optimizer_util import set_optimizer
@ -73,7 +73,7 @@ class Trainer(object):
step_gap = int(int(len(self.eval_file_choice['train']) / args.train_batch_size) / 20)
for epoch in range(int(args.epoch_num)):
self.optimizer = lr_decay(self.optimizer, epoch, 0.05, args.learning_rate)
# self.optimizer = lr_decay(self.optimizer, epoch, 0.05, args.learning_rate)
global_loss = 0.0