Compare commits
31 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
33e35326db | ||
|
d57776c45a | ||
|
a660434d21 | ||
|
00bfa63bc5 | ||
|
be40616920 | ||
|
016b541223 | ||
|
79dc405f37 | ||
|
47140303e9 | ||
|
d45461db1c | ||
|
c48d3d9ba4 | ||
|
fd347807a5 | ||
|
bc6359fb64 | ||
|
56f347a934 | ||
|
1db78511fe | ||
|
3e477fcf18 | ||
|
4266c2e7aa | ||
|
9fe6375518 | ||
|
db74f8ea57 | ||
|
9b3f68fbfd | ||
|
b428bfac12 | ||
|
188548fa3c | ||
|
f528a4c120 | ||
|
4178130e8f | ||
|
b7d4e1cde2 | ||
|
bbf48e61e8 | ||
|
cc2c3bcebc | ||
|
092bdb8f93 | ||
|
bcb1bf901e | ||
|
2e24452dfa | ||
|
19c759d2cf | ||
|
4115e69155 |
@ -8,7 +8,7 @@ before_install:
|
||||
- sudo apt-get install -yqq libopenblas-dev liblapack3 python3-numpy python3-dev swig
|
||||
- sudo ln -s /usr/lib/libopenblas.so /usr/lib/libopenblas.so.3
|
||||
install:
|
||||
- pip install -qqq http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
|
||||
- pip install -qqq https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
|
||||
- pip install -qqq numpy
|
||||
- pip install -qqq visdom
|
||||
- pip install -qqq pyflann3
|
||||
|
121
CHANGELOG.md
Normal file
121
CHANGELOG.md
Normal file
@ -0,0 +1,121 @@
|
||||
# Change Log
|
||||
|
||||
## [Unreleased](https://github.com/ixaxaar/pytorch-dnc/tree/HEAD)
|
||||
|
||||
[Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/1.0.1...HEAD)
|
||||
|
||||
**Merged pull requests:**
|
||||
|
||||
- Fixes for \#43 [\#44](https://github.com/ixaxaar/pytorch-dnc/pull/44) ([ixaxaar](https://github.com/ixaxaar))
|
||||
|
||||
## [1.0.1](https://github.com/ixaxaar/pytorch-dnc/tree/1.0.1) (2019-04-05)
|
||||
[Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/1.0.0...1.0.1)
|
||||
|
||||
**Closed issues:**
|
||||
|
||||
- When running adding task -- ModuleNotFoundError: No module named 'index' [\#39](https://github.com/ixaxaar/pytorch-dnc/issues/39)
|
||||
- SyntaxError [\#36](https://github.com/ixaxaar/pytorch-dnc/issues/36)
|
||||
- PySide dependency error [\#33](https://github.com/ixaxaar/pytorch-dnc/issues/33)
|
||||
- Issues when using pytorch 0.4 [\#31](https://github.com/ixaxaar/pytorch-dnc/issues/31)
|
||||
- TypeError: cat received an invalid combination of arguments - got \(list, int\), but expected one of: [\#29](https://github.com/ixaxaar/pytorch-dnc/issues/29)
|
||||
|
||||
**Merged pull requests:**
|
||||
|
||||
- Fixes \#36 and \#39 [\#42](https://github.com/ixaxaar/pytorch-dnc/pull/42) ([ixaxaar](https://github.com/ixaxaar))
|
||||
|
||||
## [1.0.0](https://github.com/ixaxaar/pytorch-dnc/tree/1.0.0) (2019-04-05)
|
||||
[Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.9...1.0.0)
|
||||
|
||||
**Closed issues:**
|
||||
|
||||
- Question about the running speed of Pyflann and Faiss for the SAM model [\#40](https://github.com/ixaxaar/pytorch-dnc/issues/40)
|
||||
- SyntaxError [\#37](https://github.com/ixaxaar/pytorch-dnc/issues/37)
|
||||
- Values in hidden become nan [\#35](https://github.com/ixaxaar/pytorch-dnc/issues/35)
|
||||
- faiss error [\#32](https://github.com/ixaxaar/pytorch-dnc/issues/32)
|
||||
|
||||
**Merged pull requests:**
|
||||
|
||||
- Port to pytorch 1.x [\#41](https://github.com/ixaxaar/pytorch-dnc/pull/41) ([ixaxaar](https://github.com/ixaxaar))
|
||||
- fix parens in example usage and gpu usage for SAM [\#30](https://github.com/ixaxaar/pytorch-dnc/pull/30) ([kierkegaard13](https://github.com/kierkegaard13))
|
||||
|
||||
## [0.0.9](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.9) (2018-04-23)
|
||||
[Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.7...0.0.9)
|
||||
|
||||
**Fixed bugs:**
|
||||
|
||||
- Use usage vector to determine least recently used memory [\#26](https://github.com/ixaxaar/pytorch-dnc/issues/26)
|
||||
- Store entire memory after memory limit is reached [\#24](https://github.com/ixaxaar/pytorch-dnc/issues/24)
|
||||
|
||||
**Merged pull requests:**
|
||||
|
||||
- memory.py: fix indexing for read\_modes transform [\#28](https://github.com/ixaxaar/pytorch-dnc/pull/28) ([jbinas](https://github.com/jbinas))
|
||||
- Bugfixes [\#27](https://github.com/ixaxaar/pytorch-dnc/pull/27) ([ixaxaar](https://github.com/ixaxaar))
|
||||
|
||||
## [0.0.7](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.7) (2017-12-20)
|
||||
[Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.6...0.0.7)
|
||||
|
||||
**Implemented enhancements:**
|
||||
|
||||
- GPU kNNs [\#21](https://github.com/ixaxaar/pytorch-dnc/issues/21)
|
||||
- Implement temporal addressing for SDNCs [\#18](https://github.com/ixaxaar/pytorch-dnc/issues/18)
|
||||
- Feature: Sparse Access Memory [\#4](https://github.com/ixaxaar/pytorch-dnc/issues/4)
|
||||
- SAMs [\#22](https://github.com/ixaxaar/pytorch-dnc/pull/22) ([ixaxaar](https://github.com/ixaxaar))
|
||||
- Temporal links for SDNC [\#19](https://github.com/ixaxaar/pytorch-dnc/pull/19) ([ixaxaar](https://github.com/ixaxaar))
|
||||
- SDNC [\#16](https://github.com/ixaxaar/pytorch-dnc/pull/16) ([ixaxaar](https://github.com/ixaxaar))
|
||||
|
||||
**Merged pull requests:**
|
||||
|
||||
- Add more tasks [\#23](https://github.com/ixaxaar/pytorch-dnc/pull/23) ([ixaxaar](https://github.com/ixaxaar))
|
||||
- Scale interface vectors, dynamic memory pass [\#17](https://github.com/ixaxaar/pytorch-dnc/pull/17) ([ixaxaar](https://github.com/ixaxaar))
|
||||
- Update README.md [\#14](https://github.com/ixaxaar/pytorch-dnc/pull/14) ([MaxwellRebo](https://github.com/MaxwellRebo))
|
||||
|
||||
## [0.0.6](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.6) (2017-11-12)
|
||||
[Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.5.0...0.0.6)
|
||||
|
||||
**Implemented enhancements:**
|
||||
|
||||
- Re-write allocation vector code, use pytorch's cumprod [\#13](https://github.com/ixaxaar/pytorch-dnc/issues/13)
|
||||
|
||||
**Fixed bugs:**
|
||||
|
||||
- Stacked DNCs forward pass wrong [\#12](https://github.com/ixaxaar/pytorch-dnc/issues/12)
|
||||
- Temporal debugging of memory [\#11](https://github.com/ixaxaar/pytorch-dnc/pull/11) ([ixaxaar](https://github.com/ixaxaar))
|
||||
|
||||
## [0.5.0](https://github.com/ixaxaar/pytorch-dnc/tree/0.5.0) (2017-11-01)
|
||||
[Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.3...0.5.0)
|
||||
|
||||
**Implemented enhancements:**
|
||||
|
||||
- Multiple hidden layers per controller layer [\#7](https://github.com/ixaxaar/pytorch-dnc/issues/7)
|
||||
- Vizdom integration and fix cumprod bug \#5 [\#6](https://github.com/ixaxaar/pytorch-dnc/pull/6) ([ixaxaar](https://github.com/ixaxaar))
|
||||
|
||||
**Fixed bugs:**
|
||||
|
||||
- Use shifted cumprods, emulate tensorflow's cumprod with exclusive=True [\#5](https://github.com/ixaxaar/pytorch-dnc/issues/5)
|
||||
- Vizdom integration and fix cumprod bug \\#5 [\#6](https://github.com/ixaxaar/pytorch-dnc/pull/6) ([ixaxaar](https://github.com/ixaxaar))
|
||||
|
||||
**Closed issues:**
|
||||
|
||||
- Write unit tests [\#8](https://github.com/ixaxaar/pytorch-dnc/issues/8)
|
||||
- broken links [\#3](https://github.com/ixaxaar/pytorch-dnc/issues/3)
|
||||
|
||||
**Merged pull requests:**
|
||||
|
||||
- Test travis build [\#10](https://github.com/ixaxaar/pytorch-dnc/pull/10) ([ixaxaar](https://github.com/ixaxaar))
|
||||
- Implement Hidden layers, small enhancements, cleanups [\#9](https://github.com/ixaxaar/pytorch-dnc/pull/9) ([ixaxaar](https://github.com/ixaxaar))
|
||||
|
||||
## [0.0.3](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.3) (2017-10-27)
|
||||
[Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.2...0.0.3)
|
||||
|
||||
**Implemented enhancements:**
|
||||
|
||||
- Implementation of Dropout for controller [\#2](https://github.com/ixaxaar/pytorch-dnc/pull/2) ([ixaxaar](https://github.com/ixaxaar))
|
||||
- Fix size issue for GRU and vanilla RNN [\#1](https://github.com/ixaxaar/pytorch-dnc/pull/1) ([ixaxaar](https://github.com/ixaxaar))
|
||||
|
||||
## [0.0.2](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.2) (2017-10-26)
|
||||
[Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/v0.0.1...0.0.2)
|
||||
|
||||
## [v0.0.1](https://github.com/ixaxaar/pytorch-dnc/tree/v0.0.1) (2017-10-26)
|
||||
|
||||
|
||||
\* *This Change Log was automatically generated by [github_changelog_generator](https://github.com/skywinder/Github-Changelog-Generator)*
|
14
README.md
14
README.md
@ -122,7 +122,7 @@ rnn = DNC(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors) = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ rnn = DNC(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
Memory vectors returned by forward pass (`np.ndarray`):
|
||||
@ -223,7 +223,7 @@ rnn = SDNC(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors) = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
|
||||
@ -253,7 +253,7 @@ rnn = SDNC(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
Memory vectors returned by forward pass (`np.ndarray`):
|
||||
@ -327,7 +327,7 @@ rnn = SAM(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors) = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
|
||||
@ -356,7 +356,7 @@ rnn = SAM(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
Memory vectors returned by forward pass (`np.ndarray`):
|
||||
@ -456,7 +456,7 @@ python ./tasks/argmax_task.py -cuda 0 -lr 0.0001 -rnn_type lstm -memory_type dnc
|
||||
|
||||
## General noteworthy stuff
|
||||
|
||||
1. SDNCs use the [FLANN approximate nearest neigbhour library](https://www.cs.ubc.ca/research/flann/), with its python binding [pyflann3](https://github.com/primetang/pyflann) and [FAISS](https://github.com/facebookresearch/faiss).
|
||||
1. SDNCs use the [FLANN approximate nearest neigbhour library](https://github.com/mariusmuja/flann), with its python binding [pyflann3](https://github.com/primetang/pyflann) and [FAISS](https://github.com/facebookresearch/faiss).
|
||||
|
||||
FLANN can be installed either from pip (automatically as a dependency), or from source (e.g. for multithreading via OpenMP):
|
||||
|
||||
|
1411
README.rst
Normal file
1411
README.rst
Normal file
File diff suppressed because it is too large
Load Diff
@ -13,7 +13,7 @@ from torch.nn.utils.rnn import PackedSequence
|
||||
from .util import *
|
||||
from .memory import *
|
||||
|
||||
from torch.nn.init import orthogonal, xavier_uniform
|
||||
from torch.nn.init import orthogonal_, xavier_uniform_
|
||||
|
||||
|
||||
class DNC(nn.Module):
|
||||
@ -115,11 +115,12 @@ class DNC(nn.Module):
|
||||
|
||||
# final output layer
|
||||
self.output = nn.Linear(self.nn_output_size, self.input_size)
|
||||
orthogonal(self.output.weight)
|
||||
orthogonal_(self.output.weight)
|
||||
|
||||
if self.gpu_id != -1:
|
||||
[x.cuda(self.gpu_id) for x in self.rnns]
|
||||
[x.cuda(self.gpu_id) for x in self.memories]
|
||||
self.output.cuda()
|
||||
|
||||
def _init_hidden(self, hx, batch_size, reset_experience):
|
||||
# create empty hidden states if not provided
|
||||
@ -130,7 +131,7 @@ class DNC(nn.Module):
|
||||
# initialize hidden state of the controller RNN
|
||||
if chx is None:
|
||||
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
|
||||
xavier_uniform(h)
|
||||
xavier_uniform_(h)
|
||||
|
||||
chx = [ (h, h) if self.rnn_type.lower() == 'lstm' else h for x in range(self.num_layers)]
|
||||
|
||||
|
@ -214,47 +214,47 @@ class Memory(nn.Module):
|
||||
|
||||
if self.independent_linears:
|
||||
# r read keys (b * r * w)
|
||||
read_keys = F.tanh(self.read_keys_transform(ξ).view(b, r, w))
|
||||
read_keys = T.tanh(self.read_keys_transform(ξ).view(b, r, w))
|
||||
# r read strengths (b * r)
|
||||
read_strengths = F.softplus(self.read_strengths_transform(ξ).view(b, r))
|
||||
# write key (b * 1 * w)
|
||||
write_key = F.tanh(self.write_key_transform(ξ).view(b, 1, w))
|
||||
write_key = T.tanh(self.write_key_transform(ξ).view(b, 1, w))
|
||||
# write strength (b * 1)
|
||||
write_strength = F.softplus(self.write_strength_transform(ξ).view(b, 1))
|
||||
# erase vector (b * 1 * w)
|
||||
erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w))
|
||||
erase_vector = T.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w))
|
||||
# write vector (b * 1 * w)
|
||||
write_vector = F.tanh(self.write_vector_transform(ξ).view(b, 1, w))
|
||||
write_vector = T.tanh(self.write_vector_transform(ξ).view(b, 1, w))
|
||||
# r free gates (b * r)
|
||||
free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r))
|
||||
free_gates = T.sigmoid(self.free_gates_transform(ξ).view(b, r))
|
||||
# allocation gate (b * 1)
|
||||
allocation_gate = F.sigmoid(self.allocation_gate_transform(ξ).view(b, 1))
|
||||
allocation_gate = T.sigmoid(self.allocation_gate_transform(ξ).view(b, 1))
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
# read modes (b * r * 3)
|
||||
read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), 1)
|
||||
read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), -1)
|
||||
else:
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * w * r)
|
||||
read_keys = F.tanh(ξ[:, :r * w].contiguous().view(b, r, w))
|
||||
read_keys = T.tanh(ξ[:, :r * w].contiguous().view(b, r, w))
|
||||
# r read strengths (b * r)
|
||||
read_strengths = F.softplus(ξ[:, r * w:r * w + r].contiguous().view(b, r))
|
||||
# write key (b * w * 1)
|
||||
write_key = F.tanh(ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w))
|
||||
write_key = T.tanh(ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w))
|
||||
# write strength (b * 1)
|
||||
write_strength = F.softplus(ξ[:, r * w + r + w].contiguous().view(b, 1))
|
||||
# erase vector (b * w)
|
||||
erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w))
|
||||
erase_vector = T.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w))
|
||||
# write vector (b * w)
|
||||
write_vector = F.tanh(ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w))
|
||||
write_vector = T.tanh(ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w))
|
||||
# r free gates (b * r)
|
||||
free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r))
|
||||
free_gates = T.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r))
|
||||
# allocation gate (b * 1)
|
||||
allocation_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 1].contiguous().unsqueeze(1).view(b, 1))
|
||||
allocation_gate = T.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 1].contiguous().unsqueeze(1).view(b, 1))
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1)
|
||||
write_gate = T.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1)
|
||||
# read modes (b * 3*r)
|
||||
read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 2: r * w + 5 * r + 3 * w + 2].contiguous().view(b, r, 3), 1)
|
||||
read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 3: r * w + 5 * r + 3 * w + 3].contiguous().view(b, r, 3), -1)
|
||||
|
||||
hidden = self.write(write_key, write_vector, erase_vector, free_gates,
|
||||
read_strengths, write_strength, write_gate, allocation_gate, hidden)
|
||||
|
@ -9,7 +9,7 @@ import numpy as np
|
||||
from torch.nn.utils.rnn import pad_packed_sequence as pad
|
||||
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.nn.init import orthogonal, xavier_uniform
|
||||
from torch.nn.init import orthogonal_, xavier_uniform_
|
||||
|
||||
from .util import *
|
||||
from .sparse_memory import SparseMemory
|
||||
|
@ -9,7 +9,7 @@ import numpy as np
|
||||
from torch.nn.utils.rnn import pad_packed_sequence as pad
|
||||
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.nn.init import orthogonal, xavier_uniform
|
||||
from torch.nn.init import orthogonal_, xavier_uniform_
|
||||
|
||||
from .util import *
|
||||
from .sparse_temporal_memory import SparseTemporalMemory
|
||||
|
@ -48,23 +48,34 @@ class SparseMemory(nn.Module):
|
||||
self.c = (r * self.K) + 1
|
||||
|
||||
if self.independent_linears:
|
||||
if self.gpu_id != -1:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w * r).cuda()
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w).cuda()
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c).cuda()
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1).cuda()
|
||||
else:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w * r)
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
T.nn.init.orthogonal(self.read_query_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_vector_transform.weight)
|
||||
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_gate_transform.weight)
|
||||
T.nn.init.orthogonal_(self.read_query_transform.weight)
|
||||
T.nn.init.orthogonal_(self.write_vector_transform.weight)
|
||||
T.nn.init.orthogonal_(self.interpolation_gate_transform.weight)
|
||||
T.nn.init.orthogonal_(self.write_gate_transform.weight)
|
||||
else:
|
||||
self.interface_size = (r * w) + w + self.c + 1
|
||||
if self.gpu_id != -1:
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size).cuda()
|
||||
else:
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||
T.nn.init.orthogonal(self.interface_weights.weight)
|
||||
T.nn.init.orthogonal_(self.interface_weights.weight)
|
||||
|
||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
if self.gpu_id != -1:
|
||||
self.cuda()
|
||||
|
||||
def rebuild_indexes(self, hidden, erase=False):
|
||||
b = hidden['memory'].size(0)
|
||||
@ -288,9 +299,9 @@ class SparseMemory(nn.Module):
|
||||
# write key (b * 1 * w)
|
||||
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
|
||||
interpolation_gate = T.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
else:
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * r * w)
|
||||
@ -298,9 +309,9 @@ class SparseMemory(nn.Module):
|
||||
# write key (b * 1 * w)
|
||||
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
|
||||
interpolation_gate = T.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||
write_gate = T.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||
|
||||
self.timestep += 1
|
||||
hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
|
||||
|
@ -8,7 +8,6 @@ import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from .flann_index import FLANNIndex
|
||||
from .util import *
|
||||
import time
|
||||
|
||||
@ -55,14 +54,14 @@ class SparseTemporalMemory(nn.Module):
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
T.nn.init.orthogonal(self.read_query_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_vector_transform.weight)
|
||||
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_gate_transform.weight)
|
||||
T.nn.init.orthogonal_(self.read_query_transform.weight)
|
||||
T.nn.init.orthogonal_(self.write_vector_transform.weight)
|
||||
T.nn.init.orthogonal_(self.interpolation_gate_transform.weight)
|
||||
T.nn.init.orthogonal_(self.write_gate_transform.weight)
|
||||
else:
|
||||
self.interface_size = (r * w) + w + self.c + 1
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||
T.nn.init.orthogonal(self.interface_weights.weight)
|
||||
T.nn.init.orthogonal_(self.interface_weights.weight)
|
||||
|
||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
@ -358,9 +357,9 @@ class SparseTemporalMemory(nn.Module):
|
||||
# write key (b * 1 * w)
|
||||
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
|
||||
interpolation_gate = T.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
else:
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * r * w)
|
||||
@ -368,9 +367,9 @@ class SparseTemporalMemory(nn.Module):
|
||||
# write key (b * 1 * w)
|
||||
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
|
||||
interpolation_gate = T.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||
write_gate = T.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||
|
||||
self.timestep += 1
|
||||
hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
|
||||
|
53
dnc/util.py
53
dnc/util.py
@ -4,7 +4,6 @@
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable as var
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
@ -24,49 +23,56 @@ def recursiveTrace(obj):
|
||||
|
||||
|
||||
def cuda(x, grad=False, gpu_id=-1):
|
||||
x = x.float() if T.is_tensor(x) else x
|
||||
if gpu_id == -1:
|
||||
return var(x, requires_grad=grad)
|
||||
t = T.FloatTensor(x)
|
||||
t.requires_grad=grad
|
||||
return t
|
||||
else:
|
||||
return var(x.pin_memory(), requires_grad=grad).cuda(gpu_id, async=True)
|
||||
t = T.FloatTensor(x.pin_memory()).cuda(gpu_id)
|
||||
t.requires_grad=grad
|
||||
return t
|
||||
|
||||
|
||||
def cudavec(x, grad=False, gpu_id=-1):
|
||||
if gpu_id == -1:
|
||||
return var(T.from_numpy(x), requires_grad=grad)
|
||||
t = T.Tensor(T.from_numpy(x))
|
||||
t.requires_grad = grad
|
||||
return t
|
||||
else:
|
||||
return var(T.from_numpy(x).pin_memory(), requires_grad=grad).cuda(gpu_id, async=True)
|
||||
t = T.Tensor(T.from_numpy(x).pin_memory()).cuda(gpu_id)
|
||||
t.requires_grad = grad
|
||||
return t
|
||||
|
||||
|
||||
def cudalong(x, grad=False, gpu_id=-1):
|
||||
if gpu_id == -1:
|
||||
return var(T.from_numpy(x.astype(np.long)), requires_grad=grad)
|
||||
t = T.LongTensor(T.from_numpy(x.astype(np.long)))
|
||||
t.requires_grad = grad
|
||||
return t
|
||||
else:
|
||||
return var(T.from_numpy(x.astype(np.long)).pin_memory(), requires_grad=grad).cuda(gpu_id, async=True)
|
||||
t = T.LongTensor(T.from_numpy(x.astype(np.long)).pin_memory()).cuda(gpu_id)
|
||||
t.requires_grad = grad
|
||||
return t
|
||||
|
||||
|
||||
def θ(a, b, dimA=2, dimB=2, normBy=2):
|
||||
"""Batchwise Cosine distance
|
||||
def θ(a, b, normBy=2):
|
||||
"""Batchwise Cosine similarity
|
||||
|
||||
Cosine distance
|
||||
Cosine similarity
|
||||
|
||||
Arguments:
|
||||
a {Tensor} -- A 3D Tensor (b * m * w)
|
||||
b {Tensor} -- A 3D Tensor (b * r * w)
|
||||
|
||||
Keyword Arguments:
|
||||
dimA {number} -- exponent value of the norm for `a` (default: {2})
|
||||
dimB {number} -- exponent value of the norm for `b` (default: {1})
|
||||
|
||||
Returns:
|
||||
Tensor -- Batchwise cosine distance (b * r * m)
|
||||
Tensor -- Batchwise cosine similarity (b * r * m)
|
||||
"""
|
||||
a_norm = T.norm(a, normBy, dimA, keepdim=True).expand_as(a) + δ
|
||||
b_norm = T.norm(b, normBy, dimB, keepdim=True).expand_as(b) + δ
|
||||
|
||||
x = T.bmm(a, b.transpose(1, 2)).transpose(1, 2) / (
|
||||
T.bmm(a_norm, b_norm.transpose(1, 2)).transpose(1, 2) + δ)
|
||||
# apply_dict(locals())
|
||||
return x
|
||||
dot = T.bmm(a, b.transpose(1,2))
|
||||
a_norm = T.norm(a, normBy, dim=2).unsqueeze(2)
|
||||
b_norm = T.norm(b, normBy, dim=2).unsqueeze(1)
|
||||
cos = dot / (a_norm * b_norm + δ)
|
||||
return cos.transpose(1,2).contiguous()
|
||||
|
||||
|
||||
def σ(input, axis=1):
|
||||
@ -89,10 +95,7 @@ def σ(input, axis=1):
|
||||
trans_size = trans_input.size()
|
||||
|
||||
input_2d = trans_input.contiguous().view(-1, trans_size[-1])
|
||||
if '0.3' in T.__version__:
|
||||
soft_max_2d = F.softmax(input_2d, -1)
|
||||
else:
|
||||
soft_max_2d = F.softmax(input_2d)
|
||||
soft_max_nd = soft_max_2d.view(*trans_size)
|
||||
return soft_max_nd.transpose(axis, len(input_size) - 1)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
pyflann3>=1.8.4.1
|
||||
torch>=0.2.0.post1
|
||||
torch==1.0.1.post2
|
||||
numpy>=1.13.3
|
||||
pytest>=3.0.0
|
7
setup.py
7
setup.py
@ -16,14 +16,13 @@ from os import path
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
|
||||
# Get the long description from the README file
|
||||
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
||||
with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name='dnc',
|
||||
|
||||
version='0.0.8',
|
||||
|
||||
version='1.1.0',
|
||||
description='Differentiable Neural Computer, for Pytorch',
|
||||
long_description=long_description,
|
||||
|
||||
@ -57,7 +56,7 @@ setup(
|
||||
|
||||
packages=find_packages(exclude=['contrib', 'docs', 'tests', 'tasks', 'scripts']),
|
||||
|
||||
install_requires=['torch', 'numpy', 'pyflann3'],
|
||||
install_requires=['torch', 'numpy', 'flann'],
|
||||
|
||||
extras_require={
|
||||
'dev': ['check-manifest'],
|
||||
|
@ -20,7 +20,7 @@ from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from dnc.dnc import DNC
|
||||
from dnc.sdnc import SDNC
|
||||
@ -99,7 +99,9 @@ def generate_data(length, size):
|
||||
sums = np.array(sums)
|
||||
sums = sums.reshape(1, 1, 1)
|
||||
|
||||
return cudavec(x_seq_list, gpu_id=args.cuda).float(), cudavec(sums, gpu_id=args.cuda).float(), sums_text
|
||||
return cudavec(x_seq_list.astype(np.float32), gpu_id=args.cuda).float(), \
|
||||
cudavec(sums.astype(np.float32), gpu_id=args.cuda).float(), \
|
||||
sums_text
|
||||
|
||||
|
||||
def cross_entropy(prediction, target):
|
||||
@ -219,9 +221,9 @@ if __name__ == '__main__':
|
||||
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
||||
optimizer.step()
|
||||
loss_value = loss.data[0]
|
||||
loss_value = loss.item()
|
||||
|
||||
# detach memory from graph
|
||||
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||
|
@ -20,7 +20,7 @@ from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from dnc.dnc import DNC
|
||||
from dnc.sdnc import SDNC
|
||||
@ -225,9 +225,9 @@ if __name__ == '__main__':
|
||||
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
||||
optimizer.step()
|
||||
loss_value = loss.data[0]
|
||||
loss_value = loss.item()
|
||||
|
||||
# detach memory from graph
|
||||
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||
|
@ -20,7 +20,7 @@ from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from dnc.dnc import DNC
|
||||
from dnc.sdnc import SDNC
|
||||
@ -212,9 +212,9 @@ if __name__ == '__main__':
|
||||
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
||||
optimizer.step()
|
||||
loss_value = loss.data[0]
|
||||
loss_value = loss.item()
|
||||
|
||||
summarize = (epoch % summarize_freq == 0)
|
||||
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
@ -71,7 +71,7 @@ def test_rnn_1():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
@ -127,7 +127,7 @@ def test_rnn_n():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
@ -188,7 +188,7 @@ def test_rnn_no_memory_pass():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
@ -70,7 +70,7 @@ def test_rnn_1():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
@ -126,7 +126,7 @@ def test_rnn_n():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
@ -187,7 +187,7 @@ def test_rnn_no_memory_pass():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
@ -71,7 +71,7 @@ def test_rnn_1():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
@ -127,7 +127,7 @@ def test_rnn_n():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
@ -188,7 +188,7 @@ def test_rnn_no_memory_pass():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
@ -72,7 +72,7 @@ def test_rnn_1():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
@ -130,7 +130,7 @@ def test_rnn_n():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
@ -191,7 +191,7 @@ def test_rnn_no_memory_pass():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
@ -72,7 +72,7 @@ def test_rnn_1():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
@ -130,7 +130,7 @@ def test_rnn_n():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
@ -191,7 +191,7 @@ def test_rnn_no_memory_pass():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
@ -72,7 +72,7 @@ def test_rnn_1():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
@ -130,7 +130,7 @@ def test_rnn_n():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
@ -191,7 +191,7 @@ def test_rnn_no_memory_pass():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
@ -74,7 +74,7 @@ def test_rnn_1():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
@ -134,7 +134,7 @@ def test_rnn_n():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
@ -197,7 +197,7 @@ def test_rnn_no_memory_pass():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
@ -74,7 +74,7 @@ def test_rnn_1():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
@ -134,7 +134,7 @@ def test_rnn_n():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
@ -197,7 +197,7 @@ def test_rnn_no_memory_pass():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
@ -74,7 +74,7 @@ def test_rnn_1():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
@ -134,7 +134,7 @@ def test_rnn_n():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
@ -197,7 +197,7 @@ def test_rnn_no_memory_pass():
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
|
@ -28,6 +28,6 @@ def generate_data(batch_size, length, size, cuda=-1):
|
||||
|
||||
def criterion(predictions, targets):
|
||||
return T.mean(
|
||||
-1 * F.logsigmoid(predictions) * (targets) - T.log(1 - F.sigmoid(predictions) + 1e-9) * (1 - targets)
|
||||
-1 * F.logsigmoid(predictions) * (targets) - T.log(1 - T.sigmoid(predictions) + 1e-9) * (1 - targets)
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user