mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
update memory cells to tf 1.8
This commit is contained in:
parent
b070341277
commit
da7345df08
@ -80,7 +80,7 @@ class BaseMemoryUnitCell():
|
||||
similarity = tf.squeeze(similarity)
|
||||
adjusted_similarity = similarity * strengths
|
||||
|
||||
softmax_similarity = tf.nn.softmax(adjusted_similarity, dim=-1)
|
||||
softmax_similarity = tf.nn.softmax(adjusted_similarity, axis=-1)
|
||||
|
||||
return softmax_similarity
|
||||
|
||||
|
@ -15,8 +15,7 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
||||
from adnc.model.utils import oneplus
|
||||
from adnc.model.utils import unit_simplex_initialization
|
||||
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||
|
||||
|
||||
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
|
||||
|
@ -16,9 +16,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
|
||||
from adnc.model.utils import layer_norm
|
||||
from adnc.model.utils import oneplus
|
||||
from adnc.model.utils import unit_simplex_initialization
|
||||
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||
|
||||
|
||||
class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||
|
@ -33,7 +33,7 @@ def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, s
|
||||
if config['cell_type'] == 'dnc':
|
||||
mu_cell = DNCMemoryUnitCell(input_size, memory_length, memory_width, read_heads, bypass_dropout=bypass_dropout,
|
||||
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, name=name)
|
||||
elif config['cell_type'] == 'cmu':
|
||||
elif config['cell_type'] == 'cbmu':
|
||||
mu_cell = ContentBasedMemoryUnitCell(input_size, memory_length, memory_width, read_heads,
|
||||
bypass_dropout=bypass_dropout,
|
||||
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
|
||||
@ -43,7 +43,7 @@ def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, s
|
||||
bypass_dropout=bypass_dropout,
|
||||
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
|
||||
name=name)
|
||||
elif config['cell_type'] == 'mwcmu' and 'write_heads' in config:
|
||||
elif config['cell_type'] == 'mwcbmu' and 'write_heads' in config:
|
||||
mu_cell = MWContentMemoryUnitCell(input_size, memory_length, memory_width, read_heads, write_heads,
|
||||
bypass_dropout=bypass_dropout,
|
||||
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
|
||||
|
@ -15,9 +15,7 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
||||
from adnc.model.utils import layer_norm
|
||||
from adnc.model.utils import oneplus
|
||||
from adnc.model.utils import unit_simplex_initialization
|
||||
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||
|
||||
|
||||
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
|
||||
|
@ -185,7 +185,7 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||
read_strengths = oneplus(read_strengths)
|
||||
read_strengths = tf.expand_dims(read_strengths, axis=2)
|
||||
read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 1 + 2 * self.h_WH])
|
||||
read_modes = tf.nn.softmax(read_modes, dim=2)
|
||||
read_modes = tf.nn.softmax(read_modes, axis=2)
|
||||
|
||||
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
|
||||
erase_vectors, read_keys, read_strengths, read_modes
|
||||
|
Loading…
Reference in New Issue
Block a user