add lstm crf
This commit is contained in:
parent
338541b2c0
commit
beb5389aeb
@ -63,6 +63,7 @@ class NERNet(nn.Module):
|
|||||||
if args.soft_word:
|
if args.soft_word:
|
||||||
self.soft_word_emb = nn.Embedding(num_embeddings=5, embedding_dim=50, padding_idx=0)
|
self.soft_word_emb = nn.Embedding(num_embeddings=5, embedding_dim=50, padding_idx=0)
|
||||||
embed_size += 50
|
embed_size += 50
|
||||||
|
self.soft_word_emb.weight.requires_grad = False
|
||||||
|
|
||||||
self.sentence_encoder = SentenceEncoder(args, embed_size)
|
self.sentence_encoder = SentenceEncoder(args, embed_size)
|
||||||
self.emission = nn.Linear(args.hidden_size * 2, len(model_conf['entity_type']))
|
self.emission = nn.Linear(args.hidden_size * 2, len(model_conf['entity_type']))
|
||||||
|
@ -151,7 +151,7 @@ def main():
|
|||||||
|
|
||||||
logger.info("** ** * bulid dataset ** ** * ")
|
logger.info("** ** * bulid dataset ** ** * ")
|
||||||
|
|
||||||
eval_examples, data_loaders, model_conf = bulid_dataset(args, debug=False)
|
eval_examples, data_loaders, model_conf = bulid_dataset(args, debug=True)
|
||||||
|
|
||||||
trainer = Trainer(args, data_loaders, eval_examples, model_conf)
|
trainer = Trainer(args, data_loaders, eval_examples, model_conf)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user