dnc-with-demon/Models/Demon.py
2022-11-05 14:59:40 -07:00

141 lines
4.4 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.distributions import Normal
from collections import namedtuple
LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
EPSILON = 1e-6
SavedAction = namedtuple("SavedAction", ["action", "log_prob", "mean"])
def linear_reset(module, gain=1.0):
assert isinstance(module, torch.nn.Linear)
init.xavier_uniform_(module.weight, gain=gain)
s = module.weight.size(1)
if module.bias is not None:
module.bias.data.zero_()
class ZNet(nn.Module):
def __init__(self):
super(ZNet, self).__init__()
def reset_parameters(self):
for module in self.lstm:
if isinstance(module, torch.nn.Linear):
linear_reset(module, gain=init.calculate_gain("relu"))
for module in self.hidden2z:
if isinstance(module, torch.nn.Linear):
linear_reset(module, gain=init.calculate_gain("relu"))
def init(self, input_size):
self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
self.hidden2z = nn.Sequential(nn.Linear(32, 1))
self.reset_parameters()
def forward(self, data):
output, (hn, cn) = self.lstm(data)
zvals = self.hidden2z(output)
return F.softplus(zvals)
class FNet(nn.Module):
def __init__(self):
super(FNet, self).__init__()
def reset_parameters(self):
for module in self.lstm:
if isinstance(module, torch.nn.Linear):
linear_reset(module, gain=init.calculate_gain("relu"))
for module in self.hidden2z:
if isinstance(module, torch.nn.Linear):
linear_reset(module, gain=init.calculate_gain("relu"))
def init(self, input_size):
self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
self.hidden2z = nn.Sequential(nn.Linear(32, 1))
self.reset_parameters()
def forward(self, data):
output, (hn, cn) = self.lstm(data)
output = F.elu(output)
fvals = self.hidden2z(output)
return fvals
class Demon(torch.nn.Module):
"""
Demon manipulates the external memory of DNC.
"""
def __init__(self, layer_sizes=[]):
super(Demon, self).__init__()
self.layer_sizes = layer_sizes
self.action_scale = torch.tensor(1)
self.action_bias = torch.tensor(0.0)
self.saved_actions = []
def get_output_size(self):
return self.layer_sizes[-1]
def reset_parameters(self):
for module in self.model:
if isinstance(module, torch.nn.Linear):
linear_reset(module, gain=init.calculate_gain("relu"))
linear_reset(self.embed_mean, gain=init.calculate_gain("relu"))
linear_reset(self.embed_log_std, gain=init.calculate_gain("relu"))
def init(self, input_size, output_size):
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
# So use xavier init according to this.
self.input_sizes = [input_size] + self.layer_sizes[:-1]
layers = []
for i, size in enumerate(self.layer_sizes):
layers.append(nn.Linear(self.input_sizes[i], self.layer_sizes[i]))
layers.append(nn.ReLU())
self.model = nn.Sequential(*layers)
self.embed_mean = nn.Linear(self.layer_sizes[-1], output_size)
self.embed_log_std = nn.Linear(self.layer_sizes[-1], output_size)
self.reset_parameters()
def forward(self, data):
x = self.model(data)
x = F.relu(x)
mean, log_std = self.embed_mean(x), self.embed_log_std(x)
log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
std = torch.exp(log_std)
return mean, std
def act(self, data):
"""
pathwise derivative estimator for taking actions.
:param data:
:return:
"""
mean, std = self.forward(data)
normal = Normal(mean, std)
x = normal.rsample()
y = torch.softmax(x, dim=1)
action = y * self.action_scale + self.action_bias
log_prob = normal.log_prob(action)
# Enforcing Action Bound
log_prob -= torch.log(self.action_scale * (1 - y.pow(2)) + EPSILON)
log_prob = log_prob.sum(1, keepdim=True)
mean = torch.softmax(mean, dim=1) * self.action_scale + self.action_bias
self.saved_actions.append(SavedAction(action, log_prob, mean))
return mean