dnc-jupyter/free-recall/rnn_controller.py
2017-03-01 09:29:54 -05:00

27 lines
886 B
Python
Executable File

import numpy as np
import tensorflow as tf
from controller import Controller
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
"""
RNN (cell type LSTM) with 128 hidden units
"""
class RNNController(Controller):
def init_controller_params(self):
self.rnn_dim = 128
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
def nn_step(self, X, state):
X = tf.convert_to_tensor(X)
return self.lstm_cell(X, state)
def update_state(self, update):
return tf.group(self.output.assign(update[0]), self.state.assign(update[1]))
def get_state(self):
return LSTMStateTuple(self.output, self.state)