bugfix and update memory units

This commit is contained in:
Joerg Franke 2018-07-05 09:05:22 +02:00
parent 7d25a0f972
commit 03464bc3e2
5 changed files with 7 additions and 7 deletions

View File

@ -31,6 +31,7 @@ class BaseMemoryUnitCell():
self.h_N = memory_length
self.h_W = memory_width
self.h_RH = read_heads
self.h_B = 0 # batch size, will be set in call
self.dnc_norm = dnc_norm
self.bypass_dropout = bypass_dropout
@ -44,7 +45,7 @@ class BaseMemoryUnitCell():
pass
@abstractmethod
def zero_state(self):
def zero_state(self, batch_size, dtype=tf.float32):
pass
@property

View File

@ -71,7 +71,7 @@ class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
weighted_input = tf.matmul(inputs, w_x) + b_x
if self.dnc_norm:
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype)
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype)
return weighted_input

View File

@ -25,7 +25,7 @@ class DNCMemoryUnitCell(BaseMemoryUnitCell):
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
analyse, dtype, name)
self.h_B = 0 # will set in call
@property
def state_size(self):
@ -142,7 +142,7 @@ class DNCMemoryUnitCell(BaseMemoryUnitCell):
weighted_input = tf.matmul(inputs, w_x) + b_x
if self.dnc_norm:
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype,
collection='memory_unit')
return weighted_input

View File

@ -116,7 +116,7 @@ class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
weighted_input = tf.matmul(inputs, w_x) + b_x
if self.dnc_norm:
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype,
collection='memory_unit')
return weighted_input

View File

@ -28,7 +28,6 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
self.h_WH = write_heads
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
analyse, dtype, name)
self.h_B = 0 # will set in call
@property
def state_size(self):
@ -144,7 +143,7 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
weighted_input = tf.matmul(inputs, w_x) + b_x
if self.dnc_norm:
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype,
collection='memory_unit')
return weighted_input