cleanup, print more info for repr
This commit is contained in:
parent
7344cdeb10
commit
2721d27d16
43
dnc/dnc.py
43
dnc/dnc.py
@ -271,3 +271,46 @@ class DNC(nn.Module):
|
||||
return outputs, (controller_hidden, mem_hidden, read_vectors), viz
|
||||
else:
|
||||
return outputs, (controller_hidden, mem_hidden, read_vectors)
|
||||
|
||||
def __repr__(self):
|
||||
s = "\n----------------------------------------\n"
|
||||
s += '{name}({input_size}, {hidden_size}'
|
||||
if self.rnn_type != 'lstm':
|
||||
s += ', rnn_type={rnn_type}'
|
||||
if self.num_layers != 1:
|
||||
s += ', num_layers={num_layers}'
|
||||
if self.num_hidden_layers != 2:
|
||||
s += ', num_hidden_layers={num_hidden_layers}'
|
||||
if self.bias != True:
|
||||
s += ', bias={bias}'
|
||||
if self.batch_first != True:
|
||||
s += ', batch_first={batch_first}'
|
||||
if self.dropout != 0:
|
||||
s += ', dropout={dropout}'
|
||||
if self.bidirectional != False:
|
||||
s += ', bidirectional={bidirectional}'
|
||||
if self.nr_cells != 5:
|
||||
s += ', nr_cells={nr_cells}'
|
||||
if self.read_heads != 2:
|
||||
s += ', read_heads={read_heads}'
|
||||
if self.cell_size != 10:
|
||||
s += ', cell_size={cell_size}'
|
||||
if self.nonlinearity != 'tanh':
|
||||
s += ', nonlinearity={nonlinearity}'
|
||||
if self.gpu_id != -1:
|
||||
s += ', gpu_id={gpu_id}'
|
||||
if self.independent_linears != False:
|
||||
s += ', independent_linears={independent_linears}'
|
||||
if self.share_memory != True:
|
||||
s += ', share_memory={share_memory}'
|
||||
if self.debug != False:
|
||||
s += ', debug={debug}'
|
||||
if self.clip != 20:
|
||||
s += ', clip={clip}'
|
||||
|
||||
s += ")\n" + super(DNC, self).__repr__() + \
|
||||
"\n----------------------------------------\n"
|
||||
return s.format(name=self.__class__.__name__, **self.__dict__)
|
||||
|
||||
|
||||
|
||||
|
@ -134,7 +134,7 @@ class SparseMemory(nn.Module):
|
||||
|
||||
def write_into_sparse_memory(self, hidden):
|
||||
visible_memory = hidden['visible_memory']
|
||||
positions = hidden['read_positions'].squeeze()
|
||||
positions = hidden['read_positions']
|
||||
|
||||
(b, m, w) = hidden['memory'].size()
|
||||
# update memory
|
||||
|
Loading…
Reference in New Issue
Block a user