dnc/Models/DNC.py
2018-11-15 20:31:23 +01:00

678 lines
27 KiB
Python

# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import torch.utils.data
import torch.nn.functional as F
import torch.nn.init as init
import functools
import math
def oneplus(t):
return F.softplus(t, 1, 20) + 1.0
def get_next_tensor_part(src, dims, prev_pos=0):
if not isinstance(dims, list):
dims=[dims]
n = functools.reduce(lambda x, y: x * y, dims)
data = src.narrow(-1, prev_pos, n)
return data.contiguous().view(list(data.size())[:-1] + dims) if len(dims)>1 else data, prev_pos + n
def split_tensor(src, shapes):
pos = 0
res = []
for s in shapes:
d, pos = get_next_tensor_part(src, s, pos)
res.append(d)
return res
def dict_get(dict,name):
return dict.get(name) if dict is not None else None
def dict_append(dict, name, val):
if dict is not None:
l = dict.get(name)
if not l:
l = []
dict[name] = l
l.append(val)
def init_debug(debug, initial):
if debug is not None and not debug:
debug.update(initial)
def merge_debug_tensors(d, dim):
if d is not None:
for k, v in d.items():
if isinstance(v, dict):
merge_debug_tensors(v, dim)
elif isinstance(v, list):
d[k] = torch.stack(v, dim)
def linear_reset(module, gain=1.0):
assert isinstance(module, torch.nn.Linear)
init.xavier_uniform_(module.weight, gain=gain)
s = module.weight.size(1)
if module.bias is not None:
module.bias.data.zero_()
_EPS = 1e-6
class AllocationManager(torch.nn.Module):
def __init__(self):
super(AllocationManager, self).__init__()
self.usages = None
self.zero_usages = None
self.debug_sequ_init = False
self.one = None
def _init_sequence(self, prev_read_distributions):
# prev_read_distributions size is [batch, n_heads, cell count]
s = prev_read_distributions.size()
if self.zero_usages is None or list(self.zero_usages.size())!=[s[0],s[-1]]:
self.zero_usages = torch.zeros(s[0], s[-1], device = prev_read_distributions.device)
if self.debug_sequ_init:
self.zero_usages += torch.arange(0, s[-1]).unsqueeze(0) * 1e-10
self.usages = self.zero_usages
def _init_consts(self, device):
if self.one is None:
self.one = torch.ones(1, device=device)
def new_sequence(self):
self.usages = None
def update_usages(self, prev_write_distribution, prev_read_distributions, free_gates):
# Read distributions shape: [batch, n_heads, cell count]
# Free gates shape: [batch, n_heads]
self._init_consts(prev_read_distributions.device)
phi = torch.addcmul(self.one, -1, free_gates.unsqueeze(-1), prev_read_distributions).prod(-2)
# Phi is the free tensor, sized [batch, cell count]
# If memory usage counter if doesn't exists
if self.usages is None:
self._init_sequence(prev_read_distributions)
# in first timestep nothing is written or read yet, so we don't need any further processing
else:
self.usages = torch.addcmul(self.usages, 1, prev_write_distribution.detach(), (1 - self.usages)) * phi
return phi
def forward(self, prev_write_distribution, prev_read_distributions, free_gates):
phi = self.update_usages(prev_write_distribution, prev_read_distributions, free_gates)
sorted_usage, free_list = (self.usages*(1.0-_EPS)+_EPS).sort(-1)
u_prod = sorted_usage.cumprod(-1)
one_minus_usage = 1.0 - sorted_usage
sorted_scores = torch.cat([one_minus_usage[..., 0:1], one_minus_usage[..., 1:] * u_prod[..., :-1]], dim=-1)
return sorted_scores.clone().scatter_(-1, free_list, sorted_scores), phi
class ContentAddressGenerator(torch.nn.Module):
def __init__(self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
super(ContentAddressGenerator, self).__init__()
self.disable_content_norm = disable_content_norm
self.mask_min = mask_min
self.disable_key_masking = disable_key_masking
def forward(self, memory, keys, betas, mask=None):
# Memory shape [batch, cell count, word length]
# Key shape [batch, n heads*, word length]
# Betas shape [batch, n heads]
if mask is not None and self.mask_min != 0:
mask = mask * (1.0-self.mask_min) + self.mask_min
single_head = keys.dim() == 2
if single_head:
# Single head
keys = keys.unsqueeze(1)
if mask is not None:
mask = mask.unsqueeze(1)
memory = memory.unsqueeze(1)
keys = keys.unsqueeze(-2)
if mask is not None:
mask = mask.unsqueeze(-2)
memory = memory * mask
if not self.disable_key_masking:
keys = keys * mask
# Shape [batch, n heads, cell count]
norm = keys.norm(dim=-1)
if not self.disable_content_norm:
norm = norm * memory.norm(dim=-1)
scores = (memory * keys).sum(-1) / (norm + _EPS)
scores *= betas.unsqueeze(-1)
res = F.softmax(scores, scores.dim()-1)
return res.squeeze(1) if single_head else res
class WriteHead(torch.nn.Module):
@staticmethod
def create_write_archive(write_dist, erase_vector, write_vector, phi):
return dict(write_dist=write_dist, erase_vector=erase_vector, write_vector=write_vector, phi=phi)
def __init__(self, dealloc_content=True, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
super(WriteHead, self).__init__()
self.write_content_generator = ContentAddressGenerator(disable_content_norm, mask_min=mask_min, disable_key_masking=disable_key_masking)
self.allocation_manager = AllocationManager()
self.last_write = None
self.dealloc_content = dealloc_content
self.new_sequence()
def new_sequence(self):
self.last_write = None
self.allocation_manager.new_sequence()
@staticmethod
def mem_update(memory, write_dist, erase_vector, write_vector, phi):
# In original paper the memory content is NOT deallocated, which makes content based addressing basically
# unusable when multiple similar steps should be done. The reason for this is that the memory contents are
# still there, so the lookup will find them, unless an allocation clears it before the next search, which is
# completely random. So I'm arguing that erase matrix should also take in account the free gates (multiply it
# with phi)
write_dist = write_dist.unsqueeze(-1)
erase_matrix = 1.0 - write_dist * erase_vector.unsqueeze(-2)
if phi is not None:
erase_matrix = erase_matrix * phi.unsqueeze(-1)
update_matrix = write_dist * write_vector.unsqueeze(-2)
return memory * erase_matrix + update_matrix
def forward(self, memory, write_content_key, write_beta, erase_vector, write_vector, alloc_gate, write_gate,
free_gates, prev_read_dist, write_mask=None, debug=None):
last_w_dist = self.last_write["write_dist"] if self.last_write is not None else None
content_dist = self.write_content_generator(memory, write_content_key, write_beta, mask = write_mask)
alloc_dist, phi = self.allocation_manager(last_w_dist, prev_read_dist, free_gates)
# Shape [batch, cell count]
write_dist = write_gate * (alloc_gate * alloc_dist + (1-alloc_gate)*content_dist)
self.last_write = WriteHead.create_write_archive(write_dist, erase_vector, write_vector, phi if self.dealloc_content else None)
dict_append(debug, "alloc_dist", alloc_dist)
dict_append(debug, "write_dist", write_dist)
dict_append(debug, "mem_usages", self.allocation_manager.usages)
dict_append(debug, "free_gates", free_gates)
dict_append(debug, "write_betas", write_beta)
dict_append(debug, "write_gate", write_gate)
dict_append(debug, "write_vector", write_vector)
dict_append(debug, "alloc_gate", alloc_gate)
dict_append(debug, "erase_vector", erase_vector)
if write_mask is not None:
dict_append(debug, "write_mask", write_mask)
return WriteHead.mem_update(memory, **self.last_write)
class RawWriteHead(torch.nn.Module):
def __init__(self, n_read_heads, word_length, use_mask=False, dealloc_content=True, disable_content_norm=False,
mask_min=0.0, disable_key_masking=False):
super(RawWriteHead, self).__init__()
self.write_head = WriteHead(dealloc_content = dealloc_content, disable_content_norm = disable_content_norm,
mask_min=mask_min, disable_key_masking=disable_key_masking)
self.word_length = word_length
self.n_read_heads = n_read_heads
self.use_mask = use_mask
self.input_size = 3*self.word_length + self.n_read_heads + 3 + (self.word_length if use_mask else 0)
def new_sequence(self):
self.write_head.new_sequence()
def get_prev_write(self):
return self.write_head.last_write
def forward(self, memory, nn_output, prev_read_dist, debug):
shapes = [[self.word_length]] * (4 if self.use_mask else 3) + [[self.n_read_heads]] + [[1]] * 3
tensors = split_tensor(nn_output, shapes)
if self.use_mask:
write_mask = torch.sigmoid(tensors[0])
tensors=tensors[1:]
else:
write_mask = None
write_content_key, erase_vector, write_vector, free_gates, write_beta, alloc_gate, write_gate = tensors
erase_vector = torch.sigmoid(erase_vector)
free_gates = torch.sigmoid(free_gates)
write_beta = oneplus(write_beta)
alloc_gate = torch.sigmoid(alloc_gate)
write_gate = torch.sigmoid(write_gate)
return self.write_head(memory, write_content_key, write_beta, erase_vector, write_vector,
alloc_gate, write_gate, free_gates, prev_read_dist, debug=debug, write_mask=write_mask)
def get_neural_input_size(self):
return self.input_size
class TemporalMemoryLinkage(torch.nn.Module):
def __init__(self):
super(TemporalMemoryLinkage, self).__init__()
self.temp_link_mat = None
self.precedence_weighting = None
self.diag_mask = None
self.initial_temp_link_mat = None
self.initial_precedence_weighting = None
self.initial_diag_mask = None
self.initial_shape = None
def new_sequence(self):
self.temp_link_mat = None
self.precedence_weighting = None
self.diag_mask = None
def _init_link(self, w_dist):
s = list(w_dist.size())
if self.initial_shape is None or s != self.initial_shape:
self.initial_temp_link_mat = torch.zeros(s[0], s[-1], s[-1]).to(w_dist.device)
self.initial_precedence_weighting = torch.zeros(s[0], s[-1]).to(w_dist.device)
self.initial_diag_mask = (1.0 - torch.eye(s[-1]).unsqueeze(0).to(w_dist)).detach()
self.temp_link_mat = self.initial_temp_link_mat
self.precedence_weighting = self.initial_precedence_weighting
self.diag_mask = self.initial_diag_mask
def _update_precedence(self, w_dist):
# w_dist shape: [ batch, cell count ]
self.precedence_weighting = (1.0 - w_dist.sum(-1, keepdim=True)) * self.precedence_weighting + w_dist
def _update_links(self, w_dist):
if self.temp_link_mat is None:
self._init_link(w_dist)
wt_i = w_dist.unsqueeze(-1)
wt_j = w_dist.unsqueeze(-2)
pt_j = self.precedence_weighting.unsqueeze(-2)
self.temp_link_mat = ((1 - wt_i - wt_j) * self.temp_link_mat + wt_i * pt_j) * self.diag_mask
def forward(self, w_dist, prev_r_dists, debug = None):
self._update_links(w_dist)
self._update_precedence(w_dist)
# prev_r_dists shape: [ batch, n heads, cell count ]
# Emulate matrix-vector multiplication by broadcast and sum. This way we don't need to transpose the matrix
tlm_multi_head = self.temp_link_mat.unsqueeze(1)
forward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-2)).sum(-1)
backward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-1)).sum(-2)
dict_append(debug, "forward_dists", forward_dist)
dict_append(debug, "backward_dists", backward_dist)
dict_append(debug, "precedence_weights", self.precedence_weighting)
# output shapes [ batch, n_heads, cell_count ]
return forward_dist, backward_dist
class ReadHead(torch.nn.Module):
def __init__(self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
super(ReadHead, self).__init__()
self.content_addr_generator = ContentAddressGenerator(disable_content_norm=disable_content_norm,
mask_min=mask_min,
disable_key_masking=disable_key_masking)
self.read_dist = None
self.read_data = None
self.new_sequence()
def new_sequence(self):
self.read_dist = None
self.read_data = None
def forward(self, memory, read_content_keys, read_betas, forward_dist, backward_dist, gates, read_mask=None, debug=None):
content_dist = self.content_addr_generator(memory, read_content_keys, read_betas, mask=read_mask)
self.read_dist = backward_dist * gates[..., 0:1] + content_dist * gates[...,1:2] + forward_dist * gates[..., 2:]
# memory shape: [ batch, cell count, word_length ]
# read_dist shape: [ batch, n heads, cell count ]
# result shape: [ batch, n_heads, word_length ]
self.read_data = (memory.unsqueeze(1) * self.read_dist.unsqueeze(-1)).sum(-2)
dict_append(debug, "content_dist", content_dist)
dict_append(debug, "balance", gates)
dict_append(debug, "read_dist", self.read_dist)
dict_append(debug, "read_content_keys", read_content_keys)
if read_mask is not None:
dict_append(debug, "read_mask", read_mask)
dict_append(debug, "read_betas", read_betas.unsqueeze(-2))
if read_mask is not None:
dict_append(debug, "read_mask", read_mask)
return self.read_data
class RawReadHead(torch.nn.Module):
def __init__(self, n_heads, word_length, use_mask=False, disable_content_norm=False, mask_min=0.0,
disable_key_masking=False):
super(RawReadHead, self).__init__()
self.read_head = ReadHead(disable_content_norm=disable_content_norm, mask_min=mask_min,
disable_key_masking=disable_key_masking)
self.n_heads = n_heads
self.word_length = word_length
self.use_mask = use_mask
self.input_size = self.n_heads * (self.word_length*(2 if use_mask else 1) + 3 + 1)
def get_prev_dist(self, memory):
if self.read_head.read_dist is not None:
return self.read_head.read_dist
else:
m_shape = memory.size()
return torch.zeros(m_shape[0], self.n_heads, m_shape[1]).to(memory)
def get_prev_data(self, memory):
if self.read_head.read_data is not None:
return self.read_head.read_data
else:
m_shape = memory.size()
return torch.zeros(m_shape[0], self.n_heads, m_shape[-1]).to(memory)
def new_sequence(self):
self.read_head.new_sequence()
def forward(self, memory, nn_output, forward_dist, backward_dist, debug):
shapes = [[self.n_heads, self.word_length]] * (2 if self.use_mask else 1) + [[self.n_heads], [self.n_heads, 3]]
tensors = split_tensor(nn_output, shapes)
if self.use_mask:
read_mask = torch.sigmoid(tensors[0])
tensors = tensors[1:]
else:
read_mask = None
keys, betas, gates = tensors
betas = oneplus(betas)
gates = F.softmax(gates, gates.dim()-1)
return self.read_head(memory, keys, betas, forward_dist, backward_dist, gates, debug=debug, read_mask=read_mask)
def get_neural_input_size(self):
return self.input_size
class DistSharpnessEnhancer(torch.nn.Module):
def __init__(self, n_heads):
super(DistSharpnessEnhancer, self).__init__()
self.n_heads = n_heads if isinstance(n_heads, list) else [n_heads]
self.n_data = sum(self.n_heads)
def forward(self, nn_input, *dists):
assert len(dists) == len(self.n_heads)
nn_input = oneplus(nn_input[..., :self.n_data])
factors = split_tensor(nn_input, self.n_heads)
res = []
for i, d in enumerate(dists):
s = list(d.size())
ndim = d.dim()
f = factors[i]
if ndim==2:
assert self.n_heads[i]==1
elif ndim==3:
f = f.unsqueeze(-1)
else:
assert False
d += _EPS
d = d / d.max(dim=-1, keepdim=True)[0]
d = d.pow(f)
d = d / d.sum(dim=-1, keepdim=True)
res.append(d)
return res
def get_neural_input_size(self):
return self.n_data
class DNC(torch.nn.Module):
def __init__(self, input_size, output_size, word_length, cell_count, n_read_heads, controller, batch_first=False, clip_controller=20,
bias=True, mask=False, dealloc_content=True, link_sharpness_control=True, disable_content_norm=False,
mask_min=0.0, disable_key_masking=False):
super(DNC, self).__init__()
self.clip_controller = clip_controller
self.read_head = RawReadHead(n_read_heads, word_length, use_mask=mask, disable_content_norm=disable_content_norm,
mask_min=mask_min, disable_key_masking=disable_key_masking)
self.write_head = RawWriteHead(n_read_heads, word_length, use_mask=mask, dealloc_content=dealloc_content,
disable_content_norm=disable_content_norm, mask_min=mask_min,
disable_key_masking=disable_key_masking)
self.temporal_link = TemporalMemoryLinkage()
self.sharpness_control = DistSharpnessEnhancer([n_read_heads, n_read_heads]) if link_sharpness_control else None
in_size = input_size + n_read_heads * word_length
control_channels = self.read_head.get_neural_input_size() + self.write_head.get_neural_input_size() +\
(self.sharpness_control.get_neural_input_size() if self.sharpness_control is not None else 0)
self.controller = controller
controller.init(in_size)
self.controller_to_controls = torch.nn.Linear(controller.get_output_size(), control_channels, bias=bias)
self.controller_to_out = torch.nn.Linear(controller.get_output_size(), output_size, bias=bias)
self.read_to_out = torch.nn.Linear(word_length * n_read_heads, output_size, bias=bias)
self.cell_count = cell_count
self.word_length = word_length
self.memory = None
self.reset_parameters()
self.batch_first = batch_first
self.zero_mem_tensor = None
def reset_parameters(self):
linear_reset(self.controller_to_controls)
linear_reset(self.controller_to_out)
linear_reset(self.read_to_out)
self.controller.reset_parameters()
def _step(self, in_data, debug):
init_debug(debug, {
"read_head": {},
"write_head": {},
"temporal_links": {}
})
# input shape: [ batch, channels ]
batch_size = in_data.size(0)
# run the controller
prev_read_data = self.read_head.get_prev_data(self.memory).view([batch_size, -1])
control_data = self.controller(torch.cat([in_data, prev_read_data], -1))
# memory ops
controls = self.controller_to_controls(control_data).contiguous()
controls = controls.clamp(-self.clip_controller, self.clip_controller) if self.clip_controller is not None else controls
shapes = [[self.write_head.get_neural_input_size()], [self.read_head.get_neural_input_size()]]
if self.sharpness_control is not None:
shapes.append(self.sharpness_control.get_neural_input_size())
tensors = split_tensor(controls, shapes)
write_head_control, read_head_control = tensors[:2]
tensors = tensors[2:]
prev_read_dist = self.read_head.get_prev_dist(self.memory)
self.memory = self.write_head(self.memory, write_head_control, prev_read_dist, debug=dict_get(debug,"write_head"))
prev_write = self.write_head.get_prev_write()
forward_dist, backward_dist = self.temporal_link(prev_write["write_dist"] if prev_write is not None else None, prev_read_dist, debug=dict_get(debug, "temporal_links"))
if self.sharpness_control is not None:
forward_dist, backward_dist = self.sharpness_control(tensors[0], forward_dist, backward_dist)
read_data = self.read_head(self.memory, read_head_control, forward_dist, backward_dist, debug=dict_get(debug,"read_head"))
# output:
return self.controller_to_out(control_data) + self.read_to_out(read_data.view(batch_size,-1))
def _mem_init(self, batch_size, device):
if self.zero_mem_tensor is None or self.zero_mem_tensor.size(0)!=batch_size:
self.zero_mem_tensor = torch.zeros(batch_size, self.cell_count, self.word_length).to(device)
self.memory = self.zero_mem_tensor
def forward(self, in_data, debug=None):
self.write_head.new_sequence()
self.read_head.new_sequence()
self.temporal_link.new_sequence()
self.controller.new_sequence()
self._mem_init(in_data.size(0 if self.batch_first else 1), in_data.device)
out_tsteps = []
if self.batch_first:
# input format: batch, time, channels
for t in range(in_data.size(1)):
out_tsteps.append(self._step(in_data[:,t], debug))
else:
# input format: time, batch, channels
for t in range(in_data.size(0)):
out_tsteps.append(self._step(in_data[t], debug))
merge_debug_tensors(debug, dim=1 if self.batch_first else 0)
return torch.stack(out_tsteps, dim=1 if self.batch_first else 0)
class LSTMController(torch.nn.Module):
def __init__(self, layer_sizes, out_from_all_layers=True):
super(LSTMController, self).__init__()
self.out_from_all_layers = out_from_all_layers
self.layer_sizes = layer_sizes
self.states = None
self.outputs = None
def new_sequence(self):
self.states = [None] * len(self.layer_sizes)
self.outputs = [None] * len(self.layer_sizes)
def reset_parameters(self):
def init_layer(l, index):
size = self.layer_sizes[index]
# Initialize all matrices to sigmoid, just data input to tanh
a=math.sqrt(3.0)*self.stdevs[i]
l.weight.data[0:-size].uniform_(-a,a)
a*=init.calculate_gain("tanh")
l.weight.data[-size:].uniform_(-a, a)
if l.bias is not None:
l.bias.data[self.layer_sizes[i]:].fill_(0)
# init forget gate to large number.
l.bias.data[:self.layer_sizes[i]].fill_(1)
# xavier init merged input weights
for i in range(len(self.layer_sizes)):
init_layer(self.in_to_all[i], i)
init_layer(self.out_to_all[i], i)
if i>0:
init_layer(self.prev_to_all[i-1], i)
def _add_modules(self, name, m_list):
for i, m in enumerate(m_list):
self.add_module("%s_%d" % (name,i), m)
def init(self, input_size):
self.layer_sizes = self.layer_sizes
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
# So use xavier init according to this.
self.input_sizes = [(self.layer_sizes[i - 1] if i>0 else 0) + self.layer_sizes[i] + input_size
for i in range(len(self.layer_sizes))]
self.stdevs = [math.sqrt(2.0 / (self.layer_sizes[i] + self.input_sizes[i])) for i in range(len(self.layer_sizes))]
self.in_to_all= [torch.nn.Linear(input_size, 4*self.layer_sizes[i]) for i in range(len(self.layer_sizes))]
self.out_to_all = [torch.nn.Linear(self.layer_sizes[i], 4 * self.layer_sizes[i], bias=False) for i in range(len(self.layer_sizes))]
self.prev_to_all = [torch.nn.Linear(self.layer_sizes[i-1], 4 * self.layer_sizes[i], bias=False) for i in range(1,len(self.layer_sizes))]
self._add_modules("in_to_all", self.in_to_all)
self._add_modules("out_to_all", self.out_to_all)
self._add_modules("prev_to_all", self.prev_to_all)
self.reset_parameters()
def get_output_size(self):
return sum(self.layer_sizes) if self.out_from_all_layers else self.layer_sizes[-1]
def forward(self, data):
for i, size in enumerate(self.layer_sizes):
d = self.in_to_all[i](data)
if self.outputs[i] is not None:
d+=self.out_to_all[i](self.outputs[i])
if i>0:
d+=self.prev_to_all[i-1](self.outputs[i-1])
input_data = torch.tanh(d[...,-size:])
forget_gate, input_gate, output_gate = torch.sigmoid(d[...,:-size]).chunk(3,dim=-1)
state_update = input_gate * input_data
if self.states[i] is not None:
self.states[i] = self.states[i]*forget_gate + state_update
else:
self.states[i] = state_update
self.outputs[i] = output_gate * torch.tanh(self.states[i])
return torch.cat(self.outputs, -1) if self.out_from_all_layers else self.outputs[-1]
class FeedforwardController(torch.nn.Module):
def __init__(self, layer_sizes=[]):
super(FeedforwardController, self).__init__()
self.layer_sizes = layer_sizes
def new_sequence(self):
pass
def reset_parameters(self):
for module in self.model:
if isinstance(module, torch.nn.Linear):
linear_reset(module, gain=init.calculate_gain("relu"))
def get_output_size(self):
return self.layer_sizes[-1]
def init(self, input_size):
self.layer_sizes = self.layer_sizes
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
# So use xavier init according to this.
self.input_sizes = [input_size] + self.layer_sizes[:-1]
layers = []
for i, size in enumerate(self.layer_sizes):
layers.append(torch.nn.Linear(self.input_sizes[i], self.layer_sizes[i]))
layers.append(torch.nn.ReLU())
self.model = torch.nn.Sequential(*layers)
self.reset_parameters()
def forward(self, data):
return self.model(data)