Merge read vectors correctly, option to not forward pass through memory, correctly apply linear layers

This commit is contained in:
ixaxaar 2017-11-10 14:30:37 +05:30
parent f7dc1b5aab
commit 67d9722231
5 changed files with 41 additions and 72 deletions

View File

@ -59,34 +59,23 @@ class DNC(nn.Module):
self.w = self.cell_size
self.r = self.read_heads
# input size
self.nn_input_size = self.r * self.w + self.input_size
self.nn_output_size = self.r * self.w + self.hidden_size
self.interface_size = (self.w * self.r) + (3 * self.w) + (5 * self.r) + 3
self.read_vectors_size = self.r * self.w
self.interface_size = self.read_vectors_size + (3 * self.w) + (5 * self.r) + 3
self.output_size = self.hidden_size
self.rnns = [[None] * self.num_hidden_layers] * self.num_layers
self.rnns = []
self.memories = []
for layer in range(self.num_layers):
# controllers for each layer
self.rnns.append([])
for hlayer in range(self.num_hidden_layers):
if self.rnn_type.lower() == 'rnn':
if hlayer == 0:
self.rnns[layer][hlayer] = nn.RNNCell(self.nn_input_size, self.output_size,bias=self.bias, nonlinearity=self.nonlinearity)
else:
self.rnns[layer][hlayer] = nn.RNNCell(self.output_size, self.output_size,bias=self.bias, nonlinearity=self.nonlinearity)
self.rnns[layer].append(nn.RNNCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias, nonlinearity=self.nonlinearity))
elif self.rnn_type.lower() == 'gru':
if hlayer == 0:
self.rnns[layer][hlayer] = nn.GRUCell(self.nn_input_size, self.output_size, bias=self.bias)
else:
self.rnns[layer][hlayer] = nn.GRUCell(self.output_size, self.output_size, bias=self.bias)
self.rnns[layer].append(nn.GRUCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias))
elif self.rnn_type.lower() == 'lstm':
if hlayer == 0:
self.rnns[layer][hlayer] = nn.LSTMCell(self.nn_input_size, self.output_size, bias=self.bias)
else:
self.rnns[layer][hlayer] = nn.LSTMCell(self.output_size, self.output_size, bias=self.bias)
self.rnns[layer].append(nn.LSTMCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias))
setattr(self, self.rnn_type.lower()+'_layer_' + str(layer) + '_' + str(hlayer), self.rnns[layer][hlayer])
# memories for each layer
if not self.share_memory:
@ -100,6 +89,7 @@ class DNC(nn.Module):
independent_linears=self.independent_linears
)
)
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
# only one memory shared by all layers
if self.share_memory:
@ -113,18 +103,12 @@ class DNC(nn.Module):
independent_linears=self.independent_linears
)
)
for layer in range(self.num_layers):
for hlayer in range(self.num_hidden_layers):
setattr(self, 'rnn_layer_' + str(layer) + '_' + str(hlayer), self.rnns[layer][hlayer])
if not self.share_memory:
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
if self.share_memory:
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
# final output layer
self.output_weights = nn.Linear(self.nn_input_size, self.input_size)
self.mem_out = nn.Linear(self.hidden_size, self.input_size)
self.read_vectors_weights = nn.Linear(self.read_vectors_size, self.output_size)
self.mem_out = nn.Linear(self.hidden_size, self.output_size)
self.output = nn.Linear(self.output_size, self.input_size)
self.dropout_layer = nn.Dropout(self.dropout)
if self.gpu_id != -1:
@ -183,19 +167,9 @@ class DNC(nn.Module):
debug_obj['usage_vector'].append(mhx['usage_vector'][0].unsqueeze(0).data.cpu().numpy())
return debug_obj
def _layer_forward(self, input, layer, hx=(None, None)):
def _layer_forward(self, input, layer, hx=(None, None), pass_through_memory=True):
(chx, mhx) = hx
if self.debug:
mem_debug = {
"memory": [],
"link_matrix": [],
"precedence": [],
"read_weights": [],
"write_weights": [],
"usage_vector": [],
}
layer_input = input
hchx = []
for hlayer in range(self.num_hidden_layers):
@ -210,20 +184,20 @@ class DNC(nn.Module):
output = self.dropout_layer(self.mem_out(layer_input))
# pass through memory
if self.share_memory:
read_vecs, mhx = self.memories[0](ξ, mhx)
if pass_through_memory:
if self.share_memory:
read_vecs, mhx = self.memories[0](ξ, mhx)
else:
read_vecs, mhx = self.memories[layer](ξ, mhx)
# the read vectors
read_vectors = read_vecs.view(-1, self.w * self.r)
else:
read_vecs, mhx = self.memories[layer](ξ, mhx)
read_vectors = None
# the read vectors
read_vectors = read_vecs.view(-1, self.w * self.r)
if self.debug:
return output, read_vectors, mem_debug, (chx, mhx)
else:
return output, read_vectors, (chx, mhx)
return output, read_vectors, (chx, mhx)
def forward(self, input, hx=(None, None, None), reset_experience=False):
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
# handle packed data
is_packed = type(input) is PackedSequence
if is_packed:
@ -235,9 +209,10 @@ class DNC(nn.Module):
batch_size = input.size(0) if self.batch_first else input.size(1)
# make the data batch-first
if not self.batch_first:
input = input.transpose(0, 1)
# make the data time-first
inputs = [ input[:, x, :] for x in range(max_length) ]
controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)
@ -245,10 +220,8 @@ class DNC(nn.Module):
if self.debug:
viz = None
# read_vectors = [last_read] * max_length
# outs = [input[:, x, :] for x in range(max_length)]
inputs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
outs = [None] * max_length
read_vectors = None
for time in range(max_length):
for layer in range(self.num_layers):
@ -256,10 +229,7 @@ class DNC(nn.Module):
chx = controller_hidden[layer]
m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller
if self.debug:
outs[time], read_vectors, mem_debug, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m))
else:
outs[time], read_vectors, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m))
outs[time], read_vectors, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m), pass_through_memory)
# debug memory
if self.debug:
@ -272,23 +242,22 @@ class DNC(nn.Module):
mem_hidden[layer] = m
controller_hidden[layer] = chx
# the controller output + read vectors go into next layer
outs[time] = (T.cat([outs[time], read_vectors], 1))
if read_vectors is not None:
# the controller output + read vectors go into next layer
outs[time] = outs[time] + self.read_vectors_weights(read_vectors)
inputs[time] = outs[time]
if self.debug:
viz = { k: np.array(v) for k,v in viz.items() }
viz = { k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k,v in viz.items() }
inputs = [ self.output_weights(i) for i in inputs ]
outputs = T.stack(inputs, 1)
inputs = [ self.output(i) for i in inputs ]
outputs = T.stack(inputs, 1 if self.batch_first else 0)
if not self.batch_first:
outputs = outputs.transpose(0, 1)
if is_packed:
outputs = pack(output, lengths)
if self.debug:
return outputs, (controller_hidden, mem_hidden, read_vectors[-1]), viz
return outputs, (controller_hidden, mem_hidden, read_vectors), viz
else:
return outputs, (controller_hidden, mem_hidden, read_vectors[-1])
return outputs, (controller_hidden, mem_hidden, read_vectors)

View File

@ -140,7 +140,7 @@ if __name__ == '__main__':
# input_data = input_data.transpose(0, 1).contiguous()
target_output = target_output.transpose(0, 1).contiguous()
output, (chx, mhx, rv), v = rnn(input_data, None)
output, (chx, mhx, rv), v = rnn(input_data, None, pass_through_memory=True)
# dncs operate batch first
output = output.transpose(0, 1)

View File

@ -74,7 +74,7 @@ def test_rnn_1():
assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([1])
assert rv.size() == T.Size([10, 1])
def test_rnn_n():
@ -130,4 +130,4 @@ def test_rnn_n():
assert target_output.size() == T.Size([27, 10, 100])
assert chx[1][2].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([51])
assert rv.size() == T.Size([10, 51])

View File

@ -74,7 +74,7 @@ def test_rnn_1():
assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0][0].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([1])
assert rv.size() == T.Size([10, 1])
def test_rnn_n():
@ -130,4 +130,4 @@ def test_rnn_n():
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0][0][0].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([51])
assert rv.size() == T.Size([10, 51])

View File

@ -74,7 +74,7 @@ def test_rnn_1():
assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([1])
assert rv.size() == T.Size([10, 1])
def test_rnn_n():
@ -130,4 +130,4 @@ def test_rnn_n():
assert target_output.size() == T.Size([27, 10, 100])
assert chx[1][2].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([51])
assert rv.size() == T.Size([10, 51])