diff --git a/adnc/model/memory_units/base_cell.py b/adnc/model/memory_units/base_cell.py index 1758cae..4aeebe6 100644 --- a/adnc/model/memory_units/base_cell.py +++ b/adnc/model/memory_units/base_cell.py @@ -80,7 +80,7 @@ class BaseMemoryUnitCell(): similarity = tf.squeeze(similarity) adjusted_similarity = similarity * strengths - softmax_similarity = tf.nn.softmax(adjusted_similarity, dim=-1) + softmax_similarity = tf.nn.softmax(adjusted_similarity, axis=-1) return softmax_similarity diff --git a/adnc/model/memory_units/content_based_cell.py b/adnc/model/memory_units/content_based_cell.py index 12af1ec..21b42d6 100755 --- a/adnc/model/memory_units/content_based_cell.py +++ b/adnc/model/memory_units/content_based_cell.py @@ -15,8 +15,7 @@ import tensorflow as tf from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell -from adnc.model.utils import oneplus -from adnc.model.utils import unit_simplex_initialization +from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization class ContentBasedMemoryUnitCell(DNCMemoryUnitCell): diff --git a/adnc/model/memory_units/dnc_cell.py b/adnc/model/memory_units/dnc_cell.py index 6e85edd..dd21b13 100755 --- a/adnc/model/memory_units/dnc_cell.py +++ b/adnc/model/memory_units/dnc_cell.py @@ -16,9 +16,7 @@ import numpy as np import tensorflow as tf from adnc.model.memory_units.base_cell import BaseMemoryUnitCell -from adnc.model.utils import layer_norm -from adnc.model.utils import oneplus -from adnc.model.utils import unit_simplex_initialization +from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization class DNCMemoryUnitCell(BaseMemoryUnitCell): diff --git a/adnc/model/memory_units/memory_unit.py b/adnc/model/memory_units/memory_unit.py index 344135e..e0ef025 100644 --- a/adnc/model/memory_units/memory_unit.py +++ b/adnc/model/memory_units/memory_unit.py @@ -33,7 +33,7 @@ def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, s if config['cell_type'] == 'dnc': mu_cell = DNCMemoryUnitCell(input_size, memory_length, memory_width, read_heads, bypass_dropout=bypass_dropout, dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, name=name) - elif config['cell_type'] == 'cmu': + elif config['cell_type'] == 'cbmu': mu_cell = ContentBasedMemoryUnitCell(input_size, memory_length, memory_width, read_heads, bypass_dropout=bypass_dropout, dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, @@ -43,7 +43,7 @@ def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, s bypass_dropout=bypass_dropout, dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, name=name) - elif config['cell_type'] == 'mwcmu' and 'write_heads' in config: + elif config['cell_type'] == 'mwcbmu' and 'write_heads' in config: mu_cell = MWContentMemoryUnitCell(input_size, memory_length, memory_width, read_heads, write_heads, bypass_dropout=bypass_dropout, dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, diff --git a/adnc/model/memory_units/multi_write_content_based_cell.py b/adnc/model/memory_units/multi_write_content_based_cell.py index 1b9628e..58515eb 100644 --- a/adnc/model/memory_units/multi_write_content_based_cell.py +++ b/adnc/model/memory_units/multi_write_content_based_cell.py @@ -15,9 +15,7 @@ import tensorflow as tf from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell -from adnc.model.utils import layer_norm -from adnc.model.utils import oneplus -from adnc.model.utils import unit_simplex_initialization +from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization class MWContentMemoryUnitCell(MWDNCMemoryUnitCell): diff --git a/adnc/model/memory_units/multi_write_dnc_cell.py b/adnc/model/memory_units/multi_write_dnc_cell.py index 819a898..8efd121 100644 --- a/adnc/model/memory_units/multi_write_dnc_cell.py +++ b/adnc/model/memory_units/multi_write_dnc_cell.py @@ -185,7 +185,7 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell): read_strengths = oneplus(read_strengths) read_strengths = tf.expand_dims(read_strengths, axis=2) read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 1 + 2 * self.h_WH]) - read_modes = tf.nn.softmax(read_modes, dim=2) + read_modes = tf.nn.softmax(read_modes, axis=2) return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \ erase_vectors, read_keys, read_strengths, read_modes