Merge read vectors correctly, option to not forward pass through memory, correctly apply linear layers
This commit is contained in:
parent
f7dc1b5aab
commit
67d9722231
99
dnc/dnc.py
99
dnc/dnc.py
@ -59,34 +59,23 @@ class DNC(nn.Module):
|
||||
self.w = self.cell_size
|
||||
self.r = self.read_heads
|
||||
|
||||
# input size
|
||||
self.nn_input_size = self.r * self.w + self.input_size
|
||||
self.nn_output_size = self.r * self.w + self.hidden_size
|
||||
|
||||
self.interface_size = (self.w * self.r) + (3 * self.w) + (5 * self.r) + 3
|
||||
self.read_vectors_size = self.r * self.w
|
||||
self.interface_size = self.read_vectors_size + (3 * self.w) + (5 * self.r) + 3
|
||||
self.output_size = self.hidden_size
|
||||
|
||||
self.rnns = [[None] * self.num_hidden_layers] * self.num_layers
|
||||
self.rnns = []
|
||||
self.memories = []
|
||||
|
||||
for layer in range(self.num_layers):
|
||||
# controllers for each layer
|
||||
self.rnns.append([])
|
||||
for hlayer in range(self.num_hidden_layers):
|
||||
if self.rnn_type.lower() == 'rnn':
|
||||
if hlayer == 0:
|
||||
self.rnns[layer][hlayer] = nn.RNNCell(self.nn_input_size, self.output_size,bias=self.bias, nonlinearity=self.nonlinearity)
|
||||
else:
|
||||
self.rnns[layer][hlayer] = nn.RNNCell(self.output_size, self.output_size,bias=self.bias, nonlinearity=self.nonlinearity)
|
||||
self.rnns[layer].append(nn.RNNCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias, nonlinearity=self.nonlinearity))
|
||||
elif self.rnn_type.lower() == 'gru':
|
||||
if hlayer == 0:
|
||||
self.rnns[layer][hlayer] = nn.GRUCell(self.nn_input_size, self.output_size, bias=self.bias)
|
||||
else:
|
||||
self.rnns[layer][hlayer] = nn.GRUCell(self.output_size, self.output_size, bias=self.bias)
|
||||
self.rnns[layer].append(nn.GRUCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias))
|
||||
elif self.rnn_type.lower() == 'lstm':
|
||||
if hlayer == 0:
|
||||
self.rnns[layer][hlayer] = nn.LSTMCell(self.nn_input_size, self.output_size, bias=self.bias)
|
||||
else:
|
||||
self.rnns[layer][hlayer] = nn.LSTMCell(self.output_size, self.output_size, bias=self.bias)
|
||||
self.rnns[layer].append(nn.LSTMCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias))
|
||||
setattr(self, self.rnn_type.lower()+'_layer_' + str(layer) + '_' + str(hlayer), self.rnns[layer][hlayer])
|
||||
|
||||
# memories for each layer
|
||||
if not self.share_memory:
|
||||
@ -100,6 +89,7 @@ class DNC(nn.Module):
|
||||
independent_linears=self.independent_linears
|
||||
)
|
||||
)
|
||||
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
|
||||
|
||||
# only one memory shared by all layers
|
||||
if self.share_memory:
|
||||
@ -113,18 +103,12 @@ class DNC(nn.Module):
|
||||
independent_linears=self.independent_linears
|
||||
)
|
||||
)
|
||||
|
||||
for layer in range(self.num_layers):
|
||||
for hlayer in range(self.num_hidden_layers):
|
||||
setattr(self, 'rnn_layer_' + str(layer) + '_' + str(hlayer), self.rnns[layer][hlayer])
|
||||
if not self.share_memory:
|
||||
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
|
||||
if self.share_memory:
|
||||
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
|
||||
|
||||
# final output layer
|
||||
self.output_weights = nn.Linear(self.nn_input_size, self.input_size)
|
||||
self.mem_out = nn.Linear(self.hidden_size, self.input_size)
|
||||
self.read_vectors_weights = nn.Linear(self.read_vectors_size, self.output_size)
|
||||
self.mem_out = nn.Linear(self.hidden_size, self.output_size)
|
||||
self.output = nn.Linear(self.output_size, self.input_size)
|
||||
self.dropout_layer = nn.Dropout(self.dropout)
|
||||
|
||||
if self.gpu_id != -1:
|
||||
@ -183,19 +167,9 @@ class DNC(nn.Module):
|
||||
debug_obj['usage_vector'].append(mhx['usage_vector'][0].unsqueeze(0).data.cpu().numpy())
|
||||
return debug_obj
|
||||
|
||||
def _layer_forward(self, input, layer, hx=(None, None)):
|
||||
def _layer_forward(self, input, layer, hx=(None, None), pass_through_memory=True):
|
||||
(chx, mhx) = hx
|
||||
|
||||
if self.debug:
|
||||
mem_debug = {
|
||||
"memory": [],
|
||||
"link_matrix": [],
|
||||
"precedence": [],
|
||||
"read_weights": [],
|
||||
"write_weights": [],
|
||||
"usage_vector": [],
|
||||
}
|
||||
|
||||
layer_input = input
|
||||
hchx = []
|
||||
for hlayer in range(self.num_hidden_layers):
|
||||
@ -210,20 +184,20 @@ class DNC(nn.Module):
|
||||
output = self.dropout_layer(self.mem_out(layer_input))
|
||||
|
||||
# pass through memory
|
||||
if self.share_memory:
|
||||
read_vecs, mhx = self.memories[0](ξ, mhx)
|
||||
if pass_through_memory:
|
||||
if self.share_memory:
|
||||
read_vecs, mhx = self.memories[0](ξ, mhx)
|
||||
else:
|
||||
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
||||
# the read vectors
|
||||
read_vectors = read_vecs.view(-1, self.w * self.r)
|
||||
else:
|
||||
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
||||
read_vectors = None
|
||||
|
||||
# the read vectors
|
||||
read_vectors = read_vecs.view(-1, self.w * self.r)
|
||||
|
||||
if self.debug:
|
||||
return output, read_vectors, mem_debug, (chx, mhx)
|
||||
else:
|
||||
return output, read_vectors, (chx, mhx)
|
||||
return output, read_vectors, (chx, mhx)
|
||||
|
||||
def forward(self, input, hx=(None, None, None), reset_experience=False):
|
||||
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
|
||||
# handle packed data
|
||||
is_packed = type(input) is PackedSequence
|
||||
if is_packed:
|
||||
@ -235,9 +209,10 @@ class DNC(nn.Module):
|
||||
|
||||
batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||
|
||||
# make the data batch-first
|
||||
if not self.batch_first:
|
||||
input = input.transpose(0, 1)
|
||||
# make the data time-first
|
||||
inputs = [ input[:, x, :] for x in range(max_length) ]
|
||||
|
||||
controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)
|
||||
|
||||
@ -245,10 +220,8 @@ class DNC(nn.Module):
|
||||
if self.debug:
|
||||
viz = None
|
||||
|
||||
# read_vectors = [last_read] * max_length
|
||||
# outs = [input[:, x, :] for x in range(max_length)]
|
||||
inputs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
|
||||
outs = [None] * max_length
|
||||
read_vectors = None
|
||||
|
||||
for time in range(max_length):
|
||||
for layer in range(self.num_layers):
|
||||
@ -256,10 +229,7 @@ class DNC(nn.Module):
|
||||
chx = controller_hidden[layer]
|
||||
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
||||
# pass through controller
|
||||
if self.debug:
|
||||
outs[time], read_vectors, mem_debug, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m))
|
||||
else:
|
||||
outs[time], read_vectors, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m))
|
||||
outs[time], read_vectors, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m), pass_through_memory)
|
||||
|
||||
# debug memory
|
||||
if self.debug:
|
||||
@ -272,23 +242,22 @@ class DNC(nn.Module):
|
||||
mem_hidden[layer] = m
|
||||
controller_hidden[layer] = chx
|
||||
|
||||
# the controller output + read vectors go into next layer
|
||||
outs[time] = (T.cat([outs[time], read_vectors], 1))
|
||||
if read_vectors is not None:
|
||||
# the controller output + read vectors go into next layer
|
||||
outs[time] = outs[time] + self.read_vectors_weights(read_vectors)
|
||||
inputs[time] = outs[time]
|
||||
|
||||
if self.debug:
|
||||
viz = { k: np.array(v) for k,v in viz.items() }
|
||||
viz = { k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k,v in viz.items() }
|
||||
|
||||
inputs = [ self.output_weights(i) for i in inputs ]
|
||||
outputs = T.stack(inputs, 1)
|
||||
inputs = [ self.output(i) for i in inputs ]
|
||||
outputs = T.stack(inputs, 1 if self.batch_first else 0)
|
||||
|
||||
if not self.batch_first:
|
||||
outputs = outputs.transpose(0, 1)
|
||||
if is_packed:
|
||||
outputs = pack(output, lengths)
|
||||
|
||||
if self.debug:
|
||||
return outputs, (controller_hidden, mem_hidden, read_vectors[-1]), viz
|
||||
return outputs, (controller_hidden, mem_hidden, read_vectors), viz
|
||||
else:
|
||||
return outputs, (controller_hidden, mem_hidden, read_vectors[-1])
|
||||
return outputs, (controller_hidden, mem_hidden, read_vectors)
|
||||
|
@ -140,7 +140,7 @@ if __name__ == '__main__':
|
||||
# input_data = input_data.transpose(0, 1).contiguous()
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None)
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None, pass_through_memory=True)
|
||||
# dncs operate batch first
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
|
@ -74,7 +74,7 @@ def test_rnn_1():
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
assert chx[0][0].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,1,1])
|
||||
assert rv.size() == T.Size([1])
|
||||
assert rv.size() == T.Size([10, 1])
|
||||
|
||||
|
||||
def test_rnn_n():
|
||||
@ -130,4 +130,4 @@ def test_rnn_n():
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[1][2].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([51])
|
||||
assert rv.size() == T.Size([10, 51])
|
||||
|
@ -74,7 +74,7 @@ def test_rnn_1():
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
assert chx[0][0][0].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,1,1])
|
||||
assert rv.size() == T.Size([1])
|
||||
assert rv.size() == T.Size([10, 1])
|
||||
|
||||
|
||||
def test_rnn_n():
|
||||
@ -130,4 +130,4 @@ def test_rnn_n():
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[0][0][0].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([51])
|
||||
assert rv.size() == T.Size([10, 51])
|
||||
|
@ -74,7 +74,7 @@ def test_rnn_1():
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
assert chx[0][0].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,1,1])
|
||||
assert rv.size() == T.Size([1])
|
||||
assert rv.size() == T.Size([10, 1])
|
||||
|
||||
|
||||
def test_rnn_n():
|
||||
@ -130,4 +130,4 @@ def test_rnn_n():
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[1][2].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([51])
|
||||
assert rv.size() == T.Size([10, 51])
|
||||
|
Loading…
Reference in New Issue
Block a user