27 lines
886 B
Python
27 lines
886 B
Python
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
from controller import Controller
|
||
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
||
|
|
||
|
"""
|
||
|
RNN (cell type LSTM) with 128 hidden units
|
||
|
"""
|
||
|
|
||
|
class RNNController(Controller):
|
||
|
|
||
|
def init_controller_params(self):
|
||
|
self.rnn_dim = 128
|
||
|
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
|
||
|
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||
|
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||
|
|
||
|
def nn_step(self, X, state):
|
||
|
X = tf.convert_to_tensor(X)
|
||
|
return self.lstm_cell(X, state)
|
||
|
|
||
|
def update_state(self, update):
|
||
|
return tf.group(self.output.assign(update[0]), self.state.assign(update[1]))
|
||
|
|
||
|
def get_state(self):
|
||
|
return LSTMStateTuple(self.output, self.state)
|