dnc-jupyter/repeat-copy/rnn_controller.py

27 lines
887 B
Python
Raw Permalink Normal View History

2017-02-22 03:51:00 +08:00
import numpy as np
import tensorflow as tf
from controller import Controller
2017-02-22 05:38:44 +08:00
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
2017-02-22 03:51:00 +08:00
"""
2017-03-01 11:11:13 +08:00
RNN (cell type LSTM) with 128 hidden layers
2017-02-22 03:51:00 +08:00
"""
class RNNController(Controller):
def init_controller_params(self):
2017-03-01 11:11:13 +08:00
self.rnn_dim = 128
2017-02-22 05:38:44 +08:00
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
2017-02-22 03:51:00 +08:00
def nn_step(self, X, state):
X = tf.convert_to_tensor(X)
2017-02-22 05:38:44 +08:00
return self.lstm_cell(X, state)
def update_state(self, update):
return tf.group(self.output.assign(update[0]), self.state.assign(update[1]))
2017-02-22 03:51:00 +08:00
2017-02-22 05:38:44 +08:00
def get_state(self):
return LSTMStateTuple(self.output, self.state)