update memory cells to tf 1.8

This commit is contained in:
Joerg Franke 2018-06-25 01:47:54 +02:00
parent b070341277
commit da7345df08
6 changed files with 7 additions and 12 deletions

View File

@ -80,7 +80,7 @@ class BaseMemoryUnitCell():
similarity = tf.squeeze(similarity) similarity = tf.squeeze(similarity)
adjusted_similarity = similarity * strengths adjusted_similarity = similarity * strengths
softmax_similarity = tf.nn.softmax(adjusted_similarity, dim=-1) softmax_similarity = tf.nn.softmax(adjusted_similarity, axis=-1)
return softmax_similarity return softmax_similarity

View File

@ -15,8 +15,7 @@
import tensorflow as tf import tensorflow as tf
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
from adnc.model.utils import oneplus from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
from adnc.model.utils import unit_simplex_initialization
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell): class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):

View File

@ -16,9 +16,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
from adnc.model.utils import layer_norm from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
from adnc.model.utils import oneplus
from adnc.model.utils import unit_simplex_initialization
class DNCMemoryUnitCell(BaseMemoryUnitCell): class DNCMemoryUnitCell(BaseMemoryUnitCell):

View File

@ -33,7 +33,7 @@ def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, s
if config['cell_type'] == 'dnc': if config['cell_type'] == 'dnc':
mu_cell = DNCMemoryUnitCell(input_size, memory_length, memory_width, read_heads, bypass_dropout=bypass_dropout, mu_cell = DNCMemoryUnitCell(input_size, memory_length, memory_width, read_heads, bypass_dropout=bypass_dropout,
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, name=name) dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, name=name)
elif config['cell_type'] == 'cmu': elif config['cell_type'] == 'cbmu':
mu_cell = ContentBasedMemoryUnitCell(input_size, memory_length, memory_width, read_heads, mu_cell = ContentBasedMemoryUnitCell(input_size, memory_length, memory_width, read_heads,
bypass_dropout=bypass_dropout, bypass_dropout=bypass_dropout,
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
@ -43,7 +43,7 @@ def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, s
bypass_dropout=bypass_dropout, bypass_dropout=bypass_dropout,
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
name=name) name=name)
elif config['cell_type'] == 'mwcmu' and 'write_heads' in config: elif config['cell_type'] == 'mwcbmu' and 'write_heads' in config:
mu_cell = MWContentMemoryUnitCell(input_size, memory_length, memory_width, read_heads, write_heads, mu_cell = MWContentMemoryUnitCell(input_size, memory_length, memory_width, read_heads, write_heads,
bypass_dropout=bypass_dropout, bypass_dropout=bypass_dropout,
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,

View File

@ -15,9 +15,7 @@
import tensorflow as tf import tensorflow as tf
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
from adnc.model.utils import layer_norm from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
from adnc.model.utils import oneplus
from adnc.model.utils import unit_simplex_initialization
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell): class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):

View File

@ -185,7 +185,7 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
read_strengths = oneplus(read_strengths) read_strengths = oneplus(read_strengths)
read_strengths = tf.expand_dims(read_strengths, axis=2) read_strengths = tf.expand_dims(read_strengths, axis=2)
read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 1 + 2 * self.h_WH]) read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 1 + 2 * self.h_WH])
read_modes = tf.nn.softmax(read_modes, dim=2) read_modes = tf.nn.softmax(read_modes, axis=2)
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \ return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
erase_vectors, read_keys, read_strengths, read_modes erase_vectors, read_keys, read_strengths, read_modes