mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
update memory cells to tf 1.8
This commit is contained in:
parent
b070341277
commit
da7345df08
@ -80,7 +80,7 @@ class BaseMemoryUnitCell():
|
|||||||
similarity = tf.squeeze(similarity)
|
similarity = tf.squeeze(similarity)
|
||||||
adjusted_similarity = similarity * strengths
|
adjusted_similarity = similarity * strengths
|
||||||
|
|
||||||
softmax_similarity = tf.nn.softmax(adjusted_similarity, dim=-1)
|
softmax_similarity = tf.nn.softmax(adjusted_similarity, axis=-1)
|
||||||
|
|
||||||
return softmax_similarity
|
return softmax_similarity
|
||||||
|
|
||||||
|
@ -15,8 +15,7 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
||||||
from adnc.model.utils import oneplus
|
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||||
from adnc.model.utils import unit_simplex_initialization
|
|
||||||
|
|
||||||
|
|
||||||
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
|
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
|
||||||
|
@ -16,9 +16,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
|
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
|
||||||
from adnc.model.utils import layer_norm
|
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||||
from adnc.model.utils import oneplus
|
|
||||||
from adnc.model.utils import unit_simplex_initialization
|
|
||||||
|
|
||||||
|
|
||||||
class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||||
|
@ -33,7 +33,7 @@ def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, s
|
|||||||
if config['cell_type'] == 'dnc':
|
if config['cell_type'] == 'dnc':
|
||||||
mu_cell = DNCMemoryUnitCell(input_size, memory_length, memory_width, read_heads, bypass_dropout=bypass_dropout,
|
mu_cell = DNCMemoryUnitCell(input_size, memory_length, memory_width, read_heads, bypass_dropout=bypass_dropout,
|
||||||
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, name=name)
|
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype, name=name)
|
||||||
elif config['cell_type'] == 'cmu':
|
elif config['cell_type'] == 'cbmu':
|
||||||
mu_cell = ContentBasedMemoryUnitCell(input_size, memory_length, memory_width, read_heads,
|
mu_cell = ContentBasedMemoryUnitCell(input_size, memory_length, memory_width, read_heads,
|
||||||
bypass_dropout=bypass_dropout,
|
bypass_dropout=bypass_dropout,
|
||||||
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
|
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
|
||||||
@ -43,7 +43,7 @@ def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, s
|
|||||||
bypass_dropout=bypass_dropout,
|
bypass_dropout=bypass_dropout,
|
||||||
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
|
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
|
||||||
name=name)
|
name=name)
|
||||||
elif config['cell_type'] == 'mwcmu' and 'write_heads' in config:
|
elif config['cell_type'] == 'mwcbmu' and 'write_heads' in config:
|
||||||
mu_cell = MWContentMemoryUnitCell(input_size, memory_length, memory_width, read_heads, write_heads,
|
mu_cell = MWContentMemoryUnitCell(input_size, memory_length, memory_width, read_heads, write_heads,
|
||||||
bypass_dropout=bypass_dropout,
|
bypass_dropout=bypass_dropout,
|
||||||
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
|
dnc_norm=dnc_norm, seed=seed, reuse=reuse, analyse=analyse, dtype=dtype,
|
||||||
|
@ -15,9 +15,7 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
||||||
from adnc.model.utils import layer_norm
|
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||||
from adnc.model.utils import oneplus
|
|
||||||
from adnc.model.utils import unit_simplex_initialization
|
|
||||||
|
|
||||||
|
|
||||||
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
|
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
|
||||||
|
@ -185,7 +185,7 @@ class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
|||||||
read_strengths = oneplus(read_strengths)
|
read_strengths = oneplus(read_strengths)
|
||||||
read_strengths = tf.expand_dims(read_strengths, axis=2)
|
read_strengths = tf.expand_dims(read_strengths, axis=2)
|
||||||
read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 1 + 2 * self.h_WH])
|
read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 1 + 2 * self.h_WH])
|
||||||
read_modes = tf.nn.softmax(read_modes, dim=2)
|
read_modes = tf.nn.softmax(read_modes, axis=2)
|
||||||
|
|
||||||
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
|
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
|
||||||
erase_vectors, read_keys, read_strengths, read_modes
|
erase_vectors, read_keys, read_strengths, read_modes
|
||||||
|
Loading…
Reference in New Issue
Block a user