mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
bugfix and update memory units
This commit is contained in:
parent
7d25a0f972
commit
03464bc3e2
@ -31,6 +31,7 @@ class BaseMemoryUnitCell():
|
||||
self.h_N = memory_length
|
||||
self.h_W = memory_width
|
||||
self.h_RH = read_heads
|
||||
self.h_B = 0 # batch size, will be set in call
|
||||
|
||||
self.dnc_norm = dnc_norm
|
||||
self.bypass_dropout = bypass_dropout
|
||||
@ -44,7 +45,7 @@ class BaseMemoryUnitCell():
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def zero_state(self):
|
||||
def zero_state(self, batch_size, dtype=tf.float32):
|
||||
pass
|
||||
|
||||
@property
|
||||
|
@ -71,7 +71,7 @@ class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
|
||||
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||
|
||||
if self.dnc_norm:
|
||||
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype)
|
||||
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype)
|
||||
|
||||
return weighted_input
|
||||
|
||||
|
@ -25,7 +25,7 @@ class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||
|
||||
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
|
||||
analyse, dtype, name)
|
||||
self.h_B = 0 # will set in call
|
||||
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -142,7 +142,7 @@ class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||
|
||||
if self.dnc_norm:
|
||||
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
|
||||
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype,
|
||||
collection='memory_unit')
|
||||
return weighted_input
|
||||
|
||||
|
@ -116,7 +116,7 @@ class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
|
||||
|
||||
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||
if self.dnc_norm:
|
||||
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
|
||||
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype,
|
||||
collection='memory_unit')
|
||||
return weighted_input
|
||||
|
||||
|
@ -28,7 +28,6 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||
self.h_WH = write_heads
|
||||
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
|
||||
analyse, dtype, name)
|
||||
self.h_B = 0 # will set in call
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -144,7 +143,7 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||
|
||||
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||
if self.dnc_norm:
|
||||
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
|
||||
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype,
|
||||
collection='memory_unit')
|
||||
return weighted_input
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user