2017-12-19 03:59:12 +08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import getopt
|
|
|
|
import sys
|
|
|
|
import os
|
|
|
|
import math
|
|
|
|
import time
|
|
|
|
import argparse
|
|
|
|
from visdom import Visdom
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.join('..', '..'))
|
|
|
|
|
|
|
|
import torch as T
|
|
|
|
from torch.autograd import Variable as var
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.optim as optim
|
|
|
|
|
2019-04-05 14:14:41 +08:00
|
|
|
from torch.nn.utils import clip_grad_norm_
|
2017-12-19 03:59:12 +08:00
|
|
|
|
|
|
|
from dnc.dnc import DNC
|
|
|
|
from dnc.sdnc import SDNC
|
|
|
|
from dnc.sam import SAM
|
|
|
|
from dnc.util import *
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer')
|
|
|
|
parser.add_argument('-input_size', type=int, default=6, help='dimension of input feature')
|
|
|
|
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
|
|
|
|
parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn')
|
|
|
|
parser.add_argument('-dropout', type=float, default=0, help='controller dropout')
|
|
|
|
parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory: dnc | sdnc | sam')
|
|
|
|
|
|
|
|
parser.add_argument('-nlayer', type=int, default=1, help='number of layers')
|
|
|
|
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
|
|
|
|
parser.add_argument('-lr', type=float, default=1e-4, help='initial learning rate')
|
|
|
|
parser.add_argument('-optim', type=str, default='adam', help='learning rule, supports adam|rmsprop')
|
|
|
|
parser.add_argument('-clip', type=float, default=50, help='gradient clipping')
|
|
|
|
|
|
|
|
parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
|
|
|
|
parser.add_argument('-mem_size', type=int, default=20, help='memory dimension')
|
|
|
|
parser.add_argument('-mem_slot', type=int, default=16, help='number of memory slots')
|
|
|
|
parser.add_argument('-read_heads', type=int, default=4, help='number of read heads')
|
|
|
|
parser.add_argument('-sparse_reads', type=int, default=10, help='number of sparse reads per read head')
|
|
|
|
parser.add_argument('-temporal_reads', type=int, default=2, help='number of temporal reads')
|
|
|
|
|
|
|
|
parser.add_argument('-sequence_max_length', type=int, default=1000, metavar='N', help='sequence_max_length')
|
|
|
|
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
|
|
|
|
2017-12-19 04:04:35 +08:00
|
|
|
parser.add_argument('-iterations', type=int, default=2000, metavar='N', help='total number of iteration')
|
2017-12-19 03:59:12 +08:00
|
|
|
parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency')
|
|
|
|
parser.add_argument('-check_freq', type=int, default=100, metavar='N', help='check point frequency')
|
|
|
|
parser.add_argument('-visdom', action='store_true', help='plot memory content on visdom per -summarize_freq steps')
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
print(args)
|
|
|
|
|
|
|
|
viz = Visdom()
|
|
|
|
# assert viz.check_connection()
|
|
|
|
|
|
|
|
if args.cuda != -1:
|
|
|
|
print('Using CUDA.')
|
|
|
|
T.manual_seed(1111)
|
|
|
|
else:
|
|
|
|
print('Using CPU.')
|
|
|
|
|
|
|
|
def llprint(message):
|
|
|
|
sys.stdout.write(message)
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
|
|
def onehot(x, n):
|
|
|
|
ret = np.zeros(n).astype(np.float32)
|
|
|
|
ret[x] = 1.0
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
def generate_data(length, size):
|
|
|
|
|
|
|
|
content = np.random.randint(0, size - 1, length)
|
|
|
|
|
|
|
|
seqlen = length + 1
|
|
|
|
x_seq_list = [float('nan')] * seqlen
|
|
|
|
sums = 0.0
|
|
|
|
sums_text = ""
|
|
|
|
for i in range(seqlen):
|
|
|
|
if (i < length):
|
|
|
|
x_seq_list[i] = onehot(content[i], size)
|
|
|
|
sums += content[i]
|
|
|
|
sums_text += str(content[i]) + " + "
|
|
|
|
else:
|
|
|
|
x_seq_list[i] = onehot(size - 1, size)
|
|
|
|
|
|
|
|
x_seq_list = np.array(x_seq_list)
|
|
|
|
x_seq_list = x_seq_list.reshape((1,) + x_seq_list.shape)
|
|
|
|
sums = np.array(sums)
|
|
|
|
sums = sums.reshape(1, 1, 1)
|
|
|
|
|
2019-07-25 13:57:54 +08:00
|
|
|
return cudavec(x_seq_list.astype(np.float32), gpu_id=args.cuda).float(), \
|
|
|
|
cudavec(sums.astype(np.float32), gpu_id=args.cuda).float(), \
|
|
|
|
sums_text
|
2017-12-19 03:59:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
def cross_entropy(prediction, target):
|
|
|
|
return (prediction - target) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
dirname = os.path.dirname(__file__)
|
|
|
|
ckpts_dir = os.path.join(dirname, 'checkpoints')
|
|
|
|
|
|
|
|
input_size = args.input_size
|
|
|
|
memory_type = args.memory_type
|
|
|
|
lr = args.lr
|
|
|
|
clip = args.clip
|
|
|
|
batch_size = args.batch_size
|
|
|
|
sequence_max_length = args.sequence_max_length
|
|
|
|
cuda = args.cuda
|
|
|
|
iterations = args.iterations
|
|
|
|
summarize_freq = args.summarize_freq
|
|
|
|
check_freq = args.check_freq
|
|
|
|
visdom = args.visdom
|
|
|
|
|
|
|
|
from_checkpoint = None
|
|
|
|
|
|
|
|
if args.memory_type == 'dnc':
|
|
|
|
rnn = DNC(
|
|
|
|
input_size=args.input_size,
|
|
|
|
hidden_size=args.nhid,
|
|
|
|
rnn_type=args.rnn_type,
|
|
|
|
num_layers=args.nlayer,
|
|
|
|
num_hidden_layers=args.nhlayer,
|
|
|
|
dropout=args.dropout,
|
|
|
|
nr_cells=args.mem_slot,
|
|
|
|
cell_size=args.mem_size,
|
|
|
|
read_heads=args.read_heads,
|
|
|
|
gpu_id=args.cuda,
|
|
|
|
debug=args.visdom,
|
|
|
|
batch_first=True,
|
|
|
|
independent_linears=True
|
|
|
|
)
|
|
|
|
elif args.memory_type == 'sdnc':
|
|
|
|
rnn = SDNC(
|
|
|
|
input_size=args.input_size,
|
|
|
|
hidden_size=args.nhid,
|
|
|
|
rnn_type=args.rnn_type,
|
|
|
|
num_layers=args.nlayer,
|
|
|
|
num_hidden_layers=args.nhlayer,
|
|
|
|
dropout=args.dropout,
|
|
|
|
nr_cells=args.mem_slot,
|
|
|
|
cell_size=args.mem_size,
|
|
|
|
sparse_reads=args.sparse_reads,
|
|
|
|
temporal_reads=args.temporal_reads,
|
|
|
|
read_heads=args.read_heads,
|
|
|
|
gpu_id=args.cuda,
|
|
|
|
debug=args.visdom,
|
|
|
|
batch_first=True,
|
|
|
|
independent_linears=False
|
|
|
|
)
|
|
|
|
elif args.memory_type == 'sam':
|
|
|
|
rnn = SAM(
|
|
|
|
input_size=args.input_size,
|
|
|
|
hidden_size=args.nhid,
|
|
|
|
rnn_type=args.rnn_type,
|
|
|
|
num_layers=args.nlayer,
|
|
|
|
num_hidden_layers=args.nhlayer,
|
|
|
|
dropout=args.dropout,
|
|
|
|
nr_cells=args.mem_slot,
|
|
|
|
cell_size=args.mem_size,
|
|
|
|
sparse_reads=args.sparse_reads,
|
|
|
|
read_heads=args.read_heads,
|
|
|
|
gpu_id=args.cuda,
|
|
|
|
debug=args.visdom,
|
|
|
|
batch_first=True,
|
|
|
|
independent_linears=False
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise Exception('Not recognized type of memory')
|
|
|
|
|
|
|
|
if args.cuda != -1:
|
|
|
|
rnn = rnn.cuda(args.cuda)
|
|
|
|
|
|
|
|
print(rnn)
|
|
|
|
|
|
|
|
last_save_losses = []
|
|
|
|
|
|
|
|
if args.optim == 'adam':
|
|
|
|
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
|
|
|
elif args.optim == 'adamax':
|
|
|
|
optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
|
|
|
elif args.optim == 'rmsprop':
|
|
|
|
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10) # 0.0001
|
|
|
|
elif args.optim == 'sgd':
|
|
|
|
optimizer = optim.SGD(rnn.parameters(), lr=args.lr) # 0.01
|
|
|
|
elif args.optim == 'adagrad':
|
|
|
|
optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr)
|
|
|
|
elif args.optim == 'adadelta':
|
|
|
|
optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
|
|
|
|
|
|
|
|
last_100_losses = []
|
|
|
|
|
|
|
|
(chx, mhx, rv) = (None, None, None)
|
|
|
|
for epoch in range(iterations + 1):
|
|
|
|
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# We use for training just (sequence_max_length / 10) examples
|
2017-12-19 16:49:06 +08:00
|
|
|
random_length = np.random.randint(2, (sequence_max_length) + 1)
|
2017-12-19 03:59:12 +08:00
|
|
|
input_data, target_output, sums_text = generate_data(random_length, input_size)
|
|
|
|
|
|
|
|
if rnn.debug:
|
|
|
|
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
|
|
|
else:
|
|
|
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
|
|
|
|
|
|
|
output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True)
|
|
|
|
loss = cross_entropy(output, target_output)
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
2019-04-05 14:14:41 +08:00
|
|
|
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
2017-12-19 03:59:12 +08:00
|
|
|
optimizer.step()
|
2019-07-25 13:57:54 +08:00
|
|
|
loss_value = loss.item()
|
2017-12-19 03:59:12 +08:00
|
|
|
|
|
|
|
# detach memory from graph
|
|
|
|
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
|
|
|
|
2017-12-19 16:49:06 +08:00
|
|
|
summarize = (epoch % summarize_freq == 0)
|
2017-12-19 03:59:12 +08:00
|
|
|
take_checkpoint = (epoch != 0) and (epoch % iterations == 0)
|
|
|
|
|
|
|
|
last_100_losses.append(loss_value)
|
|
|
|
|
2017-12-19 16:49:06 +08:00
|
|
|
if summarize:
|
2017-12-19 03:59:12 +08:00
|
|
|
llprint("\rIteration %d/%d" % (epoch, iterations))
|
|
|
|
llprint("\nAvg. Logistic Loss: %.4f\n" % (np.mean(last_100_losses)))
|
|
|
|
output = output.data.cpu().numpy()
|
|
|
|
print("Real value: ", ' = ' + str(int(target_output[0])))
|
|
|
|
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|
|
|
|
last_100_losses = []
|
|
|
|
|
|
|
|
if take_checkpoint:
|
|
|
|
llprint("\nSaving Checkpoint ... "),
|
|
|
|
check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch))
|
|
|
|
cur_weights = rnn.state_dict()
|
|
|
|
T.save(cur_weights, check_ptr)
|
|
|
|
llprint("Done!\n")
|
|
|
|
|
|
|
|
llprint("\nTesting generalization...\n")
|
|
|
|
|
|
|
|
rnn.eval()
|
|
|
|
|
2017-12-19 16:49:06 +08:00
|
|
|
for i in range(int((iterations + 1) / 10)):
|
2017-12-19 03:59:12 +08:00
|
|
|
llprint("\nIteration %d/%d" % (i, iterations))
|
|
|
|
# We test now the learned generalization using sequence_max_length examples
|
2017-12-19 16:49:06 +08:00
|
|
|
random_length = np.random.randint(2, int(sequence_max_length) * 10 + 1)
|
2017-12-19 03:59:12 +08:00
|
|
|
input_data, target_output, sums_text = generate_data(random_length, input_size)
|
|
|
|
|
|
|
|
if rnn.debug:
|
|
|
|
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
|
|
|
else:
|
|
|
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
|
|
|
|
|
|
|
output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True)
|
|
|
|
output = output.data.cpu().numpy()
|
|
|
|
print("\nReal value: ", ' = ' + str(int(target_output[0])))
|
|
|
|
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|