mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
bugfix and update memory units
This commit is contained in:
parent
7d25a0f972
commit
03464bc3e2
@ -31,6 +31,7 @@ class BaseMemoryUnitCell():
|
|||||||
self.h_N = memory_length
|
self.h_N = memory_length
|
||||||
self.h_W = memory_width
|
self.h_W = memory_width
|
||||||
self.h_RH = read_heads
|
self.h_RH = read_heads
|
||||||
|
self.h_B = 0 # batch size, will be set in call
|
||||||
|
|
||||||
self.dnc_norm = dnc_norm
|
self.dnc_norm = dnc_norm
|
||||||
self.bypass_dropout = bypass_dropout
|
self.bypass_dropout = bypass_dropout
|
||||||
@ -44,7 +45,7 @@ class BaseMemoryUnitCell():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def zero_state(self):
|
def zero_state(self, batch_size, dtype=tf.float32):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -71,7 +71,7 @@ class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
|
|||||||
weighted_input = tf.matmul(inputs, w_x) + b_x
|
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||||
|
|
||||||
if self.dnc_norm:
|
if self.dnc_norm:
|
||||||
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype)
|
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype)
|
||||||
|
|
||||||
return weighted_input
|
return weighted_input
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
|||||||
|
|
||||||
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
|
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
|
||||||
analyse, dtype, name)
|
analyse, dtype, name)
|
||||||
self.h_B = 0 # will set in call
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state_size(self):
|
def state_size(self):
|
||||||
@ -142,7 +142,7 @@ class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
|||||||
weighted_input = tf.matmul(inputs, w_x) + b_x
|
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||||
|
|
||||||
if self.dnc_norm:
|
if self.dnc_norm:
|
||||||
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
|
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype,
|
||||||
collection='memory_unit')
|
collection='memory_unit')
|
||||||
return weighted_input
|
return weighted_input
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
|
|||||||
|
|
||||||
weighted_input = tf.matmul(inputs, w_x) + b_x
|
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||||
if self.dnc_norm:
|
if self.dnc_norm:
|
||||||
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
|
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype,
|
||||||
collection='memory_unit')
|
collection='memory_unit')
|
||||||
return weighted_input
|
return weighted_input
|
||||||
|
|
||||||
|
@ -28,7 +28,6 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
|||||||
self.h_WH = write_heads
|
self.h_WH = write_heads
|
||||||
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
|
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
|
||||||
analyse, dtype, name)
|
analyse, dtype, name)
|
||||||
self.h_B = 0 # will set in call
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state_size(self):
|
def state_size(self):
|
||||||
@ -144,7 +143,7 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
|||||||
|
|
||||||
weighted_input = tf.matmul(inputs, w_x) + b_x
|
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||||
if self.dnc_norm:
|
if self.dnc_norm:
|
||||||
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
|
weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype,
|
||||||
collection='memory_unit')
|
collection='memory_unit')
|
||||||
return weighted_input
|
return weighted_input
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user