Initial commit
This commit is contained in:
commit
2a2b6bfd78
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
__pycache__
|
||||||
|
*.png
|
||||||
|
save
|
||||||
|
.idea
|
66
Dataset/Bitmap/AssociativeRecall.py
Normal file
66
Dataset/Bitmap/AssociativeRecall.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from .BitmapTask import BitmapTask
|
||||||
|
from Utils.Seed import get_randstate
|
||||||
|
|
||||||
|
class AssociativeRecall(BitmapTask):
|
||||||
|
def __init__(self, length=None, bit_w=8, block_w=3, transform=lambda x: x):
|
||||||
|
super(AssociativeRecall, self).__init__()
|
||||||
|
self.length = length
|
||||||
|
self.bit_w = bit_w
|
||||||
|
self.block_w = block_w
|
||||||
|
self.transform = transform
|
||||||
|
self.seed = None
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if self.seed is None:
|
||||||
|
self.seed = get_randstate()
|
||||||
|
|
||||||
|
length = self.length() if callable(self.length) else self.length
|
||||||
|
if length is None:
|
||||||
|
# Random length batch hack.
|
||||||
|
length = key
|
||||||
|
|
||||||
|
stride = self.block_w+1
|
||||||
|
|
||||||
|
d = self.seed.randint(0, 2, [length * (self.block_w+1), self.bit_w + 2]).astype(np.float32)
|
||||||
|
d[:,-2:] = 0
|
||||||
|
|
||||||
|
# Terminate input block
|
||||||
|
for i in range(1,length,1):
|
||||||
|
d[i * stride - 1, :] = 0
|
||||||
|
d[i * stride - 1, -2] = 1
|
||||||
|
|
||||||
|
# Terminate input sequence
|
||||||
|
d[-1, :] = 0
|
||||||
|
d[-1, -1] = 1
|
||||||
|
|
||||||
|
# Add and terminate query
|
||||||
|
ti = self.seed.randint(0, length-1)
|
||||||
|
d = np.concatenate((d, d[ti * stride: (ti+1) * stride-1], np.zeros([self.block_w+1, self.bit_w+2], np.float32)), axis=0)
|
||||||
|
d[-(1+self.block_w),-1] = 1
|
||||||
|
|
||||||
|
# Target
|
||||||
|
target = np.zeros_like(d)
|
||||||
|
target[-self.block_w:] = d[(ti+1) * stride: (ti+2) * stride-1]
|
||||||
|
|
||||||
|
return self.transform({
|
||||||
|
"input": d,
|
||||||
|
"output": target
|
||||||
|
})
|
||||||
|
|
47
Dataset/Bitmap/BitmapTask.py
Normal file
47
Dataset/Bitmap/BitmapTask.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from Visualize.BitmapTask import visualize_bitmap_task
|
||||||
|
from Utils import Visdom
|
||||||
|
from Utils import universal as U
|
||||||
|
|
||||||
|
|
||||||
|
class BitmapTask(torch.utils.data.Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
super(BitmapTask, self).__init__()
|
||||||
|
|
||||||
|
self._img = Visdom.Image("preview")
|
||||||
|
|
||||||
|
def set_dump_dir(self, dir):
|
||||||
|
self._img.set_dump_dir(dir)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 0x7FFFFFFF
|
||||||
|
|
||||||
|
def visualize_preview(self, data, net_output):
|
||||||
|
img = visualize_bitmap_task(data["input"], [data["output"], U.sigmoid(net_output)])
|
||||||
|
self._img.draw(img)
|
||||||
|
|
||||||
|
def loss(self, net_output, target):
|
||||||
|
return F.binary_cross_entropy_with_logits(net_output, target, reduction="sum") / net_output.size(0)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state):
|
||||||
|
pass
|
55
Dataset/Bitmap/BitmapTaskRepeater.py
Normal file
55
Dataset/Bitmap/BitmapTaskRepeater.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from .BitmapTask import BitmapTask
|
||||||
|
|
||||||
|
class BitmapTaskRepeater(BitmapTask):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
super(BitmapTaskRepeater, self).__init__()
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
r = [self.dataset[k] for k in key]
|
||||||
|
if len(r)==1:
|
||||||
|
return r[0]
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"input": np.concatenate([a["input"] for a in r], axis=0),
|
||||||
|
"output": np.concatenate([a["output"] for a in r], axis=0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def key_sampler(length, repeat):
|
||||||
|
def call_sampler(s):
|
||||||
|
if callable(s):
|
||||||
|
return s()
|
||||||
|
elif isinstance(s, list):
|
||||||
|
if len(s) == 2:
|
||||||
|
return random.randint(*s)
|
||||||
|
elif len(s) == 1:
|
||||||
|
return s[0]
|
||||||
|
else:
|
||||||
|
assert False, "Invalid sample parameter: %s" % s
|
||||||
|
else:
|
||||||
|
return s
|
||||||
|
|
||||||
|
def s():
|
||||||
|
r = call_sampler(repeat)
|
||||||
|
return [call_sampler(length) for i in range(r)]
|
||||||
|
|
||||||
|
return s
|
55
Dataset/Bitmap/CopyTask.py
Normal file
55
Dataset/Bitmap/CopyTask.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from .BitmapTask import BitmapTask
|
||||||
|
from Utils.Seed import get_randstate
|
||||||
|
|
||||||
|
class CopyData(BitmapTask):
|
||||||
|
def __init__(self, length=None, bit_w=8, transform=lambda x:x):
|
||||||
|
super(CopyData, self).__init__()
|
||||||
|
self.length = length
|
||||||
|
self.bit_w = bit_w
|
||||||
|
self.transform = transform
|
||||||
|
self.seed = None
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if self.seed is None:
|
||||||
|
self.seed = get_randstate()
|
||||||
|
|
||||||
|
length = self.length() if callable(self.length) else self.length
|
||||||
|
if length is None:
|
||||||
|
#Random length batch hack.
|
||||||
|
length = key
|
||||||
|
|
||||||
|
d = self.seed.randint(0,2,[length+1, self.bit_w+1]).astype(np.float32)
|
||||||
|
z = np.zeros_like(d)
|
||||||
|
|
||||||
|
d[-1] = 0
|
||||||
|
d[:, -1] = 0
|
||||||
|
d[-1, -1] = 1
|
||||||
|
|
||||||
|
i_p = np.concatenate((d, z), axis=0)
|
||||||
|
o_p = np.concatenate((z,d), axis=0)
|
||||||
|
|
||||||
|
return self.transform({
|
||||||
|
"input" : i_p,
|
||||||
|
"output": o_p
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
84
Dataset/Bitmap/KeyValue.py
Normal file
84
Dataset/Bitmap/KeyValue.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from .BitmapTask import BitmapTask
|
||||||
|
from Utils.Seed import get_randstate
|
||||||
|
|
||||||
|
|
||||||
|
class KeyValue(BitmapTask):
|
||||||
|
def __init__(self, length=None, bit_w=8, transform=lambda x: x):
|
||||||
|
assert bit_w % 2 == 0, "bit_w must be even"
|
||||||
|
super(KeyValue, self).__init__()
|
||||||
|
self.length = length
|
||||||
|
self.bit_w = bit_w
|
||||||
|
self.transform = transform
|
||||||
|
self.seed = None
|
||||||
|
self.key_w = self.bit_w//2
|
||||||
|
self.max_key = 2**self.key_w - 1
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if self.seed is None:
|
||||||
|
self.seed = get_randstate()
|
||||||
|
|
||||||
|
if self.length is None:
|
||||||
|
# Random length batch hack.
|
||||||
|
length = key
|
||||||
|
else:
|
||||||
|
length = self.length() if callable(self.length) else self.length
|
||||||
|
|
||||||
|
# keys must be unique
|
||||||
|
keys = None
|
||||||
|
last_size = 0
|
||||||
|
while last_size!=length:
|
||||||
|
res = self.seed.random_integers(0, self.max_key, size=(length - last_size))
|
||||||
|
if keys is not None:
|
||||||
|
keys = np.concatenate((res, keys))
|
||||||
|
else:
|
||||||
|
keys = res
|
||||||
|
|
||||||
|
keys = np.unique(keys)
|
||||||
|
last_size = keys.size
|
||||||
|
|
||||||
|
# view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it
|
||||||
|
keys = keys.view(np.uint8).reshape(length, -1)
|
||||||
|
keys = keys[:, :math.ceil(self.key_w/8)]
|
||||||
|
keys = np.unpackbits(np.expand_dims(keys,-1), axis=-1)
|
||||||
|
keys = np.flip(keys, axis=-1).reshape(keys.shape[0],-1)[:, :self.key_w]
|
||||||
|
keys = keys.astype(np.float32)
|
||||||
|
|
||||||
|
values = self.seed.randint(0,2, keys.shape).astype(np.float32)
|
||||||
|
|
||||||
|
perm = self.seed.permutation(length)
|
||||||
|
keys_perm = keys[perm,:]
|
||||||
|
values_perm = values[perm,:]
|
||||||
|
|
||||||
|
i_p = np.zeros((2*length+2, self.bit_w+1), dtype=np.float32)
|
||||||
|
i_p[:length,:self.key_w] = keys
|
||||||
|
i_p[:length,self.key_w:-1] = values
|
||||||
|
i_p[length+1:-1, :self.key_w] = keys_perm
|
||||||
|
|
||||||
|
i_p[length,-1] = 1
|
||||||
|
i_p[-1, -1] = 1
|
||||||
|
|
||||||
|
o_p = np.zeros((2*length+2, self.key_w), dtype=np.float32)
|
||||||
|
o_p[length+1:-1] = values_perm
|
||||||
|
|
||||||
|
return self.transform({
|
||||||
|
"input": i_p,
|
||||||
|
"output": o_p
|
||||||
|
})
|
92
Dataset/Bitmap/KeyValue2Way.py
Normal file
92
Dataset/Bitmap/KeyValue2Way.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from .BitmapTask import BitmapTask
|
||||||
|
from Utils.Seed import get_randstate
|
||||||
|
|
||||||
|
|
||||||
|
class KeyValue2Way(BitmapTask):
|
||||||
|
def __init__(self, length=None, bit_w=8, transform=lambda x: x):
|
||||||
|
assert bit_w % 2 == 0, "bit_w must be even"
|
||||||
|
super(KeyValue2Way, self).__init__()
|
||||||
|
self.length = length
|
||||||
|
self.bit_w = bit_w
|
||||||
|
self.transform = transform
|
||||||
|
self.seed = None
|
||||||
|
self.key_w = self.bit_w//2
|
||||||
|
self.max_key = 2**self.key_w - 1
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if self.seed is None:
|
||||||
|
self.seed = get_randstate()
|
||||||
|
|
||||||
|
if self.length is None:
|
||||||
|
# Random length batch hack.
|
||||||
|
length = key
|
||||||
|
else:
|
||||||
|
length = self.length() if callable(self.length) else self.length
|
||||||
|
|
||||||
|
# keys must be unique
|
||||||
|
keys = None
|
||||||
|
last_size = 0
|
||||||
|
while last_size!=length:
|
||||||
|
res = self.seed.random_integers(0, self.max_key, size=(length - last_size))
|
||||||
|
if keys is not None:
|
||||||
|
keys = np.concatenate((res, keys))
|
||||||
|
else:
|
||||||
|
keys = res
|
||||||
|
|
||||||
|
keys = np.unique(keys)
|
||||||
|
last_size = keys.size
|
||||||
|
|
||||||
|
# view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it
|
||||||
|
keys = keys.view(np.uint8).reshape(length, -1)
|
||||||
|
keys = keys[:, :math.ceil(self.key_w/8)]
|
||||||
|
keys = np.unpackbits(np.expand_dims(keys,-1), axis=-1)
|
||||||
|
keys = np.flip(keys, axis=-1).reshape(keys.shape[0],-1)[:, :self.key_w]
|
||||||
|
keys = keys.astype(np.float32)
|
||||||
|
|
||||||
|
values = self.seed.randint(0,2, keys.shape).astype(np.float32)
|
||||||
|
|
||||||
|
perm = self.seed.permutation(length)
|
||||||
|
keys_perm = keys[perm,:]
|
||||||
|
values_perm = values[perm,:]
|
||||||
|
|
||||||
|
i_p = np.zeros((3*(length+1), self.bit_w+2), dtype=np.float32)
|
||||||
|
o_p = np.zeros((3*(length+1), self.key_w), dtype=np.float32)
|
||||||
|
|
||||||
|
i_p[:length,:self.key_w] = keys
|
||||||
|
i_p[:length,self.key_w:-2] = values
|
||||||
|
i_p[length + 1:2*length + 1, :self.key_w] = keys_perm
|
||||||
|
o_p[length + 1:2 * length + 1] = values_perm
|
||||||
|
|
||||||
|
perm = self.seed.permutation(length)
|
||||||
|
keys_perm = keys[perm, :]
|
||||||
|
values_perm = values[perm, :]
|
||||||
|
|
||||||
|
o_p[2*(length + 1):-1] = keys_perm
|
||||||
|
i_p[2 * (length + 1):-1, :self.key_w] = values_perm
|
||||||
|
|
||||||
|
i_p[length, -2] = 1
|
||||||
|
i_p[2 * length + 1, -1] = 1
|
||||||
|
i_p[-1, -2:] = 1
|
||||||
|
|
||||||
|
return self.transform({
|
||||||
|
"input": i_p,
|
||||||
|
"output": o_p
|
||||||
|
})
|
0
Dataset/Bitmap/__init__.py
Normal file
0
Dataset/Bitmap/__init__.py
Normal file
1
Dataset/NLP/.gitignore
vendored
Normal file
1
Dataset/NLP/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
cache
|
103
Dataset/NLP/NLPTask.py
Normal file
103
Dataset/NLP/NLPTask.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import os
|
||||||
|
from .Vocabulary import Vocabulary
|
||||||
|
from Utils import Visdom
|
||||||
|
from Utils import universal as U
|
||||||
|
|
||||||
|
class NLPTask(torch.utils.data.Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
super(NLPTask, self).__init__()
|
||||||
|
|
||||||
|
self.my_dir = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
self.cache_dir = os.path.join(self.my_dir, "cache")
|
||||||
|
|
||||||
|
if not os.path.isdir(self.cache_dir):
|
||||||
|
os.makedirs(self.cache_dir)
|
||||||
|
|
||||||
|
self.vocabulary = self._load_vocabulary()
|
||||||
|
|
||||||
|
self._preview = None
|
||||||
|
|
||||||
|
def _load_vocabulary(self):
|
||||||
|
cache_file = os.path.join(self.cache_dir, "vocabulary.pth")
|
||||||
|
if not os.path.isfile(cache_file):
|
||||||
|
print("WARNING: Vocabulary not found. Removing cached files.")
|
||||||
|
for f in os.listdir(self.cache_dir):
|
||||||
|
f = os.path.join(self.cache_dir, f)
|
||||||
|
if f.endswith(".pth"):
|
||||||
|
print(" "+f)
|
||||||
|
os.remove(f)
|
||||||
|
return Vocabulary()
|
||||||
|
else:
|
||||||
|
return torch.load(cache_file)
|
||||||
|
|
||||||
|
def save_vocabulary(self):
|
||||||
|
cache_file = os.path.join(self.cache_dir, "vocabulary.pth")
|
||||||
|
if os.path.isfile(cache_file):
|
||||||
|
os.remove(cache_file)
|
||||||
|
torch.save(self.vocabulary, cache_file)
|
||||||
|
|
||||||
|
def loss(self, net_output, target):
|
||||||
|
s = list(net_output.size())
|
||||||
|
return F.cross_entropy(net_output.view([s[0]*s[1], s[2]]), target.view([-1]), ignore_index=0,
|
||||||
|
reduction='sum')/s[0]
|
||||||
|
|
||||||
|
def generate_preview_text(self, data, net_output):
|
||||||
|
input = U.to_numpy(data["input"][0])
|
||||||
|
reference = U.to_numpy(data["output"][0])
|
||||||
|
net_out = U.argmax(net_output[0], -1)
|
||||||
|
net_out = U.to_numpy(net_out)
|
||||||
|
|
||||||
|
res = ""
|
||||||
|
start_index = 0
|
||||||
|
|
||||||
|
for i in range(input.shape[0]):
|
||||||
|
if reference[i] != 0:
|
||||||
|
if start_index < i:
|
||||||
|
end_index = i
|
||||||
|
while end_index>start_index and input[end_index]==0:
|
||||||
|
end_index -= 1
|
||||||
|
|
||||||
|
if end_index>start_index:
|
||||||
|
sentence = " ".join(self.vocabulary.indices_to_sentence(input[start_index:i].tolist())). \
|
||||||
|
replace(" .", ".").replace(" ,", ",").replace(" ?", "?").split(". ")
|
||||||
|
sentence = ". ".join([s.capitalize() for s in sentence])
|
||||||
|
res += sentence + "<br>"
|
||||||
|
|
||||||
|
start_index = i + 1
|
||||||
|
|
||||||
|
match = reference[i] == net_out[i]
|
||||||
|
res += "<b><font color=\"%s\">%s [%s]</font><br></b>" % ("green" if match else "red",
|
||||||
|
self.vocabulary.indices_to_sentence(
|
||||||
|
[net_out[i]])[0],
|
||||||
|
self.vocabulary.indices_to_sentence(
|
||||||
|
[reference[i]])[0])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def visualize_preview(self, data, net_output):
|
||||||
|
res = self.generate_preview_text(data, net_output)
|
||||||
|
|
||||||
|
if self._preview is None:
|
||||||
|
self._preview = Visdom.Text("Preview")
|
||||||
|
|
||||||
|
self._preview.set(res)
|
||||||
|
|
||||||
|
def set_dump_dir(self, dir):
|
||||||
|
pass
|
49
Dataset/NLP/Vocabulary.py
Normal file
49
Dataset/NLP/Vocabulary.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class Vocabulary:
|
||||||
|
def __init__(self):
|
||||||
|
self.words = {"-" : 0, "?": 1, "<UNK>": 2}
|
||||||
|
self.inv_words = {0 : "-", 1: "?", 2: "<UNK>"}
|
||||||
|
self.next_id = 3
|
||||||
|
self.punctations = [".", "?", ","]
|
||||||
|
|
||||||
|
def _process_word(self, w, add_words):
|
||||||
|
if not w.isalpha() and w not in self.punctations:
|
||||||
|
print("WARNING: word with unknown characters: %s", w)
|
||||||
|
w = "<UNK>"
|
||||||
|
|
||||||
|
if w not in self.words:
|
||||||
|
if add_words:
|
||||||
|
self.words[w] = self.next_id
|
||||||
|
self.inv_words[self.next_id] = w
|
||||||
|
self.next_id += 1
|
||||||
|
else:
|
||||||
|
w = "<UNK>"
|
||||||
|
|
||||||
|
return self.words[w]
|
||||||
|
|
||||||
|
def sentence_to_indices(self, sentence, add_words=True):
|
||||||
|
for p in self.punctations:
|
||||||
|
sentence = sentence.replace(p, " %s " % p)
|
||||||
|
|
||||||
|
return [self._process_word(w, add_words) for w in sentence.lower().split(" ") if w]
|
||||||
|
|
||||||
|
def indices_to_sentence(self, indices):
|
||||||
|
return [self.inv_words[i] for i in indices]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.words)
|
0
Dataset/NLP/__init__.py
Normal file
0
Dataset/NLP/__init__.py
Normal file
270
Dataset/NLP/bAbi.py
Normal file
270
Dataset/NLP/bAbi.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import torch
|
||||||
|
from collections import namedtuple
|
||||||
|
import numpy as np
|
||||||
|
from .NLPTask import NLPTask
|
||||||
|
from Utils import Visdom
|
||||||
|
|
||||||
|
Sentence = namedtuple('Sentence', ['sentence', 'answer', 'supporting_facts'])
|
||||||
|
|
||||||
|
class bAbiDataset(NLPTask):
|
||||||
|
URL = 'http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz'
|
||||||
|
DIR_NAME = "tasks_1-20_v1-2"
|
||||||
|
|
||||||
|
def __init__(self, dirs = ["en-10k"], sets=None, think_steps=0, dir_name=None, name=None):
|
||||||
|
super(bAbiDataset, self).__init__()
|
||||||
|
|
||||||
|
self._test_res_win = None
|
||||||
|
self._test_plot_win = None
|
||||||
|
self._think_steps = think_steps
|
||||||
|
|
||||||
|
if dir_name is None:
|
||||||
|
self._download()
|
||||||
|
dir_name = os.path.join(self.cache_dir, self.DIR_NAME)
|
||||||
|
|
||||||
|
self.data={}
|
||||||
|
for d in dirs:
|
||||||
|
self.data[d] = self._load_or_create(os.path.join(dir_name, d))
|
||||||
|
|
||||||
|
self.all_tasks=None
|
||||||
|
self.name = name
|
||||||
|
self.use(sets=sets)
|
||||||
|
|
||||||
|
def _make_active_list(self, tasks, sets, dirs):
|
||||||
|
def verify(name, checker):
|
||||||
|
if checker is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if callable(checker):
|
||||||
|
return checker(name)
|
||||||
|
elif isinstance(checker, list):
|
||||||
|
return name in checker
|
||||||
|
else:
|
||||||
|
return name==checker
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for dirname, setlist in self.data.items():
|
||||||
|
if not verify(dirname, dirs):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for sname, tasklist in setlist.items():
|
||||||
|
if not verify(sname, sets):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for task, data in tasklist.items():
|
||||||
|
name = task.split("_")[0][2:]
|
||||||
|
if not verify(name, tasks):
|
||||||
|
continue
|
||||||
|
|
||||||
|
res += [(d, dirname, task, sname) for d in data]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def use(self, tasks=None, sets=None, dirs=None):
|
||||||
|
self.all_tasks=self._make_active_list(tasks=tasks, sets=sets, dirs=dirs)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.all_tasks)
|
||||||
|
|
||||||
|
def _get_seq(self, index):
|
||||||
|
return self.all_tasks[index]
|
||||||
|
|
||||||
|
def _seq_to_nn_input(self, seq):
|
||||||
|
in_arr = []
|
||||||
|
out_arr = []
|
||||||
|
hasAnswer = False
|
||||||
|
for sentence in seq[0]:
|
||||||
|
in_arr += sentence.sentence
|
||||||
|
out_arr += [0] * len(sentence.sentence)
|
||||||
|
if sentence.answer is not None:
|
||||||
|
in_arr += [0] * (len(sentence.answer) + self._think_steps)
|
||||||
|
out_arr += [0] * self._think_steps + sentence.answer
|
||||||
|
hasAnswer = True
|
||||||
|
|
||||||
|
in_arr = np.asarray(in_arr, np.int64)
|
||||||
|
out_arr = np.asarray(out_arr, np.int64)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input": in_arr,
|
||||||
|
"output": out_arr,
|
||||||
|
"meta": {
|
||||||
|
"dir": seq[1],
|
||||||
|
"task": seq[2],
|
||||||
|
"set": seq[3]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
seq = self._get_seq(item)
|
||||||
|
return self._seq_to_nn_input(seq)
|
||||||
|
|
||||||
|
def _load_or_create(self, directory):
|
||||||
|
cache_name = directory.replace("/","_")
|
||||||
|
cache_file = os.path.join(self.cache_dir, cache_name+".pth")
|
||||||
|
if not os.path.isfile(cache_file):
|
||||||
|
print("bAbI: Loading %s" % directory)
|
||||||
|
res = self._load_dir(directory)
|
||||||
|
print("Write: ", cache_file)
|
||||||
|
self.save_vocabulary()
|
||||||
|
torch.save(res, cache_file)
|
||||||
|
else:
|
||||||
|
res = torch.load(cache_file)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _download(self):
|
||||||
|
if not os.path.isdir(os.path.join(self.cache_dir, self.DIR_NAME)):
|
||||||
|
print(self.URL)
|
||||||
|
print("bAbi data not found. Downloading...")
|
||||||
|
import requests, tarfile, io
|
||||||
|
request = requests.get(self.URL, headers={"User-agent":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36"})
|
||||||
|
|
||||||
|
decompressed_file = tarfile.open(fileobj=io.BytesIO(request.content), mode='r|gz')
|
||||||
|
decompressed_file.extractall(self.cache_dir)
|
||||||
|
print("Done")
|
||||||
|
|
||||||
|
def _load_dir(self, directory, parse_name = lambda x: x.split(".")[0], parse_set = lambda x: x.split(".")[0].split("_")[-1]):
|
||||||
|
res = {}
|
||||||
|
for f in glob.glob(os.path.join(directory, '**', '*.txt'), recursive=True):
|
||||||
|
basename = os.path.basename(f)
|
||||||
|
task_name = parse_name(basename)
|
||||||
|
set = parse_set(basename)
|
||||||
|
print("Loading", f)
|
||||||
|
|
||||||
|
s = res.get(set)
|
||||||
|
if s is None:
|
||||||
|
s = {}
|
||||||
|
res[set] = s
|
||||||
|
s[task_name] = self._load_task(f, task_name)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _load_task(self, filename, task_name):
|
||||||
|
task = []
|
||||||
|
currTask = []
|
||||||
|
|
||||||
|
nextIndex = 1
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
line = [f.strip() for f in line.split("\t")]
|
||||||
|
line[0] = line[0].split(" ")
|
||||||
|
i = int(line[0][0])
|
||||||
|
line[0] = " ".join(line[0][1:])
|
||||||
|
|
||||||
|
if i!=nextIndex:
|
||||||
|
nextIndex = i
|
||||||
|
task.append(currTask)
|
||||||
|
currTask = []
|
||||||
|
|
||||||
|
isQuestion = len(line)>1
|
||||||
|
currTask.append(
|
||||||
|
Sentence(self.vocabulary.sentence_to_indices(line[0]), self.vocabulary.sentence_to_indices(line[1].replace(",", " "))
|
||||||
|
if isQuestion else None, [int(f) for f in line[2].split(" ")] if isQuestion else None)
|
||||||
|
)
|
||||||
|
|
||||||
|
nextIndex += 1
|
||||||
|
return task
|
||||||
|
|
||||||
|
def start_test(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def veify_result(self, test, data, net_output):
|
||||||
|
_, net_output = net_output.max(-1)
|
||||||
|
|
||||||
|
ref = data["output"]
|
||||||
|
|
||||||
|
mask = 1.0 - ref.eq(0).float()
|
||||||
|
|
||||||
|
correct = (torch.eq(net_output, ref).float() * mask).sum(-1)
|
||||||
|
total = mask.sum(-1)
|
||||||
|
|
||||||
|
correct = correct.data.cpu().numpy()
|
||||||
|
total = total.data.cpu().numpy()
|
||||||
|
|
||||||
|
for i in range(correct.shape[0]):
|
||||||
|
task = data["meta"][i]["task"]
|
||||||
|
if task not in test:
|
||||||
|
test[task] = {"total": 0, "correct": 0}
|
||||||
|
|
||||||
|
d = test[task]
|
||||||
|
d["total"] += total[i]
|
||||||
|
d["correct"] += correct[i]
|
||||||
|
|
||||||
|
def _ensure_test_wins_exists(self, legend = None):
|
||||||
|
if self._test_res_win is None:
|
||||||
|
n = (("[" + self.name + "]") if self.name is not None else "")
|
||||||
|
self._test_res_win = Visdom.Text("Test results" + n)
|
||||||
|
self._test_plot_win = Visdom.Plot2D("Test results" + n, legend=legend)
|
||||||
|
elif self._test_plot_win.legend is None:
|
||||||
|
self._test_plot_win.set_legend(legend=legend)
|
||||||
|
|
||||||
|
def show_test_results(self, iteration, test):
|
||||||
|
res = {k: v["correct"]/v["total"] for k, v in test.items()}
|
||||||
|
|
||||||
|
t = ""
|
||||||
|
|
||||||
|
all_keys = list(res.keys())
|
||||||
|
|
||||||
|
num_keys = [k for k in all_keys if k.startswith("qa")]
|
||||||
|
tmp = [i[0] for i in sorted(enumerate(num_keys), key=lambda x:int(x[1][2:].split("_")[0]))]
|
||||||
|
num_keys = [num_keys[j] for j in tmp]
|
||||||
|
|
||||||
|
all_keys = num_keys + sorted([k for k in all_keys if not k.startswith("qa")])
|
||||||
|
|
||||||
|
err_precent = [(1.0-res[k]) * 100.0 for k in all_keys]
|
||||||
|
|
||||||
|
n_passed = sum([int(p<=5) for p in err_precent])
|
||||||
|
n_total = len(err_precent)
|
||||||
|
err_precent = err_precent + [sum(err_precent) / len(err_precent)]
|
||||||
|
all_keys += ["mean"]
|
||||||
|
|
||||||
|
for i, k in enumerate(all_keys):
|
||||||
|
t += "<font color=\"%s\">%s: <b>%.2f%%</b></font><br>" % ("green" if err_precent[i] <= 5 else "red", k, err_precent[i])
|
||||||
|
|
||||||
|
t += "<br><b>Total: %d of %d passed.</b>" % (n_passed, n_total)
|
||||||
|
|
||||||
|
self._ensure_test_wins_exists(legend=[i.split("_")[0] if i.startswith("qa") else i for i in all_keys])
|
||||||
|
|
||||||
|
self._test_res_win.set(t)
|
||||||
|
self._test_plot_win.add_point(iteration, err_precent)
|
||||||
|
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
if self._test_res_win is not None:
|
||||||
|
return {
|
||||||
|
"_test_res_win" : self._test_res_win.state_dict(),
|
||||||
|
"_test_plot_win": self._test_plot_win.state_dict(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state):
|
||||||
|
if state:
|
||||||
|
self._ensure_test_wins_exists()
|
||||||
|
self._test_res_win.load_state_dict(state["_test_res_win"])
|
||||||
|
self._test_plot_win.load_state_dict(state["_test_plot_win"])
|
||||||
|
self._test_plot_win.legend = None
|
||||||
|
|
||||||
|
def visualize_preview(self, data, net_output):
|
||||||
|
res = self.generate_preview_text(data, net_output)
|
||||||
|
res = ("<b><u>%s</u></b><br>" % data["meta"][0]["task"]) + res
|
||||||
|
if self._preview is None:
|
||||||
|
self._preview = Visdom.Text("Preview")
|
||||||
|
|
||||||
|
self._preview.set(res)
|
1
Dataset/__init__.py
Normal file
1
Dataset/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
677
Models/DNC.py
Normal file
677
Models/DNC.py
Normal file
@ -0,0 +1,677 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.init as init
|
||||||
|
import functools
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def oneplus(t):
|
||||||
|
return F.softplus(t, 1, 20) + 1.0
|
||||||
|
|
||||||
|
def get_next_tensor_part(src, dims, prev_pos=0):
|
||||||
|
if not isinstance(dims, list):
|
||||||
|
dims=[dims]
|
||||||
|
n = functools.reduce(lambda x, y: x * y, dims)
|
||||||
|
data = src.narrow(-1, prev_pos, n)
|
||||||
|
return data.contiguous().view(list(data.size())[:-1] + dims) if len(dims)>1 else data, prev_pos + n
|
||||||
|
|
||||||
|
def split_tensor(src, shapes):
|
||||||
|
pos = 0
|
||||||
|
res = []
|
||||||
|
for s in shapes:
|
||||||
|
d, pos = get_next_tensor_part(src, s, pos)
|
||||||
|
res.append(d)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def dict_get(dict,name):
|
||||||
|
return dict.get(name) if dict is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
def dict_append(dict, name, val):
|
||||||
|
if dict is not None:
|
||||||
|
l = dict.get(name)
|
||||||
|
if not l:
|
||||||
|
l = []
|
||||||
|
dict[name] = l
|
||||||
|
l.append(val)
|
||||||
|
|
||||||
|
|
||||||
|
def init_debug(debug, initial):
|
||||||
|
if debug is not None and not debug:
|
||||||
|
debug.update(initial)
|
||||||
|
|
||||||
|
def merge_debug_tensors(d, dim):
|
||||||
|
if d is not None:
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
merge_debug_tensors(v, dim)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
d[k] = torch.stack(v, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_reset(module, gain=1.0):
|
||||||
|
assert isinstance(module, torch.nn.Linear)
|
||||||
|
init.xavier_uniform_(module.weight, gain=gain)
|
||||||
|
s = module.weight.size(1)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
_EPS = 1e-6
|
||||||
|
|
||||||
|
class AllocationManager(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(AllocationManager, self).__init__()
|
||||||
|
self.usages = None
|
||||||
|
self.zero_usages = None
|
||||||
|
self.debug_sequ_init = False
|
||||||
|
self.one = None
|
||||||
|
|
||||||
|
def _init_sequence(self, prev_read_distributions):
|
||||||
|
# prev_read_distributions size is [batch, n_heads, cell count]
|
||||||
|
s = prev_read_distributions.size()
|
||||||
|
if self.zero_usages is None or list(self.zero_usages.size())!=[s[0],s[-1]]:
|
||||||
|
self.zero_usages = torch.zeros(s[0], s[-1], device = prev_read_distributions.device)
|
||||||
|
if self.debug_sequ_init:
|
||||||
|
self.zero_usages += torch.arange(0, s[-1]).unsqueeze(0) * 1e-10
|
||||||
|
|
||||||
|
self.usages = self.zero_usages
|
||||||
|
|
||||||
|
def _init_consts(self, device):
|
||||||
|
if self.one is None:
|
||||||
|
self.one = torch.ones(1, device=device)
|
||||||
|
|
||||||
|
def new_sequence(self):
|
||||||
|
self.usages = None
|
||||||
|
|
||||||
|
def update_usages(self, prev_write_distribution, prev_read_distributions, free_gates):
|
||||||
|
# Read distributions shape: [batch, n_heads, cell count]
|
||||||
|
# Free gates shape: [batch, n_heads]
|
||||||
|
|
||||||
|
self._init_consts(prev_read_distributions.device)
|
||||||
|
phi = torch.addcmul(self.one, -1, free_gates.unsqueeze(-1), prev_read_distributions).prod(-2)
|
||||||
|
# Phi is the free tensor, sized [batch, cell count]
|
||||||
|
|
||||||
|
# If memory usage counter if doesn't exists
|
||||||
|
if self.usages is None:
|
||||||
|
self._init_sequence(prev_read_distributions)
|
||||||
|
# in first timestep nothing is written or read yet, so we don't need any further processing
|
||||||
|
else:
|
||||||
|
self.usages = torch.addcmul(self.usages, 1, prev_write_distribution.detach(), (1 - self.usages)) * phi
|
||||||
|
|
||||||
|
return phi
|
||||||
|
|
||||||
|
def forward(self, prev_write_distribution, prev_read_distributions, free_gates):
|
||||||
|
phi = self.update_usages(prev_write_distribution, prev_read_distributions, free_gates)
|
||||||
|
sorted_usage, free_list = (self.usages*(1.0-_EPS)+_EPS).sort(-1)
|
||||||
|
|
||||||
|
u_prod = sorted_usage.cumprod(-1)
|
||||||
|
one_minus_usage = 1.0 - sorted_usage
|
||||||
|
sorted_scores = torch.cat([one_minus_usage[..., 0:1], one_minus_usage[..., 1:] * u_prod[..., :-1]], dim=-1)
|
||||||
|
|
||||||
|
return sorted_scores.clone().scatter_(-1, free_list, sorted_scores), phi
|
||||||
|
|
||||||
|
|
||||||
|
class ContentAddressGenerator(torch.nn.Module):
|
||||||
|
def __init__(self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
|
||||||
|
super(ContentAddressGenerator, self).__init__()
|
||||||
|
self.disable_content_norm = disable_content_norm
|
||||||
|
self.mask_min = mask_min
|
||||||
|
self.disable_key_masking = disable_key_masking
|
||||||
|
|
||||||
|
def forward(self, memory, keys, betas, mask=None):
|
||||||
|
# Memory shape [batch, cell count, word length]
|
||||||
|
# Key shape [batch, n heads*, word length]
|
||||||
|
# Betas shape [batch, n heads]
|
||||||
|
if mask is not None and self.mask_min != 0:
|
||||||
|
mask = mask * (1.0-self.mask_min) + self.mask_min
|
||||||
|
|
||||||
|
single_head = keys.dim() == 2
|
||||||
|
if single_head:
|
||||||
|
# Single head
|
||||||
|
keys = keys.unsqueeze(1)
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
memory = memory.unsqueeze(1)
|
||||||
|
keys = keys.unsqueeze(-2)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(-2)
|
||||||
|
memory = memory * mask
|
||||||
|
if not self.disable_key_masking:
|
||||||
|
keys = keys * mask
|
||||||
|
|
||||||
|
# Shape [batch, n heads, cell count]
|
||||||
|
norm = keys.norm(dim=-1)
|
||||||
|
if not self.disable_content_norm:
|
||||||
|
norm = norm * memory.norm(dim=-1)
|
||||||
|
|
||||||
|
scores = (memory * keys).sum(-1) / (norm + _EPS)
|
||||||
|
scores *= betas.unsqueeze(-1)
|
||||||
|
|
||||||
|
res = F.softmax(scores, scores.dim()-1)
|
||||||
|
return res.squeeze(1) if single_head else res
|
||||||
|
|
||||||
|
|
||||||
|
class WriteHead(torch.nn.Module):
|
||||||
|
@staticmethod
|
||||||
|
def create_write_archive(write_dist, erase_vector, write_vector, phi):
|
||||||
|
return dict(write_dist=write_dist, erase_vector=erase_vector, write_vector=write_vector, phi=phi)
|
||||||
|
|
||||||
|
def __init__(self, dealloc_content=True, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
|
||||||
|
super(WriteHead, self).__init__()
|
||||||
|
self.write_content_generator = ContentAddressGenerator(disable_content_norm, mask_min=mask_min, disable_key_masking=disable_key_masking)
|
||||||
|
self.allocation_manager = AllocationManager()
|
||||||
|
self.last_write = None
|
||||||
|
self.dealloc_content = dealloc_content
|
||||||
|
self.new_sequence()
|
||||||
|
|
||||||
|
def new_sequence(self):
|
||||||
|
self.last_write = None
|
||||||
|
self.allocation_manager.new_sequence()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mem_update(memory, write_dist, erase_vector, write_vector, phi):
|
||||||
|
# In original paper the memory content is NOT deallocated, which makes content based addressing basically
|
||||||
|
# unusable when multiple similar steps should be done. The reason for this is that the memory contents are
|
||||||
|
# still there, so the lookup will find them, unless an allocation clears it before the next search, which is
|
||||||
|
# completely random. So I'm arguing that erase matrix should also take in account the free gates (multiply it
|
||||||
|
# with phi)
|
||||||
|
write_dist = write_dist.unsqueeze(-1)
|
||||||
|
|
||||||
|
erase_matrix = 1.0 - write_dist * erase_vector.unsqueeze(-2)
|
||||||
|
if phi is not None:
|
||||||
|
erase_matrix = erase_matrix * phi.unsqueeze(-1)
|
||||||
|
|
||||||
|
update_matrix = write_dist * write_vector.unsqueeze(-2)
|
||||||
|
return memory * erase_matrix + update_matrix
|
||||||
|
|
||||||
|
def forward(self, memory, write_content_key, write_beta, erase_vector, write_vector, alloc_gate, write_gate,
|
||||||
|
free_gates, prev_read_dist, write_mask=None, debug=None):
|
||||||
|
last_w_dist = self.last_write["write_dist"] if self.last_write is not None else None
|
||||||
|
|
||||||
|
content_dist = self.write_content_generator(memory, write_content_key, write_beta, mask = write_mask)
|
||||||
|
alloc_dist, phi = self.allocation_manager(last_w_dist, prev_read_dist, free_gates)
|
||||||
|
|
||||||
|
# Shape [batch, cell count]
|
||||||
|
write_dist = write_gate * (alloc_gate * alloc_dist + (1-alloc_gate)*content_dist)
|
||||||
|
self.last_write = WriteHead.create_write_archive(write_dist, erase_vector, write_vector, phi if self.dealloc_content else None)
|
||||||
|
|
||||||
|
dict_append(debug, "alloc_dist", alloc_dist)
|
||||||
|
dict_append(debug, "write_dist", write_dist)
|
||||||
|
dict_append(debug, "mem_usages", self.allocation_manager.usages)
|
||||||
|
dict_append(debug, "free_gates", free_gates)
|
||||||
|
dict_append(debug, "write_betas", write_beta)
|
||||||
|
dict_append(debug, "write_gate", write_gate)
|
||||||
|
dict_append(debug, "write_vector", write_vector)
|
||||||
|
dict_append(debug, "alloc_gate", alloc_gate)
|
||||||
|
dict_append(debug, "erase_vector", erase_vector)
|
||||||
|
if write_mask is not None:
|
||||||
|
dict_append(debug, "write_mask", write_mask)
|
||||||
|
|
||||||
|
return WriteHead.mem_update(memory, **self.last_write)
|
||||||
|
|
||||||
|
class RawWriteHead(torch.nn.Module):
|
||||||
|
def __init__(self, n_read_heads, word_length, use_mask=False, dealloc_content=True, disable_content_norm=False,
|
||||||
|
mask_min=0.0, disable_key_masking=False):
|
||||||
|
super(RawWriteHead, self).__init__()
|
||||||
|
self.write_head = WriteHead(dealloc_content = dealloc_content, disable_content_norm = disable_content_norm,
|
||||||
|
mask_min=mask_min, disable_key_masking=disable_key_masking)
|
||||||
|
self.word_length = word_length
|
||||||
|
self.n_read_heads = n_read_heads
|
||||||
|
self.use_mask = use_mask
|
||||||
|
self.input_size = 3*self.word_length + self.n_read_heads + 3 + (self.word_length if use_mask else 0)
|
||||||
|
|
||||||
|
def new_sequence(self):
|
||||||
|
self.write_head.new_sequence()
|
||||||
|
|
||||||
|
def get_prev_write(self):
|
||||||
|
return self.write_head.last_write
|
||||||
|
|
||||||
|
def forward(self, memory, nn_output, prev_read_dist, debug):
|
||||||
|
shapes = [[self.word_length]] * (4 if self.use_mask else 3) + [[self.n_read_heads]] + [[1]] * 3
|
||||||
|
tensors = split_tensor(nn_output, shapes)
|
||||||
|
|
||||||
|
if self.use_mask:
|
||||||
|
write_mask = torch.sigmoid(tensors[0])
|
||||||
|
tensors=tensors[1:]
|
||||||
|
else:
|
||||||
|
write_mask = None
|
||||||
|
|
||||||
|
write_content_key, erase_vector, write_vector, free_gates, write_beta, alloc_gate, write_gate = tensors
|
||||||
|
|
||||||
|
erase_vector = torch.sigmoid(erase_vector)
|
||||||
|
free_gates = torch.sigmoid(free_gates)
|
||||||
|
write_beta = oneplus(write_beta)
|
||||||
|
alloc_gate = torch.sigmoid(alloc_gate)
|
||||||
|
write_gate = torch.sigmoid(write_gate)
|
||||||
|
|
||||||
|
return self.write_head(memory, write_content_key, write_beta, erase_vector, write_vector,
|
||||||
|
alloc_gate, write_gate, free_gates, prev_read_dist, debug=debug, write_mask=write_mask)
|
||||||
|
|
||||||
|
def get_neural_input_size(self):
|
||||||
|
return self.input_size
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalMemoryLinkage(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(TemporalMemoryLinkage, self).__init__()
|
||||||
|
self.temp_link_mat = None
|
||||||
|
self.precedence_weighting = None
|
||||||
|
self.diag_mask = None
|
||||||
|
|
||||||
|
self.initial_temp_link_mat = None
|
||||||
|
self.initial_precedence_weighting = None
|
||||||
|
self.initial_diag_mask = None
|
||||||
|
self.initial_shape = None
|
||||||
|
|
||||||
|
def new_sequence(self):
|
||||||
|
self.temp_link_mat = None
|
||||||
|
self.precedence_weighting = None
|
||||||
|
self.diag_mask = None
|
||||||
|
|
||||||
|
def _init_link(self, w_dist):
|
||||||
|
s = list(w_dist.size())
|
||||||
|
if self.initial_shape is None or s != self.initial_shape:
|
||||||
|
self.initial_temp_link_mat = torch.zeros(s[0], s[-1], s[-1]).to(w_dist.device)
|
||||||
|
self.initial_precedence_weighting = torch.zeros(s[0], s[-1]).to(w_dist.device)
|
||||||
|
self.initial_diag_mask = (1.0 - torch.eye(s[-1]).unsqueeze(0).to(w_dist)).detach()
|
||||||
|
|
||||||
|
self.temp_link_mat = self.initial_temp_link_mat
|
||||||
|
self.precedence_weighting = self.initial_precedence_weighting
|
||||||
|
self.diag_mask = self.initial_diag_mask
|
||||||
|
|
||||||
|
def _update_precedence(self, w_dist):
|
||||||
|
# w_dist shape: [ batch, cell count ]
|
||||||
|
self.precedence_weighting = (1.0 - w_dist.sum(-1, keepdim=True)) * self.precedence_weighting + w_dist
|
||||||
|
|
||||||
|
def _update_links(self, w_dist):
|
||||||
|
if self.temp_link_mat is None:
|
||||||
|
self._init_link(w_dist)
|
||||||
|
|
||||||
|
wt_i = w_dist.unsqueeze(-1)
|
||||||
|
wt_j = w_dist.unsqueeze(-2)
|
||||||
|
pt_j = self.precedence_weighting.unsqueeze(-2)
|
||||||
|
|
||||||
|
self.temp_link_mat = ((1 - wt_i - wt_j) * self.temp_link_mat + wt_i * pt_j) * self.diag_mask
|
||||||
|
|
||||||
|
def forward(self, w_dist, prev_r_dists, debug = None):
|
||||||
|
self._update_links(w_dist)
|
||||||
|
self._update_precedence(w_dist)
|
||||||
|
|
||||||
|
# prev_r_dists shape: [ batch, n heads, cell count ]
|
||||||
|
# Emulate matrix-vector multiplication by broadcast and sum. This way we don't need to transpose the matrix
|
||||||
|
tlm_multi_head = self.temp_link_mat.unsqueeze(1)
|
||||||
|
|
||||||
|
forward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-2)).sum(-1)
|
||||||
|
backward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-1)).sum(-2)
|
||||||
|
|
||||||
|
dict_append(debug, "forward_dists", forward_dist)
|
||||||
|
dict_append(debug, "backward_dists", backward_dist)
|
||||||
|
dict_append(debug, "precedence_weights", self.precedence_weighting)
|
||||||
|
|
||||||
|
# output shapes [ batch, n_heads, cell_count ]
|
||||||
|
return forward_dist, backward_dist
|
||||||
|
|
||||||
|
|
||||||
|
class ReadHead(torch.nn.Module):
|
||||||
|
def __init__(self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
|
||||||
|
super(ReadHead, self).__init__()
|
||||||
|
self.content_addr_generator = ContentAddressGenerator(disable_content_norm=disable_content_norm,
|
||||||
|
mask_min=mask_min,
|
||||||
|
disable_key_masking=disable_key_masking)
|
||||||
|
self.read_dist = None
|
||||||
|
self.read_data = None
|
||||||
|
self.new_sequence()
|
||||||
|
|
||||||
|
def new_sequence(self):
|
||||||
|
self.read_dist = None
|
||||||
|
self.read_data = None
|
||||||
|
|
||||||
|
def forward(self, memory, read_content_keys, read_betas, forward_dist, backward_dist, gates, read_mask=None, debug=None):
|
||||||
|
content_dist = self.content_addr_generator(memory, read_content_keys, read_betas, mask=read_mask)
|
||||||
|
|
||||||
|
self.read_dist = backward_dist * gates[..., 0:1] + content_dist * gates[...,1:2] + forward_dist * gates[..., 2:]
|
||||||
|
|
||||||
|
# memory shape: [ batch, cell count, word_length ]
|
||||||
|
# read_dist shape: [ batch, n heads, cell count ]
|
||||||
|
# result shape: [ batch, n_heads, word_length ]
|
||||||
|
self.read_data = (memory.unsqueeze(1) * self.read_dist.unsqueeze(-1)).sum(-2)
|
||||||
|
|
||||||
|
dict_append(debug, "content_dist", content_dist)
|
||||||
|
dict_append(debug, "balance", gates)
|
||||||
|
dict_append(debug, "read_dist", self.read_dist)
|
||||||
|
dict_append(debug, "read_content_keys", read_content_keys)
|
||||||
|
if read_mask is not None:
|
||||||
|
dict_append(debug, "read_mask", read_mask)
|
||||||
|
dict_append(debug, "read_betas", read_betas.unsqueeze(-2))
|
||||||
|
if read_mask is not None:
|
||||||
|
dict_append(debug, "read_mask", read_mask)
|
||||||
|
|
||||||
|
return self.read_data
|
||||||
|
|
||||||
|
|
||||||
|
class RawReadHead(torch.nn.Module):
|
||||||
|
def __init__(self, n_heads, word_length, use_mask=False, disable_content_norm=False, mask_min=0.0,
|
||||||
|
disable_key_masking=False):
|
||||||
|
super(RawReadHead, self).__init__()
|
||||||
|
self.read_head = ReadHead(disable_content_norm=disable_content_norm, mask_min=mask_min,
|
||||||
|
disable_key_masking=disable_key_masking)
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.word_length = word_length
|
||||||
|
self.use_mask = use_mask
|
||||||
|
self.input_size = self.n_heads * (self.word_length*(2 if use_mask else 1) + 3 + 1)
|
||||||
|
|
||||||
|
def get_prev_dist(self, memory):
|
||||||
|
if self.read_head.read_dist is not None:
|
||||||
|
return self.read_head.read_dist
|
||||||
|
else:
|
||||||
|
m_shape = memory.size()
|
||||||
|
return torch.zeros(m_shape[0], self.n_heads, m_shape[1]).to(memory)
|
||||||
|
|
||||||
|
def get_prev_data(self, memory):
|
||||||
|
if self.read_head.read_data is not None:
|
||||||
|
return self.read_head.read_data
|
||||||
|
else:
|
||||||
|
m_shape = memory.size()
|
||||||
|
return torch.zeros(m_shape[0], self.n_heads, m_shape[-1]).to(memory)
|
||||||
|
|
||||||
|
def new_sequence(self):
|
||||||
|
self.read_head.new_sequence()
|
||||||
|
|
||||||
|
def forward(self, memory, nn_output, forward_dist, backward_dist, debug):
|
||||||
|
shapes = [[self.n_heads, self.word_length]] * (2 if self.use_mask else 1) + [[self.n_heads], [self.n_heads, 3]]
|
||||||
|
tensors = split_tensor(nn_output, shapes)
|
||||||
|
|
||||||
|
if self.use_mask:
|
||||||
|
read_mask = torch.sigmoid(tensors[0])
|
||||||
|
tensors = tensors[1:]
|
||||||
|
else:
|
||||||
|
read_mask = None
|
||||||
|
|
||||||
|
keys, betas, gates = tensors
|
||||||
|
|
||||||
|
betas = oneplus(betas)
|
||||||
|
gates = F.softmax(gates, gates.dim()-1)
|
||||||
|
|
||||||
|
return self.read_head(memory, keys, betas, forward_dist, backward_dist, gates, debug=debug, read_mask=read_mask)
|
||||||
|
|
||||||
|
def get_neural_input_size(self):
|
||||||
|
return self.input_size
|
||||||
|
|
||||||
|
|
||||||
|
class DistSharpnessEnhancer(torch.nn.Module):
|
||||||
|
def __init__(self, n_heads):
|
||||||
|
super(DistSharpnessEnhancer, self).__init__()
|
||||||
|
self.n_heads = n_heads if isinstance(n_heads, list) else [n_heads]
|
||||||
|
self.n_data = sum(self.n_heads)
|
||||||
|
|
||||||
|
def forward(self, nn_input, *dists):
|
||||||
|
assert len(dists) == len(self.n_heads)
|
||||||
|
nn_input = oneplus(nn_input[..., :self.n_data])
|
||||||
|
factors = split_tensor(nn_input, self.n_heads)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for i, d in enumerate(dists):
|
||||||
|
s = list(d.size())
|
||||||
|
ndim = d.dim()
|
||||||
|
f = factors[i]
|
||||||
|
if ndim==2:
|
||||||
|
assert self.n_heads[i]==1
|
||||||
|
elif ndim==3:
|
||||||
|
f = f.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
d += _EPS
|
||||||
|
d = d / d.max(dim=-1, keepdim=True)[0]
|
||||||
|
d = d.pow(f)
|
||||||
|
d = d / d.sum(dim=-1, keepdim=True)
|
||||||
|
res.append(d)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_neural_input_size(self):
|
||||||
|
return self.n_data
|
||||||
|
|
||||||
|
|
||||||
|
class DNC(torch.nn.Module):
|
||||||
|
def __init__(self, input_size, output_size, word_length, cell_count, n_read_heads, controller, batch_first=False, clip_controller=20,
|
||||||
|
bias=True, mask=False, dealloc_content=True, link_sharpness_control=True, disable_content_norm=False,
|
||||||
|
mask_min=0.0, disable_key_masking=False):
|
||||||
|
super(DNC, self).__init__()
|
||||||
|
|
||||||
|
self.clip_controller = clip_controller
|
||||||
|
|
||||||
|
self.read_head = RawReadHead(n_read_heads, word_length, use_mask=mask, disable_content_norm=disable_content_norm,
|
||||||
|
mask_min=mask_min, disable_key_masking=disable_key_masking)
|
||||||
|
self.write_head = RawWriteHead(n_read_heads, word_length, use_mask=mask, dealloc_content=dealloc_content,
|
||||||
|
disable_content_norm=disable_content_norm, mask_min=mask_min,
|
||||||
|
disable_key_masking=disable_key_masking)
|
||||||
|
self.temporal_link = TemporalMemoryLinkage()
|
||||||
|
self.sharpness_control = DistSharpnessEnhancer([n_read_heads, n_read_heads]) if link_sharpness_control else None
|
||||||
|
|
||||||
|
in_size = input_size + n_read_heads * word_length
|
||||||
|
control_channels = self.read_head.get_neural_input_size() + self.write_head.get_neural_input_size() +\
|
||||||
|
(self.sharpness_control.get_neural_input_size() if self.sharpness_control is not None else 0)
|
||||||
|
|
||||||
|
self.controller = controller
|
||||||
|
controller.init(in_size)
|
||||||
|
self.controller_to_controls = torch.nn.Linear(controller.get_output_size(), control_channels, bias=bias)
|
||||||
|
self.controller_to_out = torch.nn.Linear(controller.get_output_size(), output_size, bias=bias)
|
||||||
|
self.read_to_out = torch.nn.Linear(word_length * n_read_heads, output_size, bias=bias)
|
||||||
|
|
||||||
|
self.cell_count = cell_count
|
||||||
|
self.word_length = word_length
|
||||||
|
|
||||||
|
self.memory = None
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
self.batch_first = batch_first
|
||||||
|
self.zero_mem_tensor = None
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
linear_reset(self.controller_to_controls)
|
||||||
|
linear_reset(self.controller_to_out)
|
||||||
|
linear_reset(self.read_to_out)
|
||||||
|
self.controller.reset_parameters()
|
||||||
|
|
||||||
|
def _step(self, in_data, debug):
|
||||||
|
init_debug(debug, {
|
||||||
|
"read_head": {},
|
||||||
|
"write_head": {},
|
||||||
|
"temporal_links": {}
|
||||||
|
})
|
||||||
|
|
||||||
|
# input shape: [ batch, channels ]
|
||||||
|
batch_size = in_data.size(0)
|
||||||
|
|
||||||
|
# run the controller
|
||||||
|
prev_read_data = self.read_head.get_prev_data(self.memory).view([batch_size, -1])
|
||||||
|
|
||||||
|
control_data = self.controller(torch.cat([in_data, prev_read_data], -1))
|
||||||
|
|
||||||
|
# memory ops
|
||||||
|
controls = self.controller_to_controls(control_data).contiguous()
|
||||||
|
controls = controls.clamp(-self.clip_controller, self.clip_controller) if self.clip_controller is not None else controls
|
||||||
|
|
||||||
|
shapes = [[self.write_head.get_neural_input_size()], [self.read_head.get_neural_input_size()]]
|
||||||
|
if self.sharpness_control is not None:
|
||||||
|
shapes.append(self.sharpness_control.get_neural_input_size())
|
||||||
|
|
||||||
|
tensors = split_tensor(controls, shapes)
|
||||||
|
|
||||||
|
write_head_control, read_head_control = tensors[:2]
|
||||||
|
tensors = tensors[2:]
|
||||||
|
|
||||||
|
prev_read_dist = self.read_head.get_prev_dist(self.memory)
|
||||||
|
|
||||||
|
self.memory = self.write_head(self.memory, write_head_control, prev_read_dist, debug=dict_get(debug,"write_head"))
|
||||||
|
|
||||||
|
prev_write = self.write_head.get_prev_write()
|
||||||
|
forward_dist, backward_dist = self.temporal_link(prev_write["write_dist"] if prev_write is not None else None, prev_read_dist, debug=dict_get(debug, "temporal_links"))
|
||||||
|
|
||||||
|
if self.sharpness_control is not None:
|
||||||
|
forward_dist, backward_dist = self.sharpness_control(tensors[0], forward_dist, backward_dist)
|
||||||
|
|
||||||
|
read_data = self.read_head(self.memory, read_head_control, forward_dist, backward_dist, debug=dict_get(debug,"read_head"))
|
||||||
|
|
||||||
|
# output:
|
||||||
|
return self.controller_to_out(control_data) + self.read_to_out(read_data.view(batch_size,-1))
|
||||||
|
|
||||||
|
def _mem_init(self, batch_size, device):
|
||||||
|
if self.zero_mem_tensor is None or self.zero_mem_tensor.size(0)!=batch_size:
|
||||||
|
self.zero_mem_tensor = torch.zeros(batch_size, self.cell_count, self.word_length).to(device)
|
||||||
|
|
||||||
|
self.memory = self.zero_mem_tensor
|
||||||
|
|
||||||
|
def forward(self, in_data, debug=None):
|
||||||
|
self.write_head.new_sequence()
|
||||||
|
self.read_head.new_sequence()
|
||||||
|
self.temporal_link.new_sequence()
|
||||||
|
self.controller.new_sequence()
|
||||||
|
|
||||||
|
self._mem_init(in_data.size(0 if self.batch_first else 1), in_data.device)
|
||||||
|
|
||||||
|
out_tsteps = []
|
||||||
|
|
||||||
|
if self.batch_first:
|
||||||
|
# input format: batch, time, channels
|
||||||
|
for t in range(in_data.size(1)):
|
||||||
|
out_tsteps.append(self._step(in_data[:,t], debug))
|
||||||
|
else:
|
||||||
|
# input format: time, batch, channels
|
||||||
|
for t in range(in_data.size(0)):
|
||||||
|
out_tsteps.append(self._step(in_data[t], debug))
|
||||||
|
|
||||||
|
merge_debug_tensors(debug, dim=1 if self.batch_first else 0)
|
||||||
|
return torch.stack(out_tsteps, dim=1 if self.batch_first else 0)
|
||||||
|
|
||||||
|
class LSTMController(torch.nn.Module):
|
||||||
|
def __init__(self, layer_sizes, out_from_all_layers=True):
|
||||||
|
super(LSTMController, self).__init__()
|
||||||
|
self.out_from_all_layers = out_from_all_layers
|
||||||
|
self.layer_sizes = layer_sizes
|
||||||
|
self.states = None
|
||||||
|
self.outputs = None
|
||||||
|
|
||||||
|
def new_sequence(self):
|
||||||
|
self.states = [None] * len(self.layer_sizes)
|
||||||
|
self.outputs = [None] * len(self.layer_sizes)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
def init_layer(l, index):
|
||||||
|
size = self.layer_sizes[index]
|
||||||
|
# Initialize all matrices to sigmoid, just data input to tanh
|
||||||
|
a=math.sqrt(3.0)*self.stdevs[i]
|
||||||
|
l.weight.data[0:-size].uniform_(-a,a)
|
||||||
|
a*=init.calculate_gain("tanh")
|
||||||
|
l.weight.data[-size:].uniform_(-a, a)
|
||||||
|
if l.bias is not None:
|
||||||
|
l.bias.data[self.layer_sizes[i]:].fill_(0)
|
||||||
|
# init forget gate to large number.
|
||||||
|
l.bias.data[:self.layer_sizes[i]].fill_(1)
|
||||||
|
|
||||||
|
# xavier init merged input weights
|
||||||
|
for i in range(len(self.layer_sizes)):
|
||||||
|
init_layer(self.in_to_all[i], i)
|
||||||
|
init_layer(self.out_to_all[i], i)
|
||||||
|
if i>0:
|
||||||
|
init_layer(self.prev_to_all[i-1], i)
|
||||||
|
|
||||||
|
def _add_modules(self, name, m_list):
|
||||||
|
for i, m in enumerate(m_list):
|
||||||
|
self.add_module("%s_%d" % (name,i), m)
|
||||||
|
|
||||||
|
def init(self, input_size):
|
||||||
|
self.layer_sizes = self.layer_sizes
|
||||||
|
|
||||||
|
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
|
||||||
|
# So use xavier init according to this.
|
||||||
|
self.input_sizes = [(self.layer_sizes[i - 1] if i>0 else 0) + self.layer_sizes[i] + input_size
|
||||||
|
for i in range(len(self.layer_sizes))]
|
||||||
|
self.stdevs = [math.sqrt(2.0 / (self.layer_sizes[i] + self.input_sizes[i])) for i in range(len(self.layer_sizes))]
|
||||||
|
self.in_to_all= [torch.nn.Linear(input_size, 4*self.layer_sizes[i]) for i in range(len(self.layer_sizes))]
|
||||||
|
self.out_to_all = [torch.nn.Linear(self.layer_sizes[i], 4 * self.layer_sizes[i], bias=False) for i in range(len(self.layer_sizes))]
|
||||||
|
self.prev_to_all = [torch.nn.Linear(self.layer_sizes[i-1], 4 * self.layer_sizes[i], bias=False) for i in range(1,len(self.layer_sizes))]
|
||||||
|
|
||||||
|
self._add_modules("in_to_all", self.in_to_all)
|
||||||
|
self._add_modules("out_to_all", self.out_to_all)
|
||||||
|
self._add_modules("prev_to_all", self.prev_to_all)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def get_output_size(self):
|
||||||
|
return sum(self.layer_sizes) if self.out_from_all_layers else self.layer_sizes[-1]
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
for i, size in enumerate(self.layer_sizes):
|
||||||
|
d = self.in_to_all[i](data)
|
||||||
|
if self.outputs[i] is not None:
|
||||||
|
d+=self.out_to_all[i](self.outputs[i])
|
||||||
|
if i>0:
|
||||||
|
d+=self.prev_to_all[i-1](self.outputs[i-1])
|
||||||
|
|
||||||
|
input_data = torch.tanh(d[...,-size:])
|
||||||
|
forget_gate, input_gate, output_gate = torch.sigmoid(d[...,:-size]).chunk(3,dim=-1)
|
||||||
|
|
||||||
|
state_update = input_gate * input_data
|
||||||
|
|
||||||
|
if self.states[i] is not None:
|
||||||
|
self.states[i] = self.states[i]*forget_gate + state_update
|
||||||
|
else:
|
||||||
|
self.states[i] = state_update
|
||||||
|
|
||||||
|
self.outputs[i] = output_gate * torch.tanh(self.states[i])
|
||||||
|
|
||||||
|
return torch.cat(self.outputs, -1) if self.out_from_all_layers else self.outputs[-1]
|
||||||
|
|
||||||
|
|
||||||
|
class FeedforwardController(torch.nn.Module):
|
||||||
|
def __init__(self, layer_sizes=[]):
|
||||||
|
super(FeedforwardController, self).__init__()
|
||||||
|
self.layer_sizes = layer_sizes
|
||||||
|
|
||||||
|
def new_sequence(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for module in self.model:
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||||
|
|
||||||
|
def get_output_size(self):
|
||||||
|
return self.layer_sizes[-1]
|
||||||
|
|
||||||
|
def init(self, input_size):
|
||||||
|
self.layer_sizes = self.layer_sizes
|
||||||
|
|
||||||
|
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
|
||||||
|
# So use xavier init according to this.
|
||||||
|
self.input_sizes = [input_size] + self.layer_sizes[:-1]
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for i, size in enumerate(self.layer_sizes):
|
||||||
|
layers.append(torch.nn.Linear(self.input_sizes[i], self.layer_sizes[i]))
|
||||||
|
layers.append(torch.nn.ReLU())
|
||||||
|
self.model = torch.nn.Sequential(*layers)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
return self.model(data)
|
61
README.md
Normal file
61
README.md
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
PyTorch implementation of custom DNC variants
|
||||||
|
=============================================
|
||||||
|
|
||||||
|
Tasks
|
||||||
|
-----
|
||||||
|
Supported tasks:
|
||||||
|
* bAbI
|
||||||
|
* copy
|
||||||
|
* repeated copy
|
||||||
|
* associative recall
|
||||||
|
* key-value recall
|
||||||
|
* 2 way key-value recall
|
||||||
|
|
||||||
|
Visualization and debugging
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
Many interesting internal states of the DNC are visualized inside Visdom. Check console output for the port.
|
||||||
|
|
||||||
|
![](./assets/preview.png)
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
|
||||||
|
Everything is done by main.py. Use -name to give some path (it will be created if doesn't exists), where the state of the training will be saved. Check out main.py for more information about the flags available.
|
||||||
|
|
||||||
|
Most of the trainings can be run by profiles:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./main.py -name <train dir> -profile babi
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported profiles: babi, repeat_copy, repeat_copy_simple, keyvalue, keyvalue2way, associative_recall.
|
||||||
|
|
||||||
|
If you want to train a pure DNC, use add "dnc" to the profile:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./main.py -name <train dir> -profile babi,dnc
|
||||||
|
```
|
||||||
|
|
||||||
|
For other options, see main.py.
|
||||||
|
|
||||||
|
DNC variants
|
||||||
|
------------
|
||||||
|
|
||||||
|
The variant of DNC can be specified as a profile. Supported variants:
|
||||||
|
dnc, dnc-msd, dnc-m, dnc-s, dnc-d, dnc-md, dnc-ms, dnc-sd.
|
||||||
|
|
||||||
|
Reusing the code
|
||||||
|
----------------
|
||||||
|
|
||||||
|
The DNC is implemented as a single file (Models/DNC.py) depending only on torch. You should be able to reuse it very easily. Please check main.py for details on its interface.
|
||||||
|
|
||||||
|
Dependencies
|
||||||
|
------------
|
||||||
|
|
||||||
|
PyTroch (1.0), Python 3. Others can be installed by running pip3 -r requirements.txt.
|
||||||
|
|
||||||
|
License
|
||||||
|
-------
|
||||||
|
|
||||||
|
The software is under Apache 2.0 license. See http://www.apache.org/licenses/LICENSE-2.0 for further details.
|
167
Utils/ArgumentParser.py
Normal file
167
Utils/ArgumentParser.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
class ArgumentParser:
|
||||||
|
class Parsed:
|
||||||
|
pass
|
||||||
|
|
||||||
|
_type = type
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def str_or_none(none_string="none"):
|
||||||
|
def parse(s):
|
||||||
|
return None if s.lower()==none_string else s
|
||||||
|
return parse
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_or_none(none_string="none", type=int):
|
||||||
|
def parse(s):
|
||||||
|
return None if s.lower() == none_string else [type(a) for a in s.split(",") if a]
|
||||||
|
return parse
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_args(args, new_args, arg_schemas):
|
||||||
|
for name, val in new_args.items():
|
||||||
|
old = args.get(name)
|
||||||
|
if old is None:
|
||||||
|
args[name] = val
|
||||||
|
else:
|
||||||
|
args[name] = arg_schemas[name]["updater"](old, val)
|
||||||
|
|
||||||
|
class Profile:
|
||||||
|
def __init__(self, name, args=None, include=[]):
|
||||||
|
assert not (args is None and not include), "One of args or include must be defined"
|
||||||
|
self.name = name
|
||||||
|
self.args = args
|
||||||
|
if not isinstance(include, list):
|
||||||
|
include=[include]
|
||||||
|
self.include = include
|
||||||
|
|
||||||
|
def get_args(self, arg_schemas, profile_by_name):
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
for n in self.include:
|
||||||
|
p = profile_by_name.get(n)
|
||||||
|
assert p is not None, "Included profile %s doesn't exists" % n
|
||||||
|
|
||||||
|
ArgumentParser._merge_args(res, p.get_args(arg_schemas, profile_by_name), arg_schemas)
|
||||||
|
|
||||||
|
ArgumentParser._merge_args(res, self.args, arg_schemas)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, description=None):
|
||||||
|
self.parser = argparse.ArgumentParser(description=description)
|
||||||
|
self.loaded = {}
|
||||||
|
self.profiles = {}
|
||||||
|
self.args = {}
|
||||||
|
self.raw = None
|
||||||
|
self.parsed = None
|
||||||
|
self.parser.add_argument("-profile", type=str, help="Pre-defined profiles.")
|
||||||
|
|
||||||
|
def add_argument(self, name, type=None, default=None, help="", save=True, parser=lambda x:x, updater=lambda old, new:new):
|
||||||
|
assert name not in ["profile"], "Argument name %s is reserved" % name
|
||||||
|
assert not (type is None and default is None), "Either type or default must be given"
|
||||||
|
|
||||||
|
if type is None:
|
||||||
|
type = ArgumentParser._type(default)
|
||||||
|
|
||||||
|
self.parser.add_argument(name, type=int if type==bool else type, default=None, help=help)
|
||||||
|
if name[0] == '-':
|
||||||
|
name = name[1:]
|
||||||
|
|
||||||
|
self.args[name] = {
|
||||||
|
"type": type,
|
||||||
|
"default": int(default) if type==bool else default,
|
||||||
|
"save": save,
|
||||||
|
"parser": parser,
|
||||||
|
"updater": updater
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_profile(self, prof):
|
||||||
|
if isinstance(prof, list):
|
||||||
|
for p in prof:
|
||||||
|
self.add_profile(p)
|
||||||
|
else:
|
||||||
|
self.profiles[prof.name] = prof
|
||||||
|
|
||||||
|
def do_parse_args(self, loaded={}):
|
||||||
|
self.raw = self.parser.parse_args()
|
||||||
|
|
||||||
|
profile = {}
|
||||||
|
if self.raw.profile:
|
||||||
|
assert not loaded, "Loading arguments from file, but profile given."
|
||||||
|
for pr in self.raw.profile.split(","):
|
||||||
|
p = self.profiles.get(pr)
|
||||||
|
assert p is not None, "Invalid profile: %s. Valid profiles: %s" % (pr, self.profiles.keys())
|
||||||
|
p = p.get_args(self.args, self.profiles)
|
||||||
|
self._merge_args(profile, p, self.args)
|
||||||
|
|
||||||
|
for k, v in self.raw.__dict__.items():
|
||||||
|
if k in ["profile"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if v is None:
|
||||||
|
if k in loaded and self.args[k]["save"]:
|
||||||
|
self.raw.__dict__[k] = loaded[k]
|
||||||
|
else:
|
||||||
|
self.raw.__dict__[k] = profile.get(k, self.args[k]["default"])
|
||||||
|
|
||||||
|
self.parsed = ArgumentParser.Parsed()
|
||||||
|
self.parsed.__dict__.update({k: self.args[k]["parser"](self.args[k]["type"](v)) if v is not None else None
|
||||||
|
for k,v in self.raw.__dict__.items() if k in self.args})
|
||||||
|
|
||||||
|
return self.parsed
|
||||||
|
|
||||||
|
def parse_or_cache(self):
|
||||||
|
if self.parsed is None:
|
||||||
|
self.do_parse_args()
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
self.parse_or_cache()
|
||||||
|
return self.parsed
|
||||||
|
|
||||||
|
def save(self, fname):
|
||||||
|
self.parse_or_cache()
|
||||||
|
with open(fname, 'w') as outfile:
|
||||||
|
json.dump(self.raw.__dict__, outfile, indent=4)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load(self, fname):
|
||||||
|
if os.path.isfile(fname):
|
||||||
|
map = {}
|
||||||
|
with open(fname, "r") as data_file:
|
||||||
|
map = json.load(data_file)
|
||||||
|
|
||||||
|
self.do_parse_args(map)
|
||||||
|
return self.parsed
|
||||||
|
|
||||||
|
def sync(self, fname, save=True):
|
||||||
|
if os.path.isfile(fname):
|
||||||
|
self.load(fname)
|
||||||
|
|
||||||
|
if save:
|
||||||
|
dir = os.path.dirname(fname)
|
||||||
|
if not os.path.isdir(dir):
|
||||||
|
os.makedirs(dir)
|
||||||
|
|
||||||
|
self.save(fname)
|
||||||
|
return self.parsed
|
75
Utils/Collate.py
Normal file
75
Utils/Collate.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from operator import mul
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
class VarLengthCollate:
|
||||||
|
def __init__(self, ignore_symbol=0):
|
||||||
|
self.ignore_symbol = ignore_symbol
|
||||||
|
|
||||||
|
def _measure_array_max_dim(self, batch):
|
||||||
|
s=list(batch[0].size())
|
||||||
|
different=[False] * len(s)
|
||||||
|
for i in range(1, len(batch)):
|
||||||
|
ns = batch[i].size()
|
||||||
|
different = [different[j] or s[j]!=ns[j] for j in range(len(s))]
|
||||||
|
s=[max(s[j], ns[j]) for j in range(len(s))]
|
||||||
|
return s, different
|
||||||
|
|
||||||
|
def _merge_var_len_array(self, batch):
|
||||||
|
max_size, different = self._measure_array_max_dim(batch)
|
||||||
|
s=[len(batch)] + max_size
|
||||||
|
storage = batch[0].storage()._new_shared(reduce(mul, s, 1))
|
||||||
|
out = batch[0].new(storage).view(s).fill_(self.ignore_symbol)
|
||||||
|
for i, d in enumerate(batch):
|
||||||
|
this_o = out[i]
|
||||||
|
for j, diff in enumerate(different):
|
||||||
|
if different[j]:
|
||||||
|
this_o = this_o.narrow(j,0,d.size(j))
|
||||||
|
this_o.copy_(d)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
if isinstance(batch[0], dict):
|
||||||
|
return {k: self([b[k] for b in batch]) for k in batch[0].keys()}
|
||||||
|
elif isinstance(batch[0], np.ndarray):
|
||||||
|
return self([torch.from_numpy(a) for a in batch])
|
||||||
|
elif torch.is_tensor(batch[0]):
|
||||||
|
return self._merge_var_len_array(batch)
|
||||||
|
else:
|
||||||
|
assert False, "Unknown type: %s" % type(batch[0])
|
||||||
|
|
||||||
|
class MetaCollate:
|
||||||
|
def __init__(self, meta_name="meta", collate=VarLengthCollate()):
|
||||||
|
self.meta_name = meta_name
|
||||||
|
self.collate = collate
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
if isinstance(batch[0], dict):
|
||||||
|
meta = [b[self.meta_name] for b in batch]
|
||||||
|
batch = [{k: v for k,v in b.items() if k!=self.meta_name} for b in batch]
|
||||||
|
else:
|
||||||
|
meta = None
|
||||||
|
|
||||||
|
res = self.collate(batch)
|
||||||
|
if meta is not None:
|
||||||
|
res[self.meta_name] = meta
|
||||||
|
|
||||||
|
return res
|
133
Utils/Debug.py
Normal file
133
Utils/Debug.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import torch
|
||||||
|
|
||||||
|
enableDebug = False
|
||||||
|
|
||||||
|
def nan_check(arg, name=None, force=False):
|
||||||
|
if not enableDebug and not force:
|
||||||
|
return arg
|
||||||
|
is_nan = False
|
||||||
|
curr_nan = False
|
||||||
|
if isinstance(arg, torch.autograd.Variable):
|
||||||
|
curr_nan = not np.isfinite(arg.sum().cpu().data.numpy())
|
||||||
|
elif isinstance(arg, torch.nn.parameter.Parameter):
|
||||||
|
curr_nan = (not np.isfinite(arg.sum().cpu().data.numpy())) or (not np.isfinite(arg.grad.sum().cpu().data.numpy()))
|
||||||
|
elif isinstance(arg, float):
|
||||||
|
curr_nan = not np.isfinite(arg)
|
||||||
|
elif isinstance(arg, (list, tuple)):
|
||||||
|
for a in arg:
|
||||||
|
nan_check(a)
|
||||||
|
else:
|
||||||
|
assert False, "Unsupported type %s" % type(arg)
|
||||||
|
|
||||||
|
if curr_nan:
|
||||||
|
if sys.exc_info()[0] is not None:
|
||||||
|
trace = str(traceback.format_exc())
|
||||||
|
else:
|
||||||
|
trace = "".join(traceback.format_stack())
|
||||||
|
|
||||||
|
print(arg)
|
||||||
|
if name is not None:
|
||||||
|
print("NaN found in %s." % name)
|
||||||
|
else:
|
||||||
|
print("NaN found.")
|
||||||
|
if isinstance(arg, torch.autograd.Variable):
|
||||||
|
print(" Argument is a torch tensor. Shape: %s" % list(arg.size()))
|
||||||
|
|
||||||
|
print(trace)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
return arg
|
||||||
|
|
||||||
|
|
||||||
|
def assert_range(t, min=0.0, max=1.0):
|
||||||
|
if not enableDebug:
|
||||||
|
return
|
||||||
|
|
||||||
|
if t.min().cpu().data.numpy()<min or t.max().cpu().data.numpy()>max:
|
||||||
|
print(t)
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
def assert_dist(t, use_lower_limit=True):
|
||||||
|
if not enableDebug:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert_range(t)
|
||||||
|
|
||||||
|
if t.sum(-1).max().cpu().data.numpy()>1.001:
|
||||||
|
print("MAT:", t)
|
||||||
|
print("SUM:", t.sum(-1))
|
||||||
|
assert False
|
||||||
|
|
||||||
|
if use_lower_limit and t.sum(-1).max().cpu().data.numpy()<0.999:
|
||||||
|
print(t)
|
||||||
|
print("SUM:", t.sum(-1))
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
def print_stat(name, t):
|
||||||
|
if not enableDebug:
|
||||||
|
return
|
||||||
|
|
||||||
|
min = t.min().cpu().data.numpy()
|
||||||
|
max = t.max().cpu().data.numpy()
|
||||||
|
mean = t.mean().cpu().data.numpy()
|
||||||
|
|
||||||
|
print("%s: min: %g, mean: %g, max: %g" % (name, min, mean, max))
|
||||||
|
|
||||||
|
|
||||||
|
def dbg_print(*things):
|
||||||
|
if not enableDebug:
|
||||||
|
return
|
||||||
|
print(*things)
|
||||||
|
|
||||||
|
class GradPrinter(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, a):
|
||||||
|
return a
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, g):
|
||||||
|
print("Grad (print_grad): ", g[0])
|
||||||
|
return g
|
||||||
|
|
||||||
|
def print_grad(t):
|
||||||
|
return GradPrinter.apply(t)
|
||||||
|
|
||||||
|
def assert_equal(t1, ref, limit=1e-5, force=True):
|
||||||
|
if not (enableDebug or force):
|
||||||
|
return
|
||||||
|
|
||||||
|
assert t1.shape==ref.shape, "Tensor shapes differ: got %s, ref %s" % (t1.shape, ref.shape)
|
||||||
|
norm = ref.abs().sum() / ref.nonzero().sum().float()
|
||||||
|
threshold = norm * limit
|
||||||
|
|
||||||
|
errcnt = ((t1 - ref).abs() > threshold).sum()
|
||||||
|
if errcnt > 0:
|
||||||
|
print("Tensors differ. (max difference: %g, norm %f). No of errors: %d of %d" %
|
||||||
|
((t1 - ref).abs().max().item(), norm, errcnt, t1.numel()))
|
||||||
|
print("---------------------------------------------Out-----------------------------------------------")
|
||||||
|
print(t1)
|
||||||
|
print("---------------------------------------------Ref-----------------------------------------------")
|
||||||
|
print(ref)
|
||||||
|
print("-----------------------------------------------------------------------------------------------")
|
||||||
|
assert False
|
24
Utils/Helpers.py
Normal file
24
Utils/Helpers.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.autograd
|
||||||
|
|
||||||
|
def as_numpy(data):
|
||||||
|
if isinstance(data, (torch.Tensor, torch.autograd.Variable)):
|
||||||
|
return data.detach().cpu().numpy()
|
||||||
|
else:
|
||||||
|
return data
|
20
Utils/Index.py
Normal file
20
Utils/Index.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
def index_by_dim(arr, dim, i_start, i_end=None):
|
||||||
|
if dim<0:
|
||||||
|
dim += arr.ndim
|
||||||
|
return arr[tuple([slice(None,None)] * dim + [slice(i_start, i_end) if i_end is not None else i_start])]
|
51
Utils/Process.py
Normal file
51
Utils/Process.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import ctypes
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
|
||||||
|
def run(cmd, hide_stderr = True):
|
||||||
|
libc_search_dirs = ["/lib", "/lib/x86_64-linux-gnu", "/lib/powerpc64le-linux-gnu"]
|
||||||
|
|
||||||
|
if sys.platform=="linux" :
|
||||||
|
found = None
|
||||||
|
for d in libc_search_dirs:
|
||||||
|
file = os.path.join(d, "libc.so.6")
|
||||||
|
if os.path.isfile(file):
|
||||||
|
found = file
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
print("WARNING: Cannot find libc.so.6. Cannot kill process when parent dies.")
|
||||||
|
killer = None
|
||||||
|
else:
|
||||||
|
libc = ctypes.CDLL(found)
|
||||||
|
PR_SET_PDEATHSIG = 1
|
||||||
|
KILL = 9
|
||||||
|
killer = lambda: libc.prctl(PR_SET_PDEATHSIG, KILL)
|
||||||
|
else:
|
||||||
|
print("WARNING: OS not linux. Cannot kill process when parent dies.")
|
||||||
|
killer = None
|
||||||
|
|
||||||
|
|
||||||
|
if hide_stderr:
|
||||||
|
stderr = open(os.devnull,'w')
|
||||||
|
else:
|
||||||
|
stderr = None
|
||||||
|
|
||||||
|
return subprocess.Popen(cmd.split(" "), stderr=stderr, preexec_fn=killer)
|
48
Utils/Profile.py
Normal file
48
Utils/Profile.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
|
||||||
|
ENABLED=False
|
||||||
|
|
||||||
|
_profiler = None
|
||||||
|
|
||||||
|
|
||||||
|
def construct():
|
||||||
|
global _profiler
|
||||||
|
if not ENABLED:
|
||||||
|
return
|
||||||
|
|
||||||
|
if _profiler is None:
|
||||||
|
from line_profiler import LineProfiler
|
||||||
|
_profiler = LineProfiler()
|
||||||
|
|
||||||
|
|
||||||
|
def do_profile(follow=[]):
|
||||||
|
construct()
|
||||||
|
def inner(func):
|
||||||
|
if _profiler is not None:
|
||||||
|
_profiler.add_function(func)
|
||||||
|
for f in follow:
|
||||||
|
_profiler.add_function(f)
|
||||||
|
_profiler.enable_by_count()
|
||||||
|
return func
|
||||||
|
return inner
|
||||||
|
|
||||||
|
@atexit.register
|
||||||
|
def print_prof():
|
||||||
|
if _profiler is not None:
|
||||||
|
_profiler.print_stats()
|
233
Utils/Saver.py
Normal file
233
Utils/Saver.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class SaverElement:
|
||||||
|
def save(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load(self, saved_state):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackSaver(SaverElement):
|
||||||
|
def __init__(self, save_fn, load_fn):
|
||||||
|
super().__init__()
|
||||||
|
self.save = save_fn
|
||||||
|
self.load = load_fn
|
||||||
|
|
||||||
|
|
||||||
|
class StateSaver(SaverElement):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def load(self, state):
|
||||||
|
try:
|
||||||
|
self._model.load_state_dict(state)
|
||||||
|
except Exception as e:
|
||||||
|
if hasattr(self._model, "named_parameters"):
|
||||||
|
names = set([n for n, _ in self._model.named_parameters()])
|
||||||
|
loaded = set(self._model.keys())
|
||||||
|
if names!=loaded:
|
||||||
|
d = loaded.difference(names)
|
||||||
|
if d:
|
||||||
|
print("Loaded, but not in model: %s" % list(d))
|
||||||
|
d = names.difference(loaded)
|
||||||
|
if d:
|
||||||
|
print("In model, but not loaded: %s" % list(d))
|
||||||
|
if isinstance(self._model, torch.optim.Optimizer):
|
||||||
|
print("WARNING: optimizer parameters not loaded!")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
return self._model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalVarSaver(SaverElement):
|
||||||
|
def __init__(self, name):
|
||||||
|
caller_frame = inspect.getouterframes(inspect.currentframe())[1]
|
||||||
|
self._vars = caller_frame.frame.f_globals
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
def load(self, state):
|
||||||
|
self._vars.update({self._name: state})
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
return self._vars[self._name]
|
||||||
|
|
||||||
|
|
||||||
|
class PyObjectSaver(SaverElement):
|
||||||
|
def __init__(self, obj):
|
||||||
|
self._obj = obj
|
||||||
|
|
||||||
|
def load(self, state):
|
||||||
|
def _load(target, state):
|
||||||
|
if isinstance(target, dict):
|
||||||
|
for k, v in state.items():
|
||||||
|
target[k] = _load(target.get(k), v)
|
||||||
|
elif isinstance(target, list):
|
||||||
|
if len(target)!=len(state):
|
||||||
|
target.clear()
|
||||||
|
for v in state:
|
||||||
|
target.append(v)
|
||||||
|
else:
|
||||||
|
for i, v in enumerate(state):
|
||||||
|
target[i] = _load(target[i], v)
|
||||||
|
|
||||||
|
elif hasattr(target, "__dict__"):
|
||||||
|
_load(target.__dict__, state)
|
||||||
|
else:
|
||||||
|
return state
|
||||||
|
return target
|
||||||
|
|
||||||
|
_load(self._obj, state)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
def _save(target):
|
||||||
|
if isinstance(target, dict):
|
||||||
|
res = {k: _save(v) for k, v in target.items()}
|
||||||
|
elif isinstance(target, list):
|
||||||
|
res = [_save(v) for v in target]
|
||||||
|
elif hasattr(target, "__dict__"):
|
||||||
|
res = {k: _save(v) for k, v in target.__dict__.items()}
|
||||||
|
else:
|
||||||
|
res = target
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return _save(self._obj)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def obj_supported(obj):
|
||||||
|
return isinstance(obj, (list, dict)) or hasattr(obj, "__dict__")
|
||||||
|
|
||||||
|
|
||||||
|
class Saver:
|
||||||
|
def __init__(self, dir, short_interval, keep_every_n_hours=4):
|
||||||
|
self.savers = {}
|
||||||
|
self.short_interval = short_interval
|
||||||
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
self.dir = dir
|
||||||
|
self._keep_every_n_seconds = keep_every_n_hours * 3600
|
||||||
|
|
||||||
|
def register(self, name, saver):
|
||||||
|
assert name not in self.savers, "Saver %s already registered" % name
|
||||||
|
|
||||||
|
if isinstance(saver, SaverElement):
|
||||||
|
self.savers[name] = saver
|
||||||
|
elif hasattr(saver, "state_dict") and callable(saver.state_dict):
|
||||||
|
self.savers[name] = StateSaver(saver)
|
||||||
|
elif PyObjectSaver.obj_supported(saver):
|
||||||
|
self.savers[name] = PyObjectSaver(saver)
|
||||||
|
else:
|
||||||
|
assert "Unsupported thing to save: %s" % type(saver)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.register(key, value)
|
||||||
|
|
||||||
|
def write(self, iter):
|
||||||
|
fname = os.path.join(self.dir, self.model_name_from_index(iter))
|
||||||
|
print("Saving %s" % fname)
|
||||||
|
|
||||||
|
state = {}
|
||||||
|
for name, fns in self.savers.items():
|
||||||
|
state[name] = fns.save()
|
||||||
|
|
||||||
|
torch.save(state, fname)
|
||||||
|
print("Saved.")
|
||||||
|
|
||||||
|
self._cleanup()
|
||||||
|
|
||||||
|
def tick(self, iter):
|
||||||
|
if iter % self.short_interval != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.write(iter)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def model_name_from_index(index):
|
||||||
|
return "model-%d.pth" % index
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_checkpoint_index_list(dir):
|
||||||
|
return list(reversed(sorted(
|
||||||
|
[int(fn.split(".")[0].split("-")[-1]) for fn in os.listdir(dir) if fn.split(".")[-1] == "pth"])))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_ckpts_in_time_window(dir, time_window_s, index_list=None):
|
||||||
|
if index_list is None:
|
||||||
|
index_list = Saver.get_checkpoint_index_list(dir)
|
||||||
|
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for i in index_list:
|
||||||
|
name = Saver.model_name_from_index(i)
|
||||||
|
mtime = os.path.getmtime(os.path.join(dir, name))
|
||||||
|
if now - mtime > time_window_s:
|
||||||
|
break
|
||||||
|
|
||||||
|
res.append(name)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_last_checkpoint(dir):
|
||||||
|
last_checkpoint = Saver.get_checkpoint_index_list(dir)
|
||||||
|
|
||||||
|
if last_checkpoint:
|
||||||
|
for index in last_checkpoint:
|
||||||
|
fname = Saver.model_name_from_index(index)
|
||||||
|
try:
|
||||||
|
print("Loading %s" % fname)
|
||||||
|
data = torch.load(os.path.join(dir, fname))
|
||||||
|
except:
|
||||||
|
print("WARNING: Loading %s failed. Maybe file is corrupted?" % fname)
|
||||||
|
continue
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cleanup(self):
|
||||||
|
index_list = self.get_checkpoint_index_list(self.dir)
|
||||||
|
new_files = self.get_ckpts_in_time_window(self.dir, self._keep_every_n_seconds, index_list[2:])
|
||||||
|
new_files = new_files[:-1]
|
||||||
|
|
||||||
|
for f in new_files:
|
||||||
|
os.remove(os.path.join(self.dir, f))
|
||||||
|
|
||||||
|
def load(self, fname=None):
|
||||||
|
if fname is None:
|
||||||
|
state = self.load_last_checkpoint(self.dir)
|
||||||
|
if not state:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
state = torch.load(fname)
|
||||||
|
|
||||||
|
for k,s in state.items():
|
||||||
|
if k not in self.savers:
|
||||||
|
print("WARNING: failed to load state of %s. It doesn't exists." % k)
|
||||||
|
continue
|
||||||
|
self.savers[k].load(s)
|
||||||
|
|
||||||
|
return True
|
31
Utils/Seed.py
Normal file
31
Utils/Seed.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
|
||||||
|
def fix():
|
||||||
|
global seed
|
||||||
|
seed = 0xB0C1FA52
|
||||||
|
|
||||||
|
|
||||||
|
def get_randstate():
|
||||||
|
if seed:
|
||||||
|
return np.random.RandomState(seed)
|
||||||
|
else:
|
||||||
|
return np.random.RandomState()
|
323
Utils/Visdom.py
Normal file
323
Utils/Visdom.py
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import time
|
||||||
|
import visdom
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
from . import Process
|
||||||
|
|
||||||
|
vis = None
|
||||||
|
port = None
|
||||||
|
visdom_fail_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
def port_used(port):
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
result = sock.connect_ex(('127.0.0.1', port))
|
||||||
|
if result == 0:
|
||||||
|
sock.close()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_port(start_from=7000):
|
||||||
|
while True:
|
||||||
|
if port_used(start_from):
|
||||||
|
print("Port already used: %d" % start_from)
|
||||||
|
start_from += 1
|
||||||
|
else:
|
||||||
|
return start_from
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_port(port, timeout=5):
|
||||||
|
star_time = time.time()
|
||||||
|
while not port_used(port):
|
||||||
|
if time.time() - star_time > timeout:
|
||||||
|
return False
|
||||||
|
|
||||||
|
time.sleep(0.1)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def start(on_port=None):
|
||||||
|
global vis
|
||||||
|
global port
|
||||||
|
global visdom_fail_count
|
||||||
|
assert vis is None, "Cannot start more than 1 visdom servers."
|
||||||
|
|
||||||
|
if visdom_fail_count>=3:
|
||||||
|
return
|
||||||
|
|
||||||
|
port = alloc_port() if on_port is None else on_port
|
||||||
|
|
||||||
|
print("Starting Visdom server on %d" % port)
|
||||||
|
Process.run("%s -m visdom.server -p %d" % (sys.executable, port))
|
||||||
|
if not wait_for_port(port):
|
||||||
|
print("ERROR: failed to start Visdom server. Server not responding.")
|
||||||
|
visdom_fail_count += 1
|
||||||
|
return
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
vis = visdom.Visdom(port=port)
|
||||||
|
|
||||||
|
|
||||||
|
def _start_if_not_running():
|
||||||
|
if vis is None:
|
||||||
|
start()
|
||||||
|
|
||||||
|
|
||||||
|
def save_heatmap(dir, title, img):
|
||||||
|
if dir is not None:
|
||||||
|
fname = os.path.join(dir, title.replace(" ", "_") + ".npy")
|
||||||
|
d = os.path.dirname(fname)
|
||||||
|
os.makedirs(d, exist_ok=True)
|
||||||
|
np.save(fname, img)
|
||||||
|
|
||||||
|
|
||||||
|
class Plot2D:
|
||||||
|
TO_SAVE = ["x", "y", "curr_accu", "curr_cnt", "legend"]
|
||||||
|
|
||||||
|
def __init__(self, name, store_interval=1, legend=None, xlabel=None, ylabel=None):
|
||||||
|
_start_if_not_running()
|
||||||
|
|
||||||
|
self.x = []
|
||||||
|
self.y = []
|
||||||
|
self.store_interval = store_interval
|
||||||
|
self.curr_accu = None
|
||||||
|
self.curr_cnt = 0
|
||||||
|
self.name = name
|
||||||
|
self.legend = legend
|
||||||
|
self.xlabel = xlabel
|
||||||
|
self.ylabel = ylabel
|
||||||
|
|
||||||
|
self.replot = False
|
||||||
|
self.visplot = None
|
||||||
|
self.last_vis_update_pos = 0
|
||||||
|
|
||||||
|
def set_legend(self, legend):
|
||||||
|
if self.legend != legend:
|
||||||
|
self.legend = legend
|
||||||
|
self.replot = True
|
||||||
|
|
||||||
|
def _send_update(self):
|
||||||
|
if not self.x or vis is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.visplot is None or self.replot:
|
||||||
|
opts = {
|
||||||
|
"title": self.name
|
||||||
|
}
|
||||||
|
if self.xlabel:
|
||||||
|
opts["xlabel"] = self.xlabel
|
||||||
|
if self.ylabel:
|
||||||
|
opts["ylabel"] = self.ylabel
|
||||||
|
|
||||||
|
if self.legend is not None:
|
||||||
|
opts["legend"] = self.legend
|
||||||
|
|
||||||
|
self.visplot = vis.line(X=np.asfarray(self.x), Y=np.asfarray(self.y), opts=opts, win=self.visplot)
|
||||||
|
self.replot = False
|
||||||
|
else:
|
||||||
|
vis.line(
|
||||||
|
X=np.asfarray(self.x[self.last_vis_update_pos:]),
|
||||||
|
Y=np.asfarray(self.y[self.last_vis_update_pos:]),
|
||||||
|
win=self.visplot,
|
||||||
|
update='append'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.last_vis_update_pos = len(self.x) - 1
|
||||||
|
|
||||||
|
def add_point(self, x, y):
|
||||||
|
if not isinstance(y, list):
|
||||||
|
y = [y]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if self.curr_accu is None:
|
||||||
|
self.curr_accu = [0.0] * len(y)
|
||||||
|
|
||||||
|
if len(self.curr_accu) < len(y):
|
||||||
|
# The number of curves increased.
|
||||||
|
need_to_add = (len(y) - len(self.curr_accu))
|
||||||
|
|
||||||
|
self.curr_accu += [0.0] * need_to_add
|
||||||
|
count = len(self.x)
|
||||||
|
if count>0:
|
||||||
|
self.replot = True
|
||||||
|
if not isinstance(self.x[0], list):
|
||||||
|
self.x = [[x] for x in self.x]
|
||||||
|
self.y = [[y] for y in self.y]
|
||||||
|
|
||||||
|
nan = float("nan")
|
||||||
|
for a in self.x:
|
||||||
|
a += [nan] * need_to_add
|
||||||
|
for a in self.y:
|
||||||
|
a += [nan] * need_to_add
|
||||||
|
elif len(self.curr_accu) > len(y):
|
||||||
|
y = y[:] + [float("nan")] * (len(self.curr_accu) - len(y))
|
||||||
|
|
||||||
|
self.curr_accu = [self.curr_accu[i] + y[i] for i in range(len(y))]
|
||||||
|
self.curr_cnt += 1
|
||||||
|
if self.curr_cnt == self.store_interval:
|
||||||
|
if len(y) > 1:
|
||||||
|
self.x.append([x] * len(y))
|
||||||
|
self.y.append([a / self.curr_cnt for a in self.curr_accu])
|
||||||
|
else:
|
||||||
|
self.x.append(x)
|
||||||
|
self.y.append(self.curr_accu[0] / self.curr_cnt)
|
||||||
|
|
||||||
|
self.curr_accu = [0.0] * len(y)
|
||||||
|
self.curr_cnt = 0
|
||||||
|
|
||||||
|
self._send_update()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
s = {k: self.__dict__[k] for k in self.TO_SAVE}
|
||||||
|
return s
|
||||||
|
|
||||||
|
def load_state_dict(self, state):
|
||||||
|
if self.legend is not None:
|
||||||
|
# Load legend only if not given in the constructor.
|
||||||
|
state["legend"] = self.legend
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self.last_vis_update_pos = 0
|
||||||
|
|
||||||
|
# Read old format
|
||||||
|
if not isinstance(self.curr_accu, list) and self.curr_accu is not None:
|
||||||
|
self.curr_accu = [self.curr_accu]
|
||||||
|
|
||||||
|
self._send_update()
|
||||||
|
|
||||||
|
|
||||||
|
class Image:
|
||||||
|
def __init__(self, title, dumpdir=None):
|
||||||
|
_start_if_not_running()
|
||||||
|
|
||||||
|
self.win = None
|
||||||
|
self.opts = dict(title=title)
|
||||||
|
self.dumpdir = dumpdir
|
||||||
|
|
||||||
|
def set_dump_dir(self, dumpdir):
|
||||||
|
self.dumpdir = dumpdir
|
||||||
|
|
||||||
|
def draw(self, img):
|
||||||
|
if vis is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(img, list):
|
||||||
|
if self.win is None:
|
||||||
|
self.win = vis.images(img, opts=self.opts)
|
||||||
|
else:
|
||||||
|
vis.images(img, win=self.win, opts=self.opts)
|
||||||
|
else:
|
||||||
|
if len(img.shape)==2:
|
||||||
|
img = np.expand_dims(img, 0)
|
||||||
|
elif img.shape[-1] in [1,3] and img.shape[0] not in [1,3]:
|
||||||
|
# cv2 image
|
||||||
|
img = img.transpose(2,0,1)
|
||||||
|
img = img[::-1]
|
||||||
|
|
||||||
|
if img.dtype==np.uint8:
|
||||||
|
img = img.astype(np.float32)/255
|
||||||
|
|
||||||
|
self.opts["width"] = img.shape[2]
|
||||||
|
self.opts["height"] = img.shape[1]
|
||||||
|
|
||||||
|
save_heatmap(self.dumpdir, self.opts["title"], img)
|
||||||
|
if self.win is None:
|
||||||
|
self.win = vis.image(img, opts=self.opts)
|
||||||
|
else:
|
||||||
|
vis.image(img, win=self.win, opts=self.opts)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
self.draw(img)
|
||||||
|
|
||||||
|
|
||||||
|
class Text:
|
||||||
|
def __init__(self, title):
|
||||||
|
_start_if_not_running()
|
||||||
|
|
||||||
|
self.win = None
|
||||||
|
self.title = title
|
||||||
|
self.curr_text = ""
|
||||||
|
|
||||||
|
def set(self, text):
|
||||||
|
self.curr_text = text
|
||||||
|
|
||||||
|
if vis is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.win is None:
|
||||||
|
self.win = vis.text(text, opts=dict(
|
||||||
|
title=self.title
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
vis.text(text, win=self.win)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"text": self.curr_text}
|
||||||
|
|
||||||
|
def load_state_dict(self, state):
|
||||||
|
self.set(state["text"])
|
||||||
|
|
||||||
|
def __call__(self, text):
|
||||||
|
self.set(text)
|
||||||
|
|
||||||
|
|
||||||
|
class Heatmap:
|
||||||
|
def __init__(self, title, min=None, max=None, xlabel=None, ylabel=None, colormap='Viridis', dumpdir=None):
|
||||||
|
_start_if_not_running()
|
||||||
|
|
||||||
|
self.win = None
|
||||||
|
self.opt = dict(title=title, colormap=colormap)
|
||||||
|
self.dumpdir = dumpdir
|
||||||
|
if min is not None:
|
||||||
|
self.opt["xmin"] = min
|
||||||
|
if max is not None:
|
||||||
|
self.opt["xmax"] = max
|
||||||
|
|
||||||
|
if xlabel:
|
||||||
|
self.opt["xlabel"] = xlabel
|
||||||
|
if ylabel:
|
||||||
|
self.opt["ylabel"] = ylabel
|
||||||
|
|
||||||
|
def set_dump_dir(self, dumpdir):
|
||||||
|
self.dumpdir = dumpdir
|
||||||
|
|
||||||
|
def draw(self, img):
|
||||||
|
if vis is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
o = self.opt.copy()
|
||||||
|
if "xmin" not in o:
|
||||||
|
o["xmin"] = float(img.min())
|
||||||
|
|
||||||
|
if "xmax" not in o:
|
||||||
|
o["xmax"] = float(img.max())
|
||||||
|
|
||||||
|
save_heatmap(self.dumpdir, o["title"], img)
|
||||||
|
|
||||||
|
if self.win is None:
|
||||||
|
self.win = vis.heatmap(img, opts=o)
|
||||||
|
else:
|
||||||
|
vis.heatmap(img, win=self.win, opts=o)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
self.draw(img)
|
189
Utils/download.py
Normal file
189
Utils/download.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import requests, tarfile, io, os, zipfile
|
||||||
|
|
||||||
|
from io import BytesIO, SEEK_SET, SEEK_END
|
||||||
|
|
||||||
|
|
||||||
|
class UrlStream:
|
||||||
|
def __init__(self, url):
|
||||||
|
self._url = url
|
||||||
|
headers = requests.head(url).headers
|
||||||
|
headers = {k.lower(): v for k,v in headers.items()}
|
||||||
|
self._seek_supported = headers.get('accept-ranges')=='bytes' and 'content-length' in headers
|
||||||
|
if self._seek_supported:
|
||||||
|
self._size = int(headers['content-length'])
|
||||||
|
self._curr_pos = 0
|
||||||
|
self._buf_start_pos = 0
|
||||||
|
self._iter = None
|
||||||
|
self._buffer = None
|
||||||
|
self._buf_size = 0
|
||||||
|
self._loaded_all = False
|
||||||
|
|
||||||
|
def seekable(self):
|
||||||
|
return self._seek_supported
|
||||||
|
|
||||||
|
def _load_all(self):
|
||||||
|
if self._loaded_all:
|
||||||
|
return
|
||||||
|
self._make_request()
|
||||||
|
old_buf_pos = self._buffer.tell()
|
||||||
|
self._buffer.seek(0, SEEK_END)
|
||||||
|
for chunk in self._iter:
|
||||||
|
self._buffer.write(chunk)
|
||||||
|
self._buf_size = self._buffer.tell()
|
||||||
|
self._buffer.seek(old_buf_pos, SEEK_SET)
|
||||||
|
self._loaded_all = True
|
||||||
|
|
||||||
|
def seek(self, position, whence=SEEK_SET):
|
||||||
|
if whence == SEEK_END:
|
||||||
|
assert position<=0
|
||||||
|
if self._seek_supported:
|
||||||
|
self.seek(self._size + position)
|
||||||
|
else:
|
||||||
|
self._load_all()
|
||||||
|
self._buffer.seek(position, SEEK_END)
|
||||||
|
self._curr_pos = self._buffer.tell()
|
||||||
|
elif whence==SEEK_SET:
|
||||||
|
if self._curr_pos != position:
|
||||||
|
self._curr_pos = position
|
||||||
|
if self._seek_supported:
|
||||||
|
self._iter = None
|
||||||
|
self._buffer = None
|
||||||
|
else:
|
||||||
|
self._load_until(position)
|
||||||
|
self._buffer.seek(position)
|
||||||
|
self._curr_pos = position
|
||||||
|
else:
|
||||||
|
assert "Invalid whence %s" % whence
|
||||||
|
|
||||||
|
return self.tell()
|
||||||
|
|
||||||
|
def tell(self):
|
||||||
|
return self._curr_pos
|
||||||
|
|
||||||
|
def _load_until(self, goal_position):
|
||||||
|
self._make_request()
|
||||||
|
old_buf_pos = self._buffer.tell()
|
||||||
|
current_position = self._buffer.seek(0, SEEK_END)
|
||||||
|
|
||||||
|
goal_position = goal_position - self._buf_start_pos
|
||||||
|
while current_position < goal_position:
|
||||||
|
try:
|
||||||
|
d = next(self._iter)
|
||||||
|
self._buffer.write(d)
|
||||||
|
current_position += len(d)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
self._buf_size = current_position
|
||||||
|
self._buffer.seek(old_buf_pos, SEEK_SET)
|
||||||
|
|
||||||
|
def _new_buffer(self):
|
||||||
|
remaining = self._buffer.read() if self._buffer is not None else None
|
||||||
|
self._buffer = BytesIO()
|
||||||
|
if remaining is not None:
|
||||||
|
self._buffer.write(remaining)
|
||||||
|
self._buf_start_pos = self._curr_pos
|
||||||
|
self._buf_size = 0 if remaining is None else len(remaining)
|
||||||
|
self._buffer.seek(0, SEEK_SET)
|
||||||
|
self._loaded_all = False
|
||||||
|
|
||||||
|
def _make_request(self):
|
||||||
|
if self._iter is None:
|
||||||
|
h = {
|
||||||
|
"User-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36",
|
||||||
|
}
|
||||||
|
if self._seek_supported:
|
||||||
|
h["Range"] = "bytes=%d-%d" % (self._curr_pos, self._size - 1)
|
||||||
|
|
||||||
|
r = requests.get(self._url, headers=h, stream=True)
|
||||||
|
|
||||||
|
self._iter = r.iter_content(1024 * 1024)
|
||||||
|
self._new_buffer()
|
||||||
|
elif self._seek_supported and self._buf_size > 128 * 1024 * 1024:
|
||||||
|
self._new_buffer()
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
if self._seek_supported:
|
||||||
|
return self._size
|
||||||
|
else:
|
||||||
|
self._load_all()
|
||||||
|
return self._buf_size
|
||||||
|
|
||||||
|
def read(self, size=None):
|
||||||
|
if size is None:
|
||||||
|
size = self.size()
|
||||||
|
|
||||||
|
self._load_until(self._curr_pos + size)
|
||||||
|
if self._seek_supported:
|
||||||
|
self._curr_pos = min(self._curr_pos+size, self._size)
|
||||||
|
|
||||||
|
return self._buffer.read(size)
|
||||||
|
|
||||||
|
def iter_content(self, block_size):
|
||||||
|
while True:
|
||||||
|
d = self.read(block_size)
|
||||||
|
if not len(d):
|
||||||
|
break
|
||||||
|
yield d
|
||||||
|
|
||||||
|
|
||||||
|
def download(url, dest=None, extract=True, ignore_if_exists=False):
|
||||||
|
"""
|
||||||
|
Download a file from the internet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: the url to download
|
||||||
|
dest: destination file if extract=False, or destionation dir if extract=True. If None, it will be the last part of URL.
|
||||||
|
extract: extract a tar.gz or zip file?
|
||||||
|
ignore_if_exists: don't do anything if file exists
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the destination filename.
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_url = url.split("?")[0]
|
||||||
|
|
||||||
|
if dest is None:
|
||||||
|
dest = [f for f in base_url.split("/") if f][-1]
|
||||||
|
|
||||||
|
if os.path.exists(dest) and ignore_if_exists:
|
||||||
|
return dest
|
||||||
|
|
||||||
|
stream = UrlStream(url)
|
||||||
|
extension = base_url.split(".")[-1].lower()
|
||||||
|
|
||||||
|
if extract and extension in ['gz', 'bz2', 'zip']:
|
||||||
|
os.makedirs(dest, exist_ok=True)
|
||||||
|
|
||||||
|
if extension in ['gz', 'bz2']:
|
||||||
|
decompressed_file = tarfile.open(fileobj=stream, mode='r|'+extension)
|
||||||
|
elif extension=='zip':
|
||||||
|
decompressed_file = zipfile.ZipFile(stream, mode='r')
|
||||||
|
else:
|
||||||
|
assert False, "Invalid extension: %s" % extension
|
||||||
|
|
||||||
|
decompressed_file.extractall(dest)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
with open(dest, 'wb') as f:
|
||||||
|
for d in stream.iter_content(1024*1024):
|
||||||
|
f.write(d)
|
||||||
|
except:
|
||||||
|
os.remove(dest)
|
||||||
|
raise
|
||||||
|
return dest
|
111
Utils/gpu_allocator.py
Normal file
111
Utils/gpu_allocator.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from Utils.lockfile import LockFile
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
try:
|
||||||
|
proc = subprocess.Popen("nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits".split(" "),
|
||||||
|
stdout=subprocess.PIPE)
|
||||||
|
lines = [s.strip().split(" ") for s in proc.communicate()[0].decode().split("\n") if s]
|
||||||
|
return {int(g[0][:-1]): int(g[1]) for g in lines}
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_free_gpus():
|
||||||
|
try:
|
||||||
|
free = []
|
||||||
|
proc = subprocess.Popen("nvidia-smi --query-compute-apps=gpu_uuid --format=csv,noheader,nounits".split(" "),
|
||||||
|
stdout=subprocess.PIPE)
|
||||||
|
uuids = [s.strip() for s in proc.communicate()[0].decode().split("\n") if s]
|
||||||
|
|
||||||
|
proc = subprocess.Popen("nvidia-smi --query-gpu=index,uuid --format=csv,noheader,nounits".split(" "),
|
||||||
|
stdout=subprocess.PIPE)
|
||||||
|
|
||||||
|
id_uid_pair = [s.strip().split(", ") for s in proc.communicate()[0].decode().split("\n") if s]
|
||||||
|
for i in id_uid_pair:
|
||||||
|
id, uid = i
|
||||||
|
|
||||||
|
if uid not in uuids:
|
||||||
|
free.append(int(id))
|
||||||
|
|
||||||
|
return free
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fix_order():
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = os.environ.get("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
|
||||||
|
|
||||||
|
def allocate(n:int = 1):
|
||||||
|
_fix_order()
|
||||||
|
with LockFile("/tmp/gpu_allocation_lock"):
|
||||||
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
|
print("WARNING: trying to allocate %d GPUs, but CUDA_VISIBLE_DEVICES already set to %s" %
|
||||||
|
(n, os.environ["CUDA_VISIBLE_DEVICES"]))
|
||||||
|
return
|
||||||
|
|
||||||
|
allocated = get_free_gpus()
|
||||||
|
if allocated is None:
|
||||||
|
print("WARNING: failed to allocate %d GPUs" % n)
|
||||||
|
return
|
||||||
|
allocated = allocated[:n]
|
||||||
|
|
||||||
|
if len(allocated) < n:
|
||||||
|
print("There is no more free GPUs. Allocating the one with least memory usage.")
|
||||||
|
usage = get_memory_usage()
|
||||||
|
if usage is None:
|
||||||
|
print("WARNING: failed to allocate %d GPUs" % n)
|
||||||
|
return
|
||||||
|
|
||||||
|
inv_usages = {}
|
||||||
|
|
||||||
|
for k, v in usage.items():
|
||||||
|
if v not in inv_usages:
|
||||||
|
inv_usages[v] = []
|
||||||
|
|
||||||
|
inv_usages[v].append(k)
|
||||||
|
|
||||||
|
min_usage = list(sorted(inv_usages.keys()))
|
||||||
|
min_usage_devs = []
|
||||||
|
for u in min_usage:
|
||||||
|
min_usage_devs += inv_usages[u]
|
||||||
|
|
||||||
|
min_usage_devs = [m for m in min_usage_devs if m not in allocated]
|
||||||
|
|
||||||
|
n2 = n - len(allocated)
|
||||||
|
if n2>len(min_usage_devs):
|
||||||
|
print("WARNING: trying to allocate %d GPUs but only %d available" % (n, len(min_usage_devs)+len(allocated)))
|
||||||
|
n2 = len(min_usage_devs)
|
||||||
|
|
||||||
|
allocated += min_usage_devs[:n2]
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"]=",".join([str(a) for a in allocated])
|
||||||
|
for i in range(len(allocated)):
|
||||||
|
a = torch.FloatTensor([0.0])
|
||||||
|
a.cuda(i)
|
||||||
|
|
||||||
|
def use_gpu(gpu="auto", n_autoalloc=1):
|
||||||
|
_fix_order()
|
||||||
|
|
||||||
|
gpu = gpu.lower()
|
||||||
|
if gpu in ["auto", ""]:
|
||||||
|
allocate(n_autoalloc)
|
||||||
|
else:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
|
41
Utils/lockfile.py
Normal file
41
Utils/lockfile.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import os
|
||||||
|
import fcntl
|
||||||
|
|
||||||
|
|
||||||
|
class LockFile:
|
||||||
|
def __init__(self, fname):
|
||||||
|
self._fname = fname
|
||||||
|
self._fd = None
|
||||||
|
|
||||||
|
def acquire(self):
|
||||||
|
self._fd=open(self._fname, "w")
|
||||||
|
os.chmod(self._fname, 0o777)
|
||||||
|
|
||||||
|
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
fcntl.lockf(self._fd, fcntl.LOCK_UN)
|
||||||
|
self._fd.close()
|
||||||
|
self._fd = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.acquire()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.release()
|
54
Utils/timer.py
Normal file
54
Utils/timer.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class OnceEvery:
|
||||||
|
def __init__(self, interval):
|
||||||
|
self._interval = interval
|
||||||
|
self._last_check = 0
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
now = time.time()
|
||||||
|
if now - self._last_check >= self._interval:
|
||||||
|
self._last_check = now
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Measure:
|
||||||
|
def __init__(self, average=1):
|
||||||
|
self._start = None
|
||||||
|
self._average = average
|
||||||
|
self._accu_value = 0.0
|
||||||
|
self._history_list = []
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._start = time.time()
|
||||||
|
|
||||||
|
def passed(self):
|
||||||
|
if self._start is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
p = time.time() - self._start
|
||||||
|
self._history_list.append(p)
|
||||||
|
self._accu_value += p
|
||||||
|
if len(self._history_list) > self._average:
|
||||||
|
self._accu_value -= self._history_list.pop(0)
|
||||||
|
|
||||||
|
return self._accu_value / len(self._history_list)
|
263
Utils/universal.py
Normal file
263
Utils/universal.py
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
float32 = [np.float32, torch.float32]
|
||||||
|
float64 = [np.float64, torch.float64]
|
||||||
|
uint8 = [np.uint8, torch.uint8]
|
||||||
|
|
||||||
|
_all_types = [float32, float64, uint8]
|
||||||
|
|
||||||
|
_dtype_numpy_map = {v[0]().dtype.name:v for v in _all_types}
|
||||||
|
_dtype_pytorch_map = {v[1]:v for v in _all_types}
|
||||||
|
|
||||||
|
|
||||||
|
def dtype(t):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return _dtype_pytorch_map[t.dtype]
|
||||||
|
else:
|
||||||
|
return _dtype_numpy_map[t.dtype.name]
|
||||||
|
|
||||||
|
|
||||||
|
def cast(t, type):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return t.type(type[1])
|
||||||
|
else:
|
||||||
|
return t.astype(type[0])
|
||||||
|
|
||||||
|
|
||||||
|
def to_numpy(t):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return t.detach().cpu().numpy()
|
||||||
|
else:
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def to_list(t):
|
||||||
|
t = to_numpy(t)
|
||||||
|
if isinstance(t, np.ndarray):
|
||||||
|
t = t.tolist()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def is_tensor(t):
|
||||||
|
return torch.is_tensor(t) or isinstance(t, np.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
def first_batch(t):
|
||||||
|
if is_tensor(t):
|
||||||
|
return t[0]
|
||||||
|
else:
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def ndim(t):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return t.dim()
|
||||||
|
else:
|
||||||
|
return t.ndim
|
||||||
|
|
||||||
|
|
||||||
|
def shape(t):
|
||||||
|
return list(t.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def transpose(t, axis):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return t.permute(axis)
|
||||||
|
else:
|
||||||
|
return np.transpose(t, axis)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_recursive(d, fn, filter=None):
|
||||||
|
if isinstance(d, list):
|
||||||
|
return [apply_recursive(da, fn) for da in d]
|
||||||
|
elif isinstance(d, tuple):
|
||||||
|
return tuple(apply_recursive(list(d), fn))
|
||||||
|
elif isinstance(d, dict):
|
||||||
|
return {k: apply_recursive(v, fn) for k, v in d.items()}
|
||||||
|
else:
|
||||||
|
if filter is None or filter(d):
|
||||||
|
return fn(d)
|
||||||
|
else:
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def apply_to_tensors(d, fn):
|
||||||
|
return apply_recursive(d, fn, torch.is_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def recursive_decorator(apply_this_fn):
|
||||||
|
def decorator(func):
|
||||||
|
def wrapped_funct(*args, **kwargs):
|
||||||
|
args = apply_recursive(args, apply_this_fn)
|
||||||
|
kwargs = apply_recursive(kwargs, apply_this_fn)
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped_funct
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
untensor = recursive_decorator(to_numpy)
|
||||||
|
unnumpy = recursive_decorator(to_list)
|
||||||
|
|
||||||
|
def unbatch(only_if_dim_equal=None):
|
||||||
|
if only_if_dim_equal is not None and not isinstance(only_if_dim_equal, list):
|
||||||
|
only_if_dim_equal = [only_if_dim_equal]
|
||||||
|
|
||||||
|
def get_first_batch(t):
|
||||||
|
if is_tensor(t) and (only_if_dim_equal is None or ndim(t) in only_if_dim_equal):
|
||||||
|
return t[0]
|
||||||
|
else:
|
||||||
|
return t
|
||||||
|
|
||||||
|
return recursive_decorator(get_first_batch)
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid(t):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return torch.sigmoid(t)
|
||||||
|
else:
|
||||||
|
return 1.0 / (1.0 + np.exp(-t))
|
||||||
|
|
||||||
|
|
||||||
|
def argmax(t, dim):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
_, res = t.max(dim)
|
||||||
|
else:
|
||||||
|
res = np.argmax(t, axis=dim)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def flip(t, axis):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return t.flip(axis)
|
||||||
|
else:
|
||||||
|
return np.flip(t, axis)
|
||||||
|
|
||||||
|
|
||||||
|
def transpose(t, axes):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return t.permute(*axes)
|
||||||
|
else:
|
||||||
|
return np.transpose(t, axes)
|
||||||
|
|
||||||
|
|
||||||
|
def split_n(t, axis):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return t.split(1, dim=axis)
|
||||||
|
else:
|
||||||
|
return np.split(t, t.shape[axis], axis=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def cat(array_of_tensors, axis):
|
||||||
|
if torch.is_tensor(array_of_tensors[0]):
|
||||||
|
return torch.cat(array_of_tensors, axis)
|
||||||
|
else:
|
||||||
|
return np.concatenate(array_of_tensors, axis)
|
||||||
|
|
||||||
|
|
||||||
|
def clamp(t, min=None, max=None):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return t.clamp(min, max)
|
||||||
|
else:
|
||||||
|
if min is not None:
|
||||||
|
t = np.maximum(t, min)
|
||||||
|
|
||||||
|
if max is not None:
|
||||||
|
t = np.minimum(t, max)
|
||||||
|
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def power(t, p):
|
||||||
|
if torch.is_tensor(t) or torch.is_tensor(p):
|
||||||
|
return torch.pow(t, p)
|
||||||
|
else:
|
||||||
|
return np.power(t, p)
|
||||||
|
|
||||||
|
|
||||||
|
def random_normal_as(a, mean, std, seed=None):
|
||||||
|
if torch.is_tensor(a):
|
||||||
|
return torch.randn_like(a) * std + mean
|
||||||
|
else:
|
||||||
|
if seed is None:
|
||||||
|
seed = np.random
|
||||||
|
return seed.normal(loc=mean, scale=std, size=shape(a))
|
||||||
|
|
||||||
|
|
||||||
|
def pad(t, pad):
|
||||||
|
assert ndim(t) == 4
|
||||||
|
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return F.pad(t, pad)
|
||||||
|
else:
|
||||||
|
assert np.pad(t, ([0,0], [0,0], pad[0:2], pad[2:]))
|
||||||
|
|
||||||
|
|
||||||
|
def dx(img):
|
||||||
|
lsh = img[:, :, :, 2:]
|
||||||
|
orig = img[:, :, :, :-2]
|
||||||
|
|
||||||
|
return pad(0.5 * (lsh - orig), (1, 1, 0, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def dy(img):
|
||||||
|
ush = img[:, :, 2:, :]
|
||||||
|
orig = img[:, :, :-2, :]
|
||||||
|
|
||||||
|
return pad(0.5 * (ush - orig), (0, 0, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def reshape(t, shape):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
return t.view(*shape)
|
||||||
|
else:
|
||||||
|
return t.reshape(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_to_beginning(t, target):
|
||||||
|
if torch.is_tensor(t):
|
||||||
|
nd_target = target.dim()
|
||||||
|
t_shape = list(t.shape)
|
||||||
|
return t.view(*t_shape, *([1]*(nd_target-len(t_shape))))
|
||||||
|
else:
|
||||||
|
nd_target = target.ndim
|
||||||
|
t_shape = list(t.shape)
|
||||||
|
return t.reshape(*t_shape, *([1] * (nd_target - len(t_shape))))
|
||||||
|
|
||||||
|
|
||||||
|
def lin_combine(d1,w1, d2,w2, bcast_begin=False):
|
||||||
|
if isinstance(d1, (list, tuple)):
|
||||||
|
assert len(d1) == len(d2)
|
||||||
|
res = [lin_combine(d1[i], w1, d2[i], w2) for i in range(len(d1))]
|
||||||
|
if isinstance(d1, tuple):
|
||||||
|
res = tuple(d1)
|
||||||
|
elif isinstance(d1, dict):
|
||||||
|
res = {k: lin_combine(v, w1, d2[k], w2) for k, v in d1.items()}
|
||||||
|
else:
|
||||||
|
if bcast_begin:
|
||||||
|
w1 = broadcast_to_beginning(w1, d1)
|
||||||
|
w2 = broadcast_to_beginning(w2, d2)
|
||||||
|
|
||||||
|
res = d1 * w1 + d2 * w2
|
||||||
|
|
||||||
|
return res
|
73
Visualize/BitmapTask.py
Normal file
73
Visualize/BitmapTask.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
except:
|
||||||
|
cv2=None
|
||||||
|
|
||||||
|
from Utils.Helpers import as_numpy
|
||||||
|
|
||||||
|
def visualize_bitmap_task(i_data, o_data, zoom=8):
|
||||||
|
if not isinstance(o_data, list):
|
||||||
|
o_data = [o_data]
|
||||||
|
|
||||||
|
imgs = []
|
||||||
|
for d in [i_data]+o_data:
|
||||||
|
if d is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
d=as_numpy(d)
|
||||||
|
if d.ndim>2:
|
||||||
|
d=d[0]
|
||||||
|
|
||||||
|
imgs.append(np.expand_dims(d.T*255, -1).astype(np.uint8))
|
||||||
|
|
||||||
|
img = np.concatenate(imgs, 0)
|
||||||
|
return nearest_zoom(img, zoom)
|
||||||
|
|
||||||
|
def visualize_01(t, zoom=8):
|
||||||
|
return nearest_zoom(np.expand_dims(t*255,-1).astype(np.uint8), zoom)
|
||||||
|
|
||||||
|
def nearest_zoom(img, zoom=1):
|
||||||
|
if zoom>1 and cv2 is not None:
|
||||||
|
return cv2.resize(img, (img.shape[1] * zoom, img.shape[0] * zoom), interpolation=cv2.INTER_NEAREST)
|
||||||
|
else:
|
||||||
|
return img
|
||||||
|
|
||||||
|
def concatenate_tensors(tensors):
|
||||||
|
max_size = None
|
||||||
|
dtype = None
|
||||||
|
|
||||||
|
for t in tensors:
|
||||||
|
s = t.shape
|
||||||
|
if max_size is None:
|
||||||
|
max_size = list(s)
|
||||||
|
dtype = t.dtype
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert t.ndim ==len(max_size), "Can only concatenate tensors with same ndim."
|
||||||
|
assert t.dtype == dtype, "Tensors must have the same type"
|
||||||
|
max_size = [max(max_size[i], s[i]) for i in range(len(max_size))]
|
||||||
|
|
||||||
|
res = np.zeros([len(tensors)] + max_size, dtype=dtype)
|
||||||
|
for i, t in enumerate(tensors):
|
||||||
|
res[i][tuple([slice(0,t.shape[i]) for i in range(t.ndim)])] = t
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
1
Visualize/__init__.py
Normal file
1
Visualize/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .BitmapTask import *
|
64
Visualize/preview.py
Normal file
64
Visualize/preview.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
import sys
|
||||||
|
from Utils import universal as U
|
||||||
|
|
||||||
|
|
||||||
|
def preview(vis_interval=None, to_numpy=True, debatch=False):
|
||||||
|
def decorator(func):
|
||||||
|
state = {
|
||||||
|
"last_vis_time": 0,
|
||||||
|
"thread_running": False
|
||||||
|
}
|
||||||
|
|
||||||
|
def wrapper_function(*args, **kwargs):
|
||||||
|
if state["thread_running"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
if vis_interval is not None:
|
||||||
|
now = time.time()
|
||||||
|
if now - state["last_vis_time"] < vis_interval:
|
||||||
|
return
|
||||||
|
state["last_vis_time"] = now
|
||||||
|
|
||||||
|
state["thread_running"] = True
|
||||||
|
|
||||||
|
if debatch:
|
||||||
|
args = U.apply_recursive(args, U.first_batch)
|
||||||
|
kwargs = U.apply_recursive(kwargs, U.first_batch)
|
||||||
|
|
||||||
|
if to_numpy:
|
||||||
|
args = U.apply_recursive(args, U.to_numpy)
|
||||||
|
kwargs = U.apply_recursive(kwargs, U.to_numpy)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
try:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
state["thread_running"] = False
|
||||||
|
|
||||||
|
download_thread = threading.Thread(target=run)
|
||||||
|
download_thread.start()
|
||||||
|
|
||||||
|
return wrapper_function
|
||||||
|
return decorator
|
671
main.py
Executable file
671
main.py
Executable file
@ -0,0 +1,671 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
import Utils.Debug as debug
|
||||||
|
from Dataset.Bitmap.AssociativeRecall import AssociativeRecall
|
||||||
|
from Dataset.Bitmap.BitmapTaskRepeater import BitmapTaskRepeater
|
||||||
|
from Dataset.Bitmap.KeyValue import KeyValue
|
||||||
|
from Dataset.Bitmap.CopyTask import CopyData
|
||||||
|
from Dataset.Bitmap.KeyValue2Way import KeyValue2Way
|
||||||
|
from Dataset.NLP.bAbi import bAbiDataset
|
||||||
|
from Models.DNC import DNC, LSTMController, FeedforwardController
|
||||||
|
from Utils import Visdom
|
||||||
|
from Utils.ArgumentParser import ArgumentParser
|
||||||
|
from Utils.Index import index_by_dim
|
||||||
|
from Utils.Saver import Saver, GlobalVarSaver, StateSaver
|
||||||
|
from Utils.Collate import MetaCollate
|
||||||
|
from Utils import gpu_allocator
|
||||||
|
from Dataset.NLP.NLPTask import NLPTask
|
||||||
|
from tqdm import tqdm
|
||||||
|
from Visualize.preview import preview
|
||||||
|
from Utils.timer import OnceEvery
|
||||||
|
from Utils import Seed
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import math
|
||||||
|
from Utils import Profile
|
||||||
|
|
||||||
|
Profile.ENABLED=False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global i
|
||||||
|
global loss_sum
|
||||||
|
global running
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("-bit_w", type=int, default=8, help="Bit vector length for copy task")
|
||||||
|
parser.add_argument("-block_w", type=int, default=3, help="Block width to associative recall task")
|
||||||
|
parser.add_argument("-len", type=str, default="4", help="Sequence length for copy task", parser=lambda x: [int(a) for a in x.split("-")])
|
||||||
|
parser.add_argument("-repeat", type=str, default="1", help="Sequence length for copy task", parser=lambda x: [int(a) for a in x.split("-")])
|
||||||
|
parser.add_argument("-batch_size", type=int, default=16, help="Sequence length for copy task")
|
||||||
|
parser.add_argument("-n_subbatch", type=str, default="auto", help="Average this much forward passes to a backward pass")
|
||||||
|
parser.add_argument("-max_input_count_per_batch", type=int, default=6000, help="Max batch_size*len that can fit into memory")
|
||||||
|
parser.add_argument("-lr", type=float, default=0.0001, help="Learning rate")
|
||||||
|
parser.add_argument("-wd", type=float, default=1e-5, help="Weight decay")
|
||||||
|
parser.add_argument("-optimizer", type=str, default="rmsprop", help="Optimizer algorithm")
|
||||||
|
parser.add_argument("-name", type=str, help="Save training to this directory")
|
||||||
|
parser.add_argument("-preview_interval", type=int, default=10, help="Show preview every nth iteration")
|
||||||
|
parser.add_argument("-info_interval", type=int, default=10, help="Show info every nth iteration")
|
||||||
|
parser.add_argument("-save_interval", type=int, default=500, help="Save network every nth iteration")
|
||||||
|
parser.add_argument("-masked_lookup", type=bool, default=1, help="Enable masking in content lookups")
|
||||||
|
parser.add_argument("-visport", type=int, default=-1, help="Port to run Visdom server on. -1 to disable")
|
||||||
|
parser.add_argument("-gpu", default="auto", type=str, help="Run on this GPU.")
|
||||||
|
parser.add_argument("-debug", type=bool, default=1, help="Enable debugging")
|
||||||
|
parser.add_argument("-task", type=str, default="copy", help="Task to learn")
|
||||||
|
parser.add_argument("-mem_count", type=int, default=16, help="Number of memory cells")
|
||||||
|
parser.add_argument("-data_word_size", type=int, default=128, help="Memory word size")
|
||||||
|
parser.add_argument("-n_read_heads", type=int, default=1, help="Number of read heads")
|
||||||
|
parser.add_argument("-layer_sizes", type=str, default="256", help="Controller layer sizes. Separate with ,. For example 512,256,256", parser=lambda x: [int(y) for y in x.split(",") if y])
|
||||||
|
parser.add_argument("-debug_log", type=bool, default=0, help="Enable debug log")
|
||||||
|
parser.add_argument("-controller_type", type=str, default="lstm", help="Controller type: lstm or linear")
|
||||||
|
parser.add_argument("-lstm_use_all_outputs", type=bool, default=1, help="Use all LSTM outputs as controller output vs use only the last layer")
|
||||||
|
parser.add_argument("-momentum", type=float, default=0.9, help="Momentum for optimizer")
|
||||||
|
parser.add_argument("-embedding_size", type=int, default=256, help="Size of word embedding for NLP tasks")
|
||||||
|
parser.add_argument("-test_interval", type=int, default=10000, help="Run test in this interval")
|
||||||
|
parser.add_argument("-dealloc_content", type=bool, default=1, help="Deallocate memory content, unlike DNC, which leaves it unchanged, just decreases the usage counter, causing problems with lookup")
|
||||||
|
parser.add_argument("-sharpness_control", type=bool, default=1, help="Distribution sharpness control for forward and backward links")
|
||||||
|
parser.add_argument("-think_steps", type=int, default=0, help="Iddle steps before requiring the answer (for bAbi)")
|
||||||
|
parser.add_argument("-dump_profile", type=str, save=False)
|
||||||
|
parser.add_argument("-test_on_start", default="0", save=False)
|
||||||
|
parser.add_argument("-dump_heatmaps", default=False, save=False)
|
||||||
|
parser.add_argument("-test_batch_size", default=16)
|
||||||
|
parser.add_argument("-mask_min", default=0.0)
|
||||||
|
parser.add_argument("-load", type=str, save=False)
|
||||||
|
parser.add_argument("-dataset_path", type=str, default="none", parser=ArgumentParser.str_or_none(), help="Specify babi path manually")
|
||||||
|
parser.add_argument("-babi_train_tasks", type=str, default="none", parser=ArgumentParser.list_or_none(type=str), help="babi task list to use for training")
|
||||||
|
parser.add_argument("-babi_test_tasks", type=str, default="none", parser=ArgumentParser.list_or_none(type=str), help="babi task list to use for testing")
|
||||||
|
parser.add_argument("-babi_train_sets", type=str, default="train", parser=ArgumentParser.list_or_none(type=str), help="babi train sets to use")
|
||||||
|
parser.add_argument("-babi_test_sets", type=str, default="test", parser=ArgumentParser.list_or_none(type=str), help="babi test sets to use")
|
||||||
|
parser.add_argument("-noargsave", type=bool, default=False, help="Do not save modified arguments", save=False)
|
||||||
|
parser.add_argument("-demo", type=bool, default=False, help="Do a single step with fixed seed", save=False)
|
||||||
|
parser.add_argument("-exit_after", type=int, help="Exit after this amount of steps. Useful for debugging.", save=False)
|
||||||
|
parser.add_argument("-grad_clip", type=float, default=10.0, help="Max gradient norm")
|
||||||
|
parser.add_argument("-clip_controller", type=float, default=20.0, help="Max gradient norm")
|
||||||
|
parser.add_argument("-print_test", default=False, save=False)
|
||||||
|
|
||||||
|
parser.add_profile([
|
||||||
|
ArgumentParser.Profile("babi", {
|
||||||
|
"preview_interval": 10,
|
||||||
|
"save_interval": 500,
|
||||||
|
"task": "babi",
|
||||||
|
"mem_count": 256,
|
||||||
|
"data_word_size": 64,
|
||||||
|
"n_read_heads": 4,
|
||||||
|
"layer_sizes": "256",
|
||||||
|
"controller_type": "lstm",
|
||||||
|
"lstm_use_all_outputs": True,
|
||||||
|
"momentum": 0.9,
|
||||||
|
"embedding_size": 128,
|
||||||
|
"test_interval": 5000,
|
||||||
|
"think_steps": 3,
|
||||||
|
"batch_size": 2
|
||||||
|
}, include=["dnc-msd"]),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("repeat_copy", {
|
||||||
|
"bit_w": 8,
|
||||||
|
"repeat": "1-8",
|
||||||
|
"len": "2-14",
|
||||||
|
"task": "copy",
|
||||||
|
"think_steps": 1,
|
||||||
|
"preview_interval": 10,
|
||||||
|
"info_interval": 10,
|
||||||
|
"save_interval": 100,
|
||||||
|
"data_word_size": 16,
|
||||||
|
"layer_sizes": "32",
|
||||||
|
"n_subbatch": 1,
|
||||||
|
"controller_type": "lstm",
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("repeat_copy_simple", {
|
||||||
|
"repeat": "1-3",
|
||||||
|
}, include="repeat_copy"),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("dnc", {
|
||||||
|
"masked_lookup": False,
|
||||||
|
"sharpness_control": False,
|
||||||
|
"dealloc_content": False
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("dnc-m", {
|
||||||
|
"masked_lookup": True,
|
||||||
|
"sharpness_control": False,
|
||||||
|
"dealloc_content": False
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("dnc-s", {
|
||||||
|
"masked_lookup": False,
|
||||||
|
"sharpness_control": True,
|
||||||
|
"dealloc_content": False
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("dnc-d", {
|
||||||
|
"masked_lookup": False,
|
||||||
|
"sharpness_control": False,
|
||||||
|
"dealloc_content": True
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("dnc-md", {
|
||||||
|
"masked_lookup": True,
|
||||||
|
"sharpness_control": False,
|
||||||
|
"dealloc_content": True
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("dnc-ms", {
|
||||||
|
"masked_lookup": True,
|
||||||
|
"sharpness_control": True,
|
||||||
|
"dealloc_content": False
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("dnc-sd", {
|
||||||
|
"masked_lookup": False,
|
||||||
|
"sharpness_control": True,
|
||||||
|
"dealloc_content": True
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("dnc-msd", {
|
||||||
|
"masked_lookup": True,
|
||||||
|
"sharpness_control": True,
|
||||||
|
"dealloc_content": True
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("keyvalue", {
|
||||||
|
"repeat": "1",
|
||||||
|
"len": "2-16",
|
||||||
|
"mem_count": 16,
|
||||||
|
"task": "keyvalue",
|
||||||
|
"think_steps": 1,
|
||||||
|
"preview_interval": 10,
|
||||||
|
"info_interval": 10,
|
||||||
|
"data_word_size": 32,
|
||||||
|
"bit_w": 12,
|
||||||
|
"save_interval": 1000,
|
||||||
|
"layer_sizes": "32"
|
||||||
|
}),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("keyvalue2way", {
|
||||||
|
"task": "keyvalue2way",
|
||||||
|
}, include="keyvalue"),
|
||||||
|
|
||||||
|
ArgumentParser.Profile("associative_recall",{
|
||||||
|
"task": "recall",
|
||||||
|
"bit_w": 8,
|
||||||
|
"len": "2-16",
|
||||||
|
"mem_count": 64,
|
||||||
|
"data_word_size": 32,
|
||||||
|
"n_read_heads": 1,
|
||||||
|
"layer_sizes": "128",
|
||||||
|
"controller_type": "lstm",
|
||||||
|
"lstm_use_all_outputs": 1,
|
||||||
|
"think_steps": 1,
|
||||||
|
"mask_min": 0.1,
|
||||||
|
"info_interval": 10,
|
||||||
|
"save_interval": 1000,
|
||||||
|
"preview_interval": 10,
|
||||||
|
"n_subbatch": 1,
|
||||||
|
})
|
||||||
|
])
|
||||||
|
|
||||||
|
opt = parser.parse()
|
||||||
|
assert opt.name is not None, "Training dir (-name parameter) not given"
|
||||||
|
opt = parser.sync(os.path.join(opt.name, "args.json"), save=not opt.noargsave)
|
||||||
|
|
||||||
|
if opt.demo:
|
||||||
|
Seed.fix()
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(opt.name,"save"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(opt.name,"preview"), exist_ok=True)
|
||||||
|
|
||||||
|
gpu_allocator.use_gpu(opt.gpu)
|
||||||
|
|
||||||
|
debug.enableDebug = opt.debug_log
|
||||||
|
|
||||||
|
if opt.visport>0:
|
||||||
|
Visdom.start(opt.visport)
|
||||||
|
|
||||||
|
Visdom.Text("Name").set(opt.name)
|
||||||
|
|
||||||
|
class LengthHackSampler:
|
||||||
|
def __init__(self, batch_size, length):
|
||||||
|
self.length = length
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
while True:
|
||||||
|
len = self.length() if callable(self.length) else self.length
|
||||||
|
yield [len] * self.batch_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 0x7FFFFFFF
|
||||||
|
|
||||||
|
embedding = None
|
||||||
|
test_set = None
|
||||||
|
curriculum = None
|
||||||
|
loader_reset = False
|
||||||
|
if opt.task=="copy":
|
||||||
|
dataset = CopyData(bit_w=opt.bit_w)
|
||||||
|
in_size = opt.bit_w + 1
|
||||||
|
out_size = in_size
|
||||||
|
elif opt.task=="recall":
|
||||||
|
dataset = AssociativeRecall(bit_w=opt.bit_w, block_w=opt.block_w)
|
||||||
|
in_size = opt.bit_w + 2
|
||||||
|
out_size = in_size
|
||||||
|
elif opt.task=="keyvalue":
|
||||||
|
assert opt.bit_w % 2==0, "Key-value datasets works only with even bit_w"
|
||||||
|
dataset = KeyValue(bit_w=opt.bit_w)
|
||||||
|
in_size = opt.bit_w + 1
|
||||||
|
out_size = opt.bit_w//2
|
||||||
|
elif opt.task=="keyvalue2way":
|
||||||
|
assert opt.bit_w % 2==0, "Key-value datasets works only with even bit_w"
|
||||||
|
dataset = KeyValue2Way(bit_w=opt.bit_w)
|
||||||
|
in_size = opt.bit_w + 2
|
||||||
|
out_size = opt.bit_w//2
|
||||||
|
elif opt.task=="babi":
|
||||||
|
dataset = bAbiDataset(think_steps=opt.think_steps, dir_name=opt.dataset_path)
|
||||||
|
test_set = bAbiDataset(think_steps=opt.think_steps, dir_name=opt.dataset_path, name="test")
|
||||||
|
dataset.use(opt.babi_train_tasks, opt.babi_train_sets)
|
||||||
|
in_size = opt.embedding_size
|
||||||
|
print("bAbi: loaded total of %d sequences." % len(dataset))
|
||||||
|
test_set.use(opt.babi_test_tasks, opt.babi_test_sets)
|
||||||
|
out_size = len(dataset.vocabulary)
|
||||||
|
print("bAbi: using %d sequences for training, %d for testing" % (len(dataset), len(test_set)))
|
||||||
|
else:
|
||||||
|
assert False, "Invalid task: %s" % opt.task
|
||||||
|
|
||||||
|
if opt.task in ["babi"]:
|
||||||
|
data_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, num_workers=4, pin_memory=True, shuffle=True, collate_fn=MetaCollate())
|
||||||
|
test_loader = torch.utils.data.DataLoader(test_set, batch_size=opt.test_batch_size, num_workers=opt.test_batch_size, pin_memory=True, shuffle=False, collate_fn=MetaCollate()) if test_set is not None else None
|
||||||
|
else:
|
||||||
|
dataset = BitmapTaskRepeater(dataset)
|
||||||
|
data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=LengthHackSampler(opt.batch_size, BitmapTaskRepeater.key_sampler(opt.len, opt.repeat)), num_workers=1, pin_memory=True)
|
||||||
|
|
||||||
|
if opt.controller_type == "lstm":
|
||||||
|
controller_constructor = functools.partial(LSTMController, out_from_all_layers=opt.lstm_use_all_outputs)
|
||||||
|
elif opt.controller_type == "linear":
|
||||||
|
controller_constructor = FeedforwardController
|
||||||
|
else:
|
||||||
|
assert False, "Invalid controller: %s" % opt.controller_type
|
||||||
|
|
||||||
|
model = DNC(in_size, out_size, opt.data_word_size, opt.mem_count, opt.n_read_heads, controller_constructor(opt.layer_sizes),
|
||||||
|
batch_first=True, mask=opt.masked_lookup, dealloc_content=opt.dealloc_content,
|
||||||
|
link_sharpness_control=opt.sharpness_control,
|
||||||
|
mask_min=opt.mask_min, clip_controller=opt.clip_controller)
|
||||||
|
|
||||||
|
params = [
|
||||||
|
{'params': [p for n, p in model.named_parameters() if not n.endswith(".bias")]},
|
||||||
|
{'params': [p for n, p in model.named_parameters() if n.endswith(".bias")], 'weight_decay': 0}
|
||||||
|
]
|
||||||
|
|
||||||
|
device = torch.device('cuda') if opt.gpu!="none" else torch.device("cpu")
|
||||||
|
print("DEVICE: ", device)
|
||||||
|
|
||||||
|
if isinstance(dataset, NLPTask):
|
||||||
|
embedding = torch.nn.Embedding(len(dataset.vocabulary), opt.embedding_size).to(device)
|
||||||
|
params.append({'params': embedding.parameters(), 'weight_decay': 0})
|
||||||
|
|
||||||
|
if opt.optimizer=="sgd":
|
||||||
|
optimizer = torch.optim.SGD(params, lr=opt.lr, weight_decay=opt.wd, momentum=opt.momentum)
|
||||||
|
elif opt.optimizer=="adam":
|
||||||
|
optimizer = torch.optim.Adam(params, lr=opt.lr, weight_decay=opt.wd)
|
||||||
|
elif opt.optimizer == "rmsprop":
|
||||||
|
optimizer = torch.optim.RMSprop(params, lr=opt.lr, weight_decay=opt.wd, momentum=opt.momentum, eps=1e-10)
|
||||||
|
else:
|
||||||
|
assert "Invalid optimizer: %s" % opt.optimizer
|
||||||
|
|
||||||
|
n_params = sum([sum([t.numel() for t in d['params']]) for d in params])
|
||||||
|
print("Number of parameters: %d" % n_params)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
if embedding is not None and hasattr(embedding, "to"):
|
||||||
|
embedding = embedding.to(device)
|
||||||
|
|
||||||
|
i=0
|
||||||
|
loss_sum = 0
|
||||||
|
|
||||||
|
loss_plot = Visdom.Plot2D("loss", store_interval=opt.info_interval, xlabel="iterations", ylabel="loss")
|
||||||
|
|
||||||
|
if curriculum is not None:
|
||||||
|
curriculum_plot = Visdom.Plot2D("curriculum lesson" +
|
||||||
|
(" (last %d)" % (curriculum.n_lessons-1) if curriculum.n_lessons is not None else ""),
|
||||||
|
xlabel="iterations", ylabel="lesson")
|
||||||
|
curriculum_accuracy = Visdom.Plot2D("curriculum accuracy", xlabel="iterations", ylabel="accuracy")
|
||||||
|
|
||||||
|
saver = Saver(os.path.join(opt.name, "save"), short_interval=opt.save_interval)
|
||||||
|
saver.register("model", StateSaver(model))
|
||||||
|
saver.register("optimizer", StateSaver(optimizer))
|
||||||
|
saver.register("i", GlobalVarSaver("i"))
|
||||||
|
saver.register("loss_sum", GlobalVarSaver("loss_sum"))
|
||||||
|
saver.register("loss_plot", StateSaver(loss_plot))
|
||||||
|
saver.register("dataset", StateSaver(dataset))
|
||||||
|
if test_set:
|
||||||
|
saver.register("test_set", StateSaver(test_set))
|
||||||
|
|
||||||
|
if curriculum is not None:
|
||||||
|
saver.register("curriculum", StateSaver(curriculum))
|
||||||
|
saver.register("curriculum_plot", StateSaver(curriculum_plot))
|
||||||
|
saver.register("curriculum_accuracy", StateSaver(curriculum_accuracy))
|
||||||
|
|
||||||
|
if isinstance(dataset, NLPTask):
|
||||||
|
saver.register("word_embeddings", StateSaver(embedding))
|
||||||
|
elif embedding is not None:
|
||||||
|
saver.register("embeddings", StateSaver(embedding))
|
||||||
|
|
||||||
|
if not saver.load(opt.load):
|
||||||
|
model.reset_parameters()
|
||||||
|
if embedding is not None:
|
||||||
|
embedding.reset_parameters()
|
||||||
|
|
||||||
|
visualizers = {}
|
||||||
|
|
||||||
|
debug_schemas={
|
||||||
|
"read_head" : {
|
||||||
|
"list_dim" : 2
|
||||||
|
},
|
||||||
|
"temporal_links/forward_dists" : {
|
||||||
|
"list_dim" : 2
|
||||||
|
},
|
||||||
|
"temporal_links/backward_dists" : {
|
||||||
|
"list_dim" : 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def plot_debug(debug, prefix="", schema={}):
|
||||||
|
if debug is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for k, v in debug.items():
|
||||||
|
curr_name = prefix+k
|
||||||
|
if curr_name in debug_schemas:
|
||||||
|
curr_schema = schema.copy()
|
||||||
|
curr_schema.update(debug_schemas[curr_name])
|
||||||
|
else:
|
||||||
|
curr_schema = schema
|
||||||
|
|
||||||
|
if isinstance(v, dict):
|
||||||
|
plot_debug(v, curr_name+"/", curr_schema)
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = v[0]
|
||||||
|
|
||||||
|
if curr_schema.get("list_dim",-1) > 0:
|
||||||
|
if data.ndim != 3:
|
||||||
|
print("WARNING: unknown data shape for array display: %s, tensor %s" % (data.shape, curr_name))
|
||||||
|
continue
|
||||||
|
|
||||||
|
n_steps = data.shape[curr_schema["list_dim"]-1]
|
||||||
|
if curr_name not in visualizers:
|
||||||
|
visualizers[curr_name] = [Visdom.Heatmap(curr_name+"_%d" % i, dumpdir=os.path.join(opt.name, "preview") if opt.dump_heatmaps else None) for i in range(n_steps)]
|
||||||
|
|
||||||
|
for i in range(n_steps):
|
||||||
|
visualizers[curr_name][i].draw(index_by_dim(data, curr_schema["list_dim"]-1, i))
|
||||||
|
else:
|
||||||
|
if data.ndim != 2:
|
||||||
|
print("WARNING: unknown data shape for simple display: %s, tensor %s" % (data.shape, curr_name))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if curr_name not in visualizers:
|
||||||
|
visualizers[curr_name] = Visdom.Heatmap(curr_name, dumpdir=os.path.join(opt.name, "preview") if opt.dump_heatmaps else None)
|
||||||
|
|
||||||
|
visualizers[curr_name].draw(data)
|
||||||
|
|
||||||
|
|
||||||
|
def run_model(input, debug=None):
|
||||||
|
if isinstance(dataset, NLPTask):
|
||||||
|
input = embedding(input["input"])
|
||||||
|
else:
|
||||||
|
input = input["input"] * 2.0 - 1.0
|
||||||
|
|
||||||
|
return model(input, debug=debug)
|
||||||
|
|
||||||
|
def multiply_grads(params, mul):
|
||||||
|
if mul==1:
|
||||||
|
return
|
||||||
|
|
||||||
|
for pa in params:
|
||||||
|
for p in pa["params"]:
|
||||||
|
p.grad.data *= mul
|
||||||
|
|
||||||
|
def test():
|
||||||
|
if test_set is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("TESTING...")
|
||||||
|
start_time=time.time()
|
||||||
|
t = test_set.start_test()
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in tqdm(test_loader):
|
||||||
|
data = {k: v.to(device) if torch.is_tensor(v) else v for k, v in data.items()}
|
||||||
|
if hasattr(dataset, "prepare"):
|
||||||
|
data = dataset.prepare(data)
|
||||||
|
|
||||||
|
net_out = run_model(data)
|
||||||
|
test_set.veify_result(t, data, net_out)
|
||||||
|
|
||||||
|
test_set.show_test_results(i, t)
|
||||||
|
print("Test done in %gs" % (time.time() - start_time))
|
||||||
|
|
||||||
|
if opt.test_on_start.lower() in ["on", "1", "true", "quit"]:
|
||||||
|
test()
|
||||||
|
if opt.test_on_start.lower() == "quit":
|
||||||
|
saver.write(i)
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
if opt.print_test:
|
||||||
|
model.eval()
|
||||||
|
total = 0
|
||||||
|
correct = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in tqdm(test_loader):
|
||||||
|
if not running:
|
||||||
|
return
|
||||||
|
|
||||||
|
data = {k: v.to(device) if torch.is_tensor(v) else v for k, v in data.items()}
|
||||||
|
if hasattr(test_set, "prepare"):
|
||||||
|
data = test_set.prepare(data)
|
||||||
|
|
||||||
|
net_out = run_model(data)
|
||||||
|
|
||||||
|
c,t = test_set.curriculum_measure(net_out, data["output"])
|
||||||
|
total += t
|
||||||
|
correct += c
|
||||||
|
|
||||||
|
print("Test result: %2.f%% (%d out of %d correct)" % (100.0*correct/total, correct, total))
|
||||||
|
model.train()
|
||||||
|
return
|
||||||
|
|
||||||
|
iter_start_time = time.time() if i % opt.info_interval == 0 else None
|
||||||
|
data_load_total_time = 0
|
||||||
|
|
||||||
|
start_i = i
|
||||||
|
|
||||||
|
if opt.dump_profile:
|
||||||
|
profiler = torch.autograd.profiler.profile(use_cuda=True)
|
||||||
|
|
||||||
|
|
||||||
|
if opt.dump_heatmaps:
|
||||||
|
dataset.set_dump_dir(os.path.join(opt.name, "preview"))
|
||||||
|
|
||||||
|
@preview()
|
||||||
|
def do_visualize(raw_data, output, pos_map, debug):
|
||||||
|
if pos_map is not None:
|
||||||
|
output = embedding.backmap_output(output, pos_map, raw_data["output"].shape[1])
|
||||||
|
dataset.visualize_preview(raw_data, output)
|
||||||
|
|
||||||
|
if debug is not None:
|
||||||
|
plot_debug(debug)
|
||||||
|
|
||||||
|
preview_timer=OnceEvery(opt.preview_interval)
|
||||||
|
|
||||||
|
pos_map = None
|
||||||
|
start_iter = i
|
||||||
|
|
||||||
|
if curriculum is not None:
|
||||||
|
curriculum.init()
|
||||||
|
|
||||||
|
while running:
|
||||||
|
data_load_timer = time.time()
|
||||||
|
for data in data_loader:
|
||||||
|
if not running:
|
||||||
|
break
|
||||||
|
|
||||||
|
if loader_reset:
|
||||||
|
print("Loader reset requested. Resetting...")
|
||||||
|
loader_reset = False
|
||||||
|
if curriculum is not None:
|
||||||
|
curriculum.lesson_started()
|
||||||
|
break
|
||||||
|
|
||||||
|
if opt.dump_profile:
|
||||||
|
if i==start_i+1:
|
||||||
|
print("Starting profiler")
|
||||||
|
profiler.__enter__()
|
||||||
|
elif i==start_i+5+1:
|
||||||
|
print("Stopping profiler")
|
||||||
|
profiler.__exit__(None, None, None)
|
||||||
|
print("Average stats")
|
||||||
|
print(profiler.key_averages().table("cpu_time_total"))
|
||||||
|
print("Writing trace to file")
|
||||||
|
profiler.export_chrome_trace(opt.dump_profile)
|
||||||
|
print("Done.")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("Step %d out of 5" % (i-start_i))
|
||||||
|
|
||||||
|
debug.dbg_print("-------------------------------------")
|
||||||
|
raw_data = data
|
||||||
|
|
||||||
|
data = {k: v.to(device) if torch.is_tensor(v) else v for k,v in data.items()}
|
||||||
|
if hasattr(dataset, "prepare"):
|
||||||
|
data = dataset.prepare(data)
|
||||||
|
|
||||||
|
data_load_total_time += time.time() - data_load_timer
|
||||||
|
|
||||||
|
need_preview = preview_timer()
|
||||||
|
debug_data = {} if opt.debug and need_preview else None
|
||||||
|
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if opt.n_subbatch=="auto":
|
||||||
|
n_subbatch = math.ceil(data["input"].numel() / opt.max_input_count_per_batch)
|
||||||
|
else:
|
||||||
|
n_subbatch = int(opt.n_subbatch)
|
||||||
|
|
||||||
|
real_batch = max(math.floor(opt.batch_size/n_subbatch),1)
|
||||||
|
n_subbatch = math.ceil(opt.batch_size/real_batch)
|
||||||
|
remaning_batch = opt.batch_size % real_batch
|
||||||
|
|
||||||
|
for subbatch in range(n_subbatch):
|
||||||
|
if not running:
|
||||||
|
break
|
||||||
|
input = data["input"]
|
||||||
|
target = data["output"]
|
||||||
|
|
||||||
|
if n_subbatch!=1:
|
||||||
|
input = input[subbatch * real_batch: (subbatch + 1) * real_batch]
|
||||||
|
target = target[subbatch * real_batch:(subbatch + 1) * real_batch]
|
||||||
|
|
||||||
|
f2 = data.copy()
|
||||||
|
f2["input"] = input
|
||||||
|
output = run_model(f2, debug=debug_data if subbatch==n_subbatch-1 else None)
|
||||||
|
l = dataset.loss(output, target)
|
||||||
|
debug.nan_check(l, force=True)
|
||||||
|
l.backward()
|
||||||
|
|
||||||
|
if curriculum is not None:
|
||||||
|
curriculum.update(*dataset.curriculum_measure(output, target))
|
||||||
|
|
||||||
|
if remaning_batch!=0 and subbatch == n_subbatch-2:
|
||||||
|
multiply_grads(params, real_batch/remaning_batch)
|
||||||
|
|
||||||
|
if n_subbatch!=1:
|
||||||
|
if remaning_batch==0:
|
||||||
|
multiply_grads(params, 1/n_subbatch)
|
||||||
|
else:
|
||||||
|
multiply_grads(params, remaning_batch / opt.batch_size)
|
||||||
|
|
||||||
|
for p in params:
|
||||||
|
torch.nn.utils.clip_grad_norm_(p["params"], opt.grad_clip)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
curr_loss = l.data.item()
|
||||||
|
loss_plot.add_point(i, curr_loss)
|
||||||
|
|
||||||
|
loss_sum += curr_loss
|
||||||
|
|
||||||
|
|
||||||
|
if i % opt.info_interval == 0:
|
||||||
|
tim = time.time()
|
||||||
|
loss_avg = loss_sum / opt.info_interval
|
||||||
|
|
||||||
|
if curriculum is not None:
|
||||||
|
curriculum_accuracy.add_point(i, curriculum.get_accuracy())
|
||||||
|
curriculum_plot.add_point(i, curriculum.step)
|
||||||
|
|
||||||
|
message = "Iteration %d, loss: %.4f" % (i, loss_avg)
|
||||||
|
if iter_start_time is not None:
|
||||||
|
message += " (%.2f ms/iter, load time %.2g ms/iter, visport: %s)" % (
|
||||||
|
(tim - iter_start_time) / opt.info_interval * 1000.0,
|
||||||
|
data_load_total_time / opt.info_interval * 1000.0,
|
||||||
|
Visdom.port)
|
||||||
|
print(message)
|
||||||
|
iter_start_time = tim
|
||||||
|
loss_sum = 0
|
||||||
|
data_load_total_time = 0
|
||||||
|
|
||||||
|
debug.dbg_print("Iteration %d, loss %g" % (i, curr_loss))
|
||||||
|
|
||||||
|
if need_preview:
|
||||||
|
do_visualize(raw_data, output, pos_map, debug_data)
|
||||||
|
|
||||||
|
if i % opt.test_interval==0:
|
||||||
|
test()
|
||||||
|
|
||||||
|
saver.tick(i)
|
||||||
|
|
||||||
|
if opt.demo and opt.exit_after is None:
|
||||||
|
running = False
|
||||||
|
input("Press enter to quit.")
|
||||||
|
|
||||||
|
if opt.exit_after is not None and (i-start_iter)>=opt.exit_after:
|
||||||
|
running=False
|
||||||
|
|
||||||
|
data_load_timer = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
global running
|
||||||
|
running = True
|
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(signal, frame):
|
||||||
|
global running
|
||||||
|
print('You pressed Ctrl+C!')
|
||||||
|
running = False
|
||||||
|
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
main()
|
3
requirements.txt
Normal file
3
requirements.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
tqdm
|
||||||
|
visdom
|
||||||
|
numpy
|
Loading…
Reference in New Issue
Block a user