dnc-jupyter/repeat-copy/nn_controller.py

36 lines
1.1 KiB
Python
Raw Normal View History

2017-02-20 05:11:57 +08:00
import numpy as np
import tensorflow as tf
2017-02-20 05:19:59 +08:00
from controller import Controller
2017-02-20 05:11:57 +08:00
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
"""
A 2-Layer feedforward neural network with 128, 256 nodes respectively
"""
class NNController(Controller):
def init_controller_params(self):
2017-02-22 03:51:00 +08:00
h1_dim = 128
h2_dim = 256
2017-02-20 05:11:57 +08:00
init = tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32)
2017-02-22 03:51:00 +08:00
self.params['W1'] = tf.get_variable("W1", [self.chi_dim, h1_dim], initializer=init)
self.params['b1'] = tf.get_variable("b1", [h1_dim], initializer=init)
self.params['W2'] = tf.get_variable("W2", [h1_dim, h2_dim], initializer=init)
self.params['b2'] = tf.get_variable("b2", [h2_dim], initializer=init)
2017-02-20 05:11:57 +08:00
def nn_step(self, X, state):
z1 = tf.matmul(X, self.params['W1']) + self.params['b1']
2017-02-22 03:51:00 +08:00
h1 = tf.nn.elu(z1)
2017-02-20 05:11:57 +08:00
z2 = tf.matmul(h1, self.params['W2']) + self.params['b2']
2017-02-22 03:51:00 +08:00
h2 = tf.nn.elu(z2)
2017-02-20 05:11:57 +08:00
return h2, state
2017-02-22 06:18:13 +08:00
def update_state(self, update):
return tf.group(tf.zeros(1), tf.zeros(1))
2017-02-22 05:38:44 +08:00
def get_state(self):
2017-02-20 05:11:57 +08:00
return LSTMStateTuple(tf.zeros(1), tf.zeros(1))