init rnn hidden states
This commit is contained in:
parent
9734e9014e
commit
e4eb9a53e6
@ -13,6 +13,8 @@ from torch.nn.utils.rnn import PackedSequence
|
||||
from .util import *
|
||||
from .memory import *
|
||||
|
||||
from torch.nn.init import orthogonal, xavier_uniform
|
||||
|
||||
|
||||
class DNC(nn.Module):
|
||||
|
||||
@ -114,6 +116,7 @@ class DNC(nn.Module):
|
||||
|
||||
# final output layer
|
||||
self.output = nn.Linear(self.nn_output_size, self.input_size)
|
||||
orthogonal(self.output.weight)
|
||||
|
||||
if self.gpu_id != -1:
|
||||
[x.cuda(self.gpu_id) for x in self.rnns]
|
||||
@ -127,7 +130,8 @@ class DNC(nn.Module):
|
||||
|
||||
# initialize hidden state of the controller RNN
|
||||
if chx is None:
|
||||
chx = [None for x in range(self.num_layers)]
|
||||
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
|
||||
xavier_uniform(h)
|
||||
|
||||
# Last read vectors
|
||||
if last_read is None:
|
||||
|
@ -9,7 +9,7 @@ import numpy as np
|
||||
from torch.nn.utils.rnn import pad_packed_sequence as pad
|
||||
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.nn.init import orthogonal
|
||||
from torch.nn.init import orthogonal, xavier_uniform
|
||||
|
||||
from .util import *
|
||||
from .sparse_memory import SparseMemory
|
||||
@ -134,7 +134,10 @@ class SDNC(nn.Module):
|
||||
|
||||
# initialize hidden state of the controller RNN
|
||||
if chx is None:
|
||||
chx = [None for x in range(self.num_layers)]
|
||||
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
|
||||
xavier_uniform(h)
|
||||
|
||||
chx = [ (h, h) if self.rnn_type.lower() == 'lstm' else h for x in range(self.num_layers)]
|
||||
|
||||
# Last read vectors
|
||||
if last_read is None:
|
||||
|
Loading…
Reference in New Issue
Block a user