cleanup, print more info for repr

This commit is contained in:
ixaxaar 2017-12-19 01:28:39 +05:30
parent 7344cdeb10
commit 2721d27d16
2 changed files with 44 additions and 1 deletions

View File

@ -271,3 +271,46 @@ class DNC(nn.Module):
return outputs, (controller_hidden, mem_hidden, read_vectors), viz
else:
return outputs, (controller_hidden, mem_hidden, read_vectors)
def __repr__(self):
s = "\n----------------------------------------\n"
s += '{name}({input_size}, {hidden_size}'
if self.rnn_type != 'lstm':
s += ', rnn_type={rnn_type}'
if self.num_layers != 1:
s += ', num_layers={num_layers}'
if self.num_hidden_layers != 2:
s += ', num_hidden_layers={num_hidden_layers}'
if self.bias != True:
s += ', bias={bias}'
if self.batch_first != True:
s += ', batch_first={batch_first}'
if self.dropout != 0:
s += ', dropout={dropout}'
if self.bidirectional != False:
s += ', bidirectional={bidirectional}'
if self.nr_cells != 5:
s += ', nr_cells={nr_cells}'
if self.read_heads != 2:
s += ', read_heads={read_heads}'
if self.cell_size != 10:
s += ', cell_size={cell_size}'
if self.nonlinearity != 'tanh':
s += ', nonlinearity={nonlinearity}'
if self.gpu_id != -1:
s += ', gpu_id={gpu_id}'
if self.independent_linears != False:
s += ', independent_linears={independent_linears}'
if self.share_memory != True:
s += ', share_memory={share_memory}'
if self.debug != False:
s += ', debug={debug}'
if self.clip != 20:
s += ', clip={clip}'
s += ")\n" + super(DNC, self).__repr__() + \
"\n----------------------------------------\n"
return s.format(name=self.__class__.__name__, **self.__dict__)

View File

@ -134,7 +134,7 @@ class SparseMemory(nn.Module):
def write_into_sparse_memory(self, hidden):
visible_memory = hidden['visible_memory']
positions = hidden['read_positions'].squeeze()
positions = hidden['read_positions']
(b, m, w) = hidden['memory'].size()
# update memory