init rnn hidden states
This commit is contained in:
parent
9734e9014e
commit
e4eb9a53e6
@ -13,6 +13,8 @@ from torch.nn.utils.rnn import PackedSequence
|
|||||||
from .util import *
|
from .util import *
|
||||||
from .memory import *
|
from .memory import *
|
||||||
|
|
||||||
|
from torch.nn.init import orthogonal, xavier_uniform
|
||||||
|
|
||||||
|
|
||||||
class DNC(nn.Module):
|
class DNC(nn.Module):
|
||||||
|
|
||||||
@ -114,6 +116,7 @@ class DNC(nn.Module):
|
|||||||
|
|
||||||
# final output layer
|
# final output layer
|
||||||
self.output = nn.Linear(self.nn_output_size, self.input_size)
|
self.output = nn.Linear(self.nn_output_size, self.input_size)
|
||||||
|
orthogonal(self.output.weight)
|
||||||
|
|
||||||
if self.gpu_id != -1:
|
if self.gpu_id != -1:
|
||||||
[x.cuda(self.gpu_id) for x in self.rnns]
|
[x.cuda(self.gpu_id) for x in self.rnns]
|
||||||
@ -127,7 +130,8 @@ class DNC(nn.Module):
|
|||||||
|
|
||||||
# initialize hidden state of the controller RNN
|
# initialize hidden state of the controller RNN
|
||||||
if chx is None:
|
if chx is None:
|
||||||
chx = [None for x in range(self.num_layers)]
|
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
|
||||||
|
xavier_uniform(h)
|
||||||
|
|
||||||
# Last read vectors
|
# Last read vectors
|
||||||
if last_read is None:
|
if last_read is None:
|
||||||
|
@ -9,7 +9,7 @@ import numpy as np
|
|||||||
from torch.nn.utils.rnn import pad_packed_sequence as pad
|
from torch.nn.utils.rnn import pad_packed_sequence as pad
|
||||||
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
||||||
from torch.nn.utils.rnn import PackedSequence
|
from torch.nn.utils.rnn import PackedSequence
|
||||||
from torch.nn.init import orthogonal
|
from torch.nn.init import orthogonal, xavier_uniform
|
||||||
|
|
||||||
from .util import *
|
from .util import *
|
||||||
from .sparse_memory import SparseMemory
|
from .sparse_memory import SparseMemory
|
||||||
@ -134,7 +134,10 @@ class SDNC(nn.Module):
|
|||||||
|
|
||||||
# initialize hidden state of the controller RNN
|
# initialize hidden state of the controller RNN
|
||||||
if chx is None:
|
if chx is None:
|
||||||
chx = [None for x in range(self.num_layers)]
|
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
|
||||||
|
xavier_uniform(h)
|
||||||
|
|
||||||
|
chx = [ (h, h) if self.rnn_type.lower() == 'lstm' else h for x in range(self.num_layers)]
|
||||||
|
|
||||||
# Last read vectors
|
# Last read vectors
|
||||||
if last_read is None:
|
if last_read is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user