Initial commit

This commit is contained in:
Robert Csordas 2018-11-15 20:31:23 +01:00
commit 2a2b6bfd78
36 changed files with 4140 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
__pycache__
*.png
save
.idea

View File

@ -0,0 +1,66 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
from .BitmapTask import BitmapTask
from Utils.Seed import get_randstate
class AssociativeRecall(BitmapTask):
def __init__(self, length=None, bit_w=8, block_w=3, transform=lambda x: x):
super(AssociativeRecall, self).__init__()
self.length = length
self.bit_w = bit_w
self.block_w = block_w
self.transform = transform
self.seed = None
def __getitem__(self, key):
if self.seed is None:
self.seed = get_randstate()
length = self.length() if callable(self.length) else self.length
if length is None:
# Random length batch hack.
length = key
stride = self.block_w+1
d = self.seed.randint(0, 2, [length * (self.block_w+1), self.bit_w + 2]).astype(np.float32)
d[:,-2:] = 0
# Terminate input block
for i in range(1,length,1):
d[i * stride - 1, :] = 0
d[i * stride - 1, -2] = 1
# Terminate input sequence
d[-1, :] = 0
d[-1, -1] = 1
# Add and terminate query
ti = self.seed.randint(0, length-1)
d = np.concatenate((d, d[ti * stride: (ti+1) * stride-1], np.zeros([self.block_w+1, self.bit_w+2], np.float32)), axis=0)
d[-(1+self.block_w),-1] = 1
# Target
target = np.zeros_like(d)
target[-self.block_w:] = d[(ti+1) * stride: (ti+2) * stride-1]
return self.transform({
"input": d,
"output": target
})

View File

@ -0,0 +1,47 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import torch.nn.functional as F
from Visualize.BitmapTask import visualize_bitmap_task
from Utils import Visdom
from Utils import universal as U
class BitmapTask(torch.utils.data.Dataset):
def __init__(self):
super(BitmapTask, self).__init__()
self._img = Visdom.Image("preview")
def set_dump_dir(self, dir):
self._img.set_dump_dir(dir)
def __len__(self):
return 0x7FFFFFFF
def visualize_preview(self, data, net_output):
img = visualize_bitmap_task(data["input"], [data["output"], U.sigmoid(net_output)])
self._img.draw(img)
def loss(self, net_output, target):
return F.binary_cross_entropy_with_logits(net_output, target, reduction="sum") / net_output.size(0)
def state_dict(self):
return {}
def load_state_dict(self, state):
pass

View File

@ -0,0 +1,55 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
import random
from .BitmapTask import BitmapTask
class BitmapTaskRepeater(BitmapTask):
def __init__(self, dataset):
super(BitmapTaskRepeater, self).__init__()
self.dataset = dataset
def __getitem__(self, key):
r = [self.dataset[k] for k in key]
if len(r)==1:
return r[0]
else:
return {
"input": np.concatenate([a["input"] for a in r], axis=0),
"output": np.concatenate([a["output"] for a in r], axis=0)
}
@staticmethod
def key_sampler(length, repeat):
def call_sampler(s):
if callable(s):
return s()
elif isinstance(s, list):
if len(s) == 2:
return random.randint(*s)
elif len(s) == 1:
return s[0]
else:
assert False, "Invalid sample parameter: %s" % s
else:
return s
def s():
r = call_sampler(repeat)
return [call_sampler(length) for i in range(r)]
return s

View File

@ -0,0 +1,55 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
from .BitmapTask import BitmapTask
from Utils.Seed import get_randstate
class CopyData(BitmapTask):
def __init__(self, length=None, bit_w=8, transform=lambda x:x):
super(CopyData, self).__init__()
self.length = length
self.bit_w = bit_w
self.transform = transform
self.seed = None
def __getitem__(self, key):
if self.seed is None:
self.seed = get_randstate()
length = self.length() if callable(self.length) else self.length
if length is None:
#Random length batch hack.
length = key
d = self.seed.randint(0,2,[length+1, self.bit_w+1]).astype(np.float32)
z = np.zeros_like(d)
d[-1] = 0
d[:, -1] = 0
d[-1, -1] = 1
i_p = np.concatenate((d, z), axis=0)
o_p = np.concatenate((z,d), axis=0)
return self.transform({
"input" : i_p,
"output": o_p
})

View File

@ -0,0 +1,84 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import math
import numpy as np
from .BitmapTask import BitmapTask
from Utils.Seed import get_randstate
class KeyValue(BitmapTask):
def __init__(self, length=None, bit_w=8, transform=lambda x: x):
assert bit_w % 2 == 0, "bit_w must be even"
super(KeyValue, self).__init__()
self.length = length
self.bit_w = bit_w
self.transform = transform
self.seed = None
self.key_w = self.bit_w//2
self.max_key = 2**self.key_w - 1
def __getitem__(self, key):
if self.seed is None:
self.seed = get_randstate()
if self.length is None:
# Random length batch hack.
length = key
else:
length = self.length() if callable(self.length) else self.length
# keys must be unique
keys = None
last_size = 0
while last_size!=length:
res = self.seed.random_integers(0, self.max_key, size=(length - last_size))
if keys is not None:
keys = np.concatenate((res, keys))
else:
keys = res
keys = np.unique(keys)
last_size = keys.size
# view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it
keys = keys.view(np.uint8).reshape(length, -1)
keys = keys[:, :math.ceil(self.key_w/8)]
keys = np.unpackbits(np.expand_dims(keys,-1), axis=-1)
keys = np.flip(keys, axis=-1).reshape(keys.shape[0],-1)[:, :self.key_w]
keys = keys.astype(np.float32)
values = self.seed.randint(0,2, keys.shape).astype(np.float32)
perm = self.seed.permutation(length)
keys_perm = keys[perm,:]
values_perm = values[perm,:]
i_p = np.zeros((2*length+2, self.bit_w+1), dtype=np.float32)
i_p[:length,:self.key_w] = keys
i_p[:length,self.key_w:-1] = values
i_p[length+1:-1, :self.key_w] = keys_perm
i_p[length,-1] = 1
i_p[-1, -1] = 1
o_p = np.zeros((2*length+2, self.key_w), dtype=np.float32)
o_p[length+1:-1] = values_perm
return self.transform({
"input": i_p,
"output": o_p
})

View File

@ -0,0 +1,92 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import math
import numpy as np
from .BitmapTask import BitmapTask
from Utils.Seed import get_randstate
class KeyValue2Way(BitmapTask):
def __init__(self, length=None, bit_w=8, transform=lambda x: x):
assert bit_w % 2 == 0, "bit_w must be even"
super(KeyValue2Way, self).__init__()
self.length = length
self.bit_w = bit_w
self.transform = transform
self.seed = None
self.key_w = self.bit_w//2
self.max_key = 2**self.key_w - 1
def __getitem__(self, key):
if self.seed is None:
self.seed = get_randstate()
if self.length is None:
# Random length batch hack.
length = key
else:
length = self.length() if callable(self.length) else self.length
# keys must be unique
keys = None
last_size = 0
while last_size!=length:
res = self.seed.random_integers(0, self.max_key, size=(length - last_size))
if keys is not None:
keys = np.concatenate((res, keys))
else:
keys = res
keys = np.unique(keys)
last_size = keys.size
# view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it
keys = keys.view(np.uint8).reshape(length, -1)
keys = keys[:, :math.ceil(self.key_w/8)]
keys = np.unpackbits(np.expand_dims(keys,-1), axis=-1)
keys = np.flip(keys, axis=-1).reshape(keys.shape[0],-1)[:, :self.key_w]
keys = keys.astype(np.float32)
values = self.seed.randint(0,2, keys.shape).astype(np.float32)
perm = self.seed.permutation(length)
keys_perm = keys[perm,:]
values_perm = values[perm,:]
i_p = np.zeros((3*(length+1), self.bit_w+2), dtype=np.float32)
o_p = np.zeros((3*(length+1), self.key_w), dtype=np.float32)
i_p[:length,:self.key_w] = keys
i_p[:length,self.key_w:-2] = values
i_p[length + 1:2*length + 1, :self.key_w] = keys_perm
o_p[length + 1:2 * length + 1] = values_perm
perm = self.seed.permutation(length)
keys_perm = keys[perm, :]
values_perm = values[perm, :]
o_p[2*(length + 1):-1] = keys_perm
i_p[2 * (length + 1):-1, :self.key_w] = values_perm
i_p[length, -2] = 1
i_p[2 * length + 1, -1] = 1
i_p[-1, -2:] = 1
return self.transform({
"input": i_p,
"output": o_p
})

View File

1
Dataset/NLP/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
cache

103
Dataset/NLP/NLPTask.py Normal file
View File

@ -0,0 +1,103 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import torch.nn.functional as F
import os
from .Vocabulary import Vocabulary
from Utils import Visdom
from Utils import universal as U
class NLPTask(torch.utils.data.Dataset):
def __init__(self):
super(NLPTask, self).__init__()
self.my_dir = os.path.abspath(os.path.dirname(__file__))
self.cache_dir = os.path.join(self.my_dir, "cache")
if not os.path.isdir(self.cache_dir):
os.makedirs(self.cache_dir)
self.vocabulary = self._load_vocabulary()
self._preview = None
def _load_vocabulary(self):
cache_file = os.path.join(self.cache_dir, "vocabulary.pth")
if not os.path.isfile(cache_file):
print("WARNING: Vocabulary not found. Removing cached files.")
for f in os.listdir(self.cache_dir):
f = os.path.join(self.cache_dir, f)
if f.endswith(".pth"):
print(" "+f)
os.remove(f)
return Vocabulary()
else:
return torch.load(cache_file)
def save_vocabulary(self):
cache_file = os.path.join(self.cache_dir, "vocabulary.pth")
if os.path.isfile(cache_file):
os.remove(cache_file)
torch.save(self.vocabulary, cache_file)
def loss(self, net_output, target):
s = list(net_output.size())
return F.cross_entropy(net_output.view([s[0]*s[1], s[2]]), target.view([-1]), ignore_index=0,
reduction='sum')/s[0]
def generate_preview_text(self, data, net_output):
input = U.to_numpy(data["input"][0])
reference = U.to_numpy(data["output"][0])
net_out = U.argmax(net_output[0], -1)
net_out = U.to_numpy(net_out)
res = ""
start_index = 0
for i in range(input.shape[0]):
if reference[i] != 0:
if start_index < i:
end_index = i
while end_index>start_index and input[end_index]==0:
end_index -= 1
if end_index>start_index:
sentence = " ".join(self.vocabulary.indices_to_sentence(input[start_index:i].tolist())). \
replace(" .", ".").replace(" ,", ",").replace(" ?", "?").split(". ")
sentence = ". ".join([s.capitalize() for s in sentence])
res += sentence + "<br>"
start_index = i + 1
match = reference[i] == net_out[i]
res += "<b><font color=\"%s\">%s [%s]</font><br></b>" % ("green" if match else "red",
self.vocabulary.indices_to_sentence(
[net_out[i]])[0],
self.vocabulary.indices_to_sentence(
[reference[i]])[0])
return res
def visualize_preview(self, data, net_output):
res = self.generate_preview_text(data, net_output)
if self._preview is None:
self._preview = Visdom.Text("Preview")
self._preview.set(res)
def set_dump_dir(self, dir):
pass

49
Dataset/NLP/Vocabulary.py Normal file
View File

@ -0,0 +1,49 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
class Vocabulary:
def __init__(self):
self.words = {"-" : 0, "?": 1, "<UNK>": 2}
self.inv_words = {0 : "-", 1: "?", 2: "<UNK>"}
self.next_id = 3
self.punctations = [".", "?", ","]
def _process_word(self, w, add_words):
if not w.isalpha() and w not in self.punctations:
print("WARNING: word with unknown characters: %s", w)
w = "<UNK>"
if w not in self.words:
if add_words:
self.words[w] = self.next_id
self.inv_words[self.next_id] = w
self.next_id += 1
else:
w = "<UNK>"
return self.words[w]
def sentence_to_indices(self, sentence, add_words=True):
for p in self.punctations:
sentence = sentence.replace(p, " %s " % p)
return [self._process_word(w, add_words) for w in sentence.lower().split(" ") if w]
def indices_to_sentence(self, indices):
return [self.inv_words[i] for i in indices]
def __len__(self):
return len(self.words)

0
Dataset/NLP/__init__.py Normal file
View File

270
Dataset/NLP/bAbi.py Normal file
View File

@ -0,0 +1,270 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import os
import glob
import torch
from collections import namedtuple
import numpy as np
from .NLPTask import NLPTask
from Utils import Visdom
Sentence = namedtuple('Sentence', ['sentence', 'answer', 'supporting_facts'])
class bAbiDataset(NLPTask):
URL = 'http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz'
DIR_NAME = "tasks_1-20_v1-2"
def __init__(self, dirs = ["en-10k"], sets=None, think_steps=0, dir_name=None, name=None):
super(bAbiDataset, self).__init__()
self._test_res_win = None
self._test_plot_win = None
self._think_steps = think_steps
if dir_name is None:
self._download()
dir_name = os.path.join(self.cache_dir, self.DIR_NAME)
self.data={}
for d in dirs:
self.data[d] = self._load_or_create(os.path.join(dir_name, d))
self.all_tasks=None
self.name = name
self.use(sets=sets)
def _make_active_list(self, tasks, sets, dirs):
def verify(name, checker):
if checker is None:
return True
if callable(checker):
return checker(name)
elif isinstance(checker, list):
return name in checker
else:
return name==checker
res = []
for dirname, setlist in self.data.items():
if not verify(dirname, dirs):
continue
for sname, tasklist in setlist.items():
if not verify(sname, sets):
continue
for task, data in tasklist.items():
name = task.split("_")[0][2:]
if not verify(name, tasks):
continue
res += [(d, dirname, task, sname) for d in data]
return res
def use(self, tasks=None, sets=None, dirs=None):
self.all_tasks=self._make_active_list(tasks=tasks, sets=sets, dirs=dirs)
def __len__(self):
return len(self.all_tasks)
def _get_seq(self, index):
return self.all_tasks[index]
def _seq_to_nn_input(self, seq):
in_arr = []
out_arr = []
hasAnswer = False
for sentence in seq[0]:
in_arr += sentence.sentence
out_arr += [0] * len(sentence.sentence)
if sentence.answer is not None:
in_arr += [0] * (len(sentence.answer) + self._think_steps)
out_arr += [0] * self._think_steps + sentence.answer
hasAnswer = True
in_arr = np.asarray(in_arr, np.int64)
out_arr = np.asarray(out_arr, np.int64)
return {
"input": in_arr,
"output": out_arr,
"meta": {
"dir": seq[1],
"task": seq[2],
"set": seq[3]
}
}
def __getitem__(self, item):
seq = self._get_seq(item)
return self._seq_to_nn_input(seq)
def _load_or_create(self, directory):
cache_name = directory.replace("/","_")
cache_file = os.path.join(self.cache_dir, cache_name+".pth")
if not os.path.isfile(cache_file):
print("bAbI: Loading %s" % directory)
res = self._load_dir(directory)
print("Write: ", cache_file)
self.save_vocabulary()
torch.save(res, cache_file)
else:
res = torch.load(cache_file)
return res
def _download(self):
if not os.path.isdir(os.path.join(self.cache_dir, self.DIR_NAME)):
print(self.URL)
print("bAbi data not found. Downloading...")
import requests, tarfile, io
request = requests.get(self.URL, headers={"User-agent":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36"})
decompressed_file = tarfile.open(fileobj=io.BytesIO(request.content), mode='r|gz')
decompressed_file.extractall(self.cache_dir)
print("Done")
def _load_dir(self, directory, parse_name = lambda x: x.split(".")[0], parse_set = lambda x: x.split(".")[0].split("_")[-1]):
res = {}
for f in glob.glob(os.path.join(directory, '**', '*.txt'), recursive=True):
basename = os.path.basename(f)
task_name = parse_name(basename)
set = parse_set(basename)
print("Loading", f)
s = res.get(set)
if s is None:
s = {}
res[set] = s
s[task_name] = self._load_task(f, task_name)
return res
def _load_task(self, filename, task_name):
task = []
currTask = []
nextIndex = 1
with open(filename, "r") as f:
for line in f:
line = [f.strip() for f in line.split("\t")]
line[0] = line[0].split(" ")
i = int(line[0][0])
line[0] = " ".join(line[0][1:])
if i!=nextIndex:
nextIndex = i
task.append(currTask)
currTask = []
isQuestion = len(line)>1
currTask.append(
Sentence(self.vocabulary.sentence_to_indices(line[0]), self.vocabulary.sentence_to_indices(line[1].replace(",", " "))
if isQuestion else None, [int(f) for f in line[2].split(" ")] if isQuestion else None)
)
nextIndex += 1
return task
def start_test(self):
return {}
def veify_result(self, test, data, net_output):
_, net_output = net_output.max(-1)
ref = data["output"]
mask = 1.0 - ref.eq(0).float()
correct = (torch.eq(net_output, ref).float() * mask).sum(-1)
total = mask.sum(-1)
correct = correct.data.cpu().numpy()
total = total.data.cpu().numpy()
for i in range(correct.shape[0]):
task = data["meta"][i]["task"]
if task not in test:
test[task] = {"total": 0, "correct": 0}
d = test[task]
d["total"] += total[i]
d["correct"] += correct[i]
def _ensure_test_wins_exists(self, legend = None):
if self._test_res_win is None:
n = (("[" + self.name + "]") if self.name is not None else "")
self._test_res_win = Visdom.Text("Test results" + n)
self._test_plot_win = Visdom.Plot2D("Test results" + n, legend=legend)
elif self._test_plot_win.legend is None:
self._test_plot_win.set_legend(legend=legend)
def show_test_results(self, iteration, test):
res = {k: v["correct"]/v["total"] for k, v in test.items()}
t = ""
all_keys = list(res.keys())
num_keys = [k for k in all_keys if k.startswith("qa")]
tmp = [i[0] for i in sorted(enumerate(num_keys), key=lambda x:int(x[1][2:].split("_")[0]))]
num_keys = [num_keys[j] for j in tmp]
all_keys = num_keys + sorted([k for k in all_keys if not k.startswith("qa")])
err_precent = [(1.0-res[k]) * 100.0 for k in all_keys]
n_passed = sum([int(p<=5) for p in err_precent])
n_total = len(err_precent)
err_precent = err_precent + [sum(err_precent) / len(err_precent)]
all_keys += ["mean"]
for i, k in enumerate(all_keys):
t += "<font color=\"%s\">%s: <b>%.2f%%</b></font><br>" % ("green" if err_precent[i] <= 5 else "red", k, err_precent[i])
t += "<br><b>Total: %d of %d passed.</b>" % (n_passed, n_total)
self._ensure_test_wins_exists(legend=[i.split("_")[0] if i.startswith("qa") else i for i in all_keys])
self._test_res_win.set(t)
self._test_plot_win.add_point(iteration, err_precent)
def state_dict(self):
if self._test_res_win is not None:
return {
"_test_res_win" : self._test_res_win.state_dict(),
"_test_plot_win": self._test_plot_win.state_dict(),
}
else:
return {}
def load_state_dict(self, state):
if state:
self._ensure_test_wins_exists()
self._test_res_win.load_state_dict(state["_test_res_win"])
self._test_plot_win.load_state_dict(state["_test_plot_win"])
self._test_plot_win.legend = None
def visualize_preview(self, data, net_output):
res = self.generate_preview_text(data, net_output)
res = ("<b><u>%s</u></b><br>" % data["meta"][0]["task"]) + res
if self._preview is None:
self._preview = Visdom.Text("Preview")
self._preview.set(res)

1
Dataset/__init__.py Normal file
View File

@ -0,0 +1 @@

677
Models/DNC.py Normal file
View File

@ -0,0 +1,677 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import torch.utils.data
import torch.nn.functional as F
import torch.nn.init as init
import functools
import math
def oneplus(t):
return F.softplus(t, 1, 20) + 1.0
def get_next_tensor_part(src, dims, prev_pos=0):
if not isinstance(dims, list):
dims=[dims]
n = functools.reduce(lambda x, y: x * y, dims)
data = src.narrow(-1, prev_pos, n)
return data.contiguous().view(list(data.size())[:-1] + dims) if len(dims)>1 else data, prev_pos + n
def split_tensor(src, shapes):
pos = 0
res = []
for s in shapes:
d, pos = get_next_tensor_part(src, s, pos)
res.append(d)
return res
def dict_get(dict,name):
return dict.get(name) if dict is not None else None
def dict_append(dict, name, val):
if dict is not None:
l = dict.get(name)
if not l:
l = []
dict[name] = l
l.append(val)
def init_debug(debug, initial):
if debug is not None and not debug:
debug.update(initial)
def merge_debug_tensors(d, dim):
if d is not None:
for k, v in d.items():
if isinstance(v, dict):
merge_debug_tensors(v, dim)
elif isinstance(v, list):
d[k] = torch.stack(v, dim)
def linear_reset(module, gain=1.0):
assert isinstance(module, torch.nn.Linear)
init.xavier_uniform_(module.weight, gain=gain)
s = module.weight.size(1)
if module.bias is not None:
module.bias.data.zero_()
_EPS = 1e-6
class AllocationManager(torch.nn.Module):
def __init__(self):
super(AllocationManager, self).__init__()
self.usages = None
self.zero_usages = None
self.debug_sequ_init = False
self.one = None
def _init_sequence(self, prev_read_distributions):
# prev_read_distributions size is [batch, n_heads, cell count]
s = prev_read_distributions.size()
if self.zero_usages is None or list(self.zero_usages.size())!=[s[0],s[-1]]:
self.zero_usages = torch.zeros(s[0], s[-1], device = prev_read_distributions.device)
if self.debug_sequ_init:
self.zero_usages += torch.arange(0, s[-1]).unsqueeze(0) * 1e-10
self.usages = self.zero_usages
def _init_consts(self, device):
if self.one is None:
self.one = torch.ones(1, device=device)
def new_sequence(self):
self.usages = None
def update_usages(self, prev_write_distribution, prev_read_distributions, free_gates):
# Read distributions shape: [batch, n_heads, cell count]
# Free gates shape: [batch, n_heads]
self._init_consts(prev_read_distributions.device)
phi = torch.addcmul(self.one, -1, free_gates.unsqueeze(-1), prev_read_distributions).prod(-2)
# Phi is the free tensor, sized [batch, cell count]
# If memory usage counter if doesn't exists
if self.usages is None:
self._init_sequence(prev_read_distributions)
# in first timestep nothing is written or read yet, so we don't need any further processing
else:
self.usages = torch.addcmul(self.usages, 1, prev_write_distribution.detach(), (1 - self.usages)) * phi
return phi
def forward(self, prev_write_distribution, prev_read_distributions, free_gates):
phi = self.update_usages(prev_write_distribution, prev_read_distributions, free_gates)
sorted_usage, free_list = (self.usages*(1.0-_EPS)+_EPS).sort(-1)
u_prod = sorted_usage.cumprod(-1)
one_minus_usage = 1.0 - sorted_usage
sorted_scores = torch.cat([one_minus_usage[..., 0:1], one_minus_usage[..., 1:] * u_prod[..., :-1]], dim=-1)
return sorted_scores.clone().scatter_(-1, free_list, sorted_scores), phi
class ContentAddressGenerator(torch.nn.Module):
def __init__(self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
super(ContentAddressGenerator, self).__init__()
self.disable_content_norm = disable_content_norm
self.mask_min = mask_min
self.disable_key_masking = disable_key_masking
def forward(self, memory, keys, betas, mask=None):
# Memory shape [batch, cell count, word length]
# Key shape [batch, n heads*, word length]
# Betas shape [batch, n heads]
if mask is not None and self.mask_min != 0:
mask = mask * (1.0-self.mask_min) + self.mask_min
single_head = keys.dim() == 2
if single_head:
# Single head
keys = keys.unsqueeze(1)
if mask is not None:
mask = mask.unsqueeze(1)
memory = memory.unsqueeze(1)
keys = keys.unsqueeze(-2)
if mask is not None:
mask = mask.unsqueeze(-2)
memory = memory * mask
if not self.disable_key_masking:
keys = keys * mask
# Shape [batch, n heads, cell count]
norm = keys.norm(dim=-1)
if not self.disable_content_norm:
norm = norm * memory.norm(dim=-1)
scores = (memory * keys).sum(-1) / (norm + _EPS)
scores *= betas.unsqueeze(-1)
res = F.softmax(scores, scores.dim()-1)
return res.squeeze(1) if single_head else res
class WriteHead(torch.nn.Module):
@staticmethod
def create_write_archive(write_dist, erase_vector, write_vector, phi):
return dict(write_dist=write_dist, erase_vector=erase_vector, write_vector=write_vector, phi=phi)
def __init__(self, dealloc_content=True, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
super(WriteHead, self).__init__()
self.write_content_generator = ContentAddressGenerator(disable_content_norm, mask_min=mask_min, disable_key_masking=disable_key_masking)
self.allocation_manager = AllocationManager()
self.last_write = None
self.dealloc_content = dealloc_content
self.new_sequence()
def new_sequence(self):
self.last_write = None
self.allocation_manager.new_sequence()
@staticmethod
def mem_update(memory, write_dist, erase_vector, write_vector, phi):
# In original paper the memory content is NOT deallocated, which makes content based addressing basically
# unusable when multiple similar steps should be done. The reason for this is that the memory contents are
# still there, so the lookup will find them, unless an allocation clears it before the next search, which is
# completely random. So I'm arguing that erase matrix should also take in account the free gates (multiply it
# with phi)
write_dist = write_dist.unsqueeze(-1)
erase_matrix = 1.0 - write_dist * erase_vector.unsqueeze(-2)
if phi is not None:
erase_matrix = erase_matrix * phi.unsqueeze(-1)
update_matrix = write_dist * write_vector.unsqueeze(-2)
return memory * erase_matrix + update_matrix
def forward(self, memory, write_content_key, write_beta, erase_vector, write_vector, alloc_gate, write_gate,
free_gates, prev_read_dist, write_mask=None, debug=None):
last_w_dist = self.last_write["write_dist"] if self.last_write is not None else None
content_dist = self.write_content_generator(memory, write_content_key, write_beta, mask = write_mask)
alloc_dist, phi = self.allocation_manager(last_w_dist, prev_read_dist, free_gates)
# Shape [batch, cell count]
write_dist = write_gate * (alloc_gate * alloc_dist + (1-alloc_gate)*content_dist)
self.last_write = WriteHead.create_write_archive(write_dist, erase_vector, write_vector, phi if self.dealloc_content else None)
dict_append(debug, "alloc_dist", alloc_dist)
dict_append(debug, "write_dist", write_dist)
dict_append(debug, "mem_usages", self.allocation_manager.usages)
dict_append(debug, "free_gates", free_gates)
dict_append(debug, "write_betas", write_beta)
dict_append(debug, "write_gate", write_gate)
dict_append(debug, "write_vector", write_vector)
dict_append(debug, "alloc_gate", alloc_gate)
dict_append(debug, "erase_vector", erase_vector)
if write_mask is not None:
dict_append(debug, "write_mask", write_mask)
return WriteHead.mem_update(memory, **self.last_write)
class RawWriteHead(torch.nn.Module):
def __init__(self, n_read_heads, word_length, use_mask=False, dealloc_content=True, disable_content_norm=False,
mask_min=0.0, disable_key_masking=False):
super(RawWriteHead, self).__init__()
self.write_head = WriteHead(dealloc_content = dealloc_content, disable_content_norm = disable_content_norm,
mask_min=mask_min, disable_key_masking=disable_key_masking)
self.word_length = word_length
self.n_read_heads = n_read_heads
self.use_mask = use_mask
self.input_size = 3*self.word_length + self.n_read_heads + 3 + (self.word_length if use_mask else 0)
def new_sequence(self):
self.write_head.new_sequence()
def get_prev_write(self):
return self.write_head.last_write
def forward(self, memory, nn_output, prev_read_dist, debug):
shapes = [[self.word_length]] * (4 if self.use_mask else 3) + [[self.n_read_heads]] + [[1]] * 3
tensors = split_tensor(nn_output, shapes)
if self.use_mask:
write_mask = torch.sigmoid(tensors[0])
tensors=tensors[1:]
else:
write_mask = None
write_content_key, erase_vector, write_vector, free_gates, write_beta, alloc_gate, write_gate = tensors
erase_vector = torch.sigmoid(erase_vector)
free_gates = torch.sigmoid(free_gates)
write_beta = oneplus(write_beta)
alloc_gate = torch.sigmoid(alloc_gate)
write_gate = torch.sigmoid(write_gate)
return self.write_head(memory, write_content_key, write_beta, erase_vector, write_vector,
alloc_gate, write_gate, free_gates, prev_read_dist, debug=debug, write_mask=write_mask)
def get_neural_input_size(self):
return self.input_size
class TemporalMemoryLinkage(torch.nn.Module):
def __init__(self):
super(TemporalMemoryLinkage, self).__init__()
self.temp_link_mat = None
self.precedence_weighting = None
self.diag_mask = None
self.initial_temp_link_mat = None
self.initial_precedence_weighting = None
self.initial_diag_mask = None
self.initial_shape = None
def new_sequence(self):
self.temp_link_mat = None
self.precedence_weighting = None
self.diag_mask = None
def _init_link(self, w_dist):
s = list(w_dist.size())
if self.initial_shape is None or s != self.initial_shape:
self.initial_temp_link_mat = torch.zeros(s[0], s[-1], s[-1]).to(w_dist.device)
self.initial_precedence_weighting = torch.zeros(s[0], s[-1]).to(w_dist.device)
self.initial_diag_mask = (1.0 - torch.eye(s[-1]).unsqueeze(0).to(w_dist)).detach()
self.temp_link_mat = self.initial_temp_link_mat
self.precedence_weighting = self.initial_precedence_weighting
self.diag_mask = self.initial_diag_mask
def _update_precedence(self, w_dist):
# w_dist shape: [ batch, cell count ]
self.precedence_weighting = (1.0 - w_dist.sum(-1, keepdim=True)) * self.precedence_weighting + w_dist
def _update_links(self, w_dist):
if self.temp_link_mat is None:
self._init_link(w_dist)
wt_i = w_dist.unsqueeze(-1)
wt_j = w_dist.unsqueeze(-2)
pt_j = self.precedence_weighting.unsqueeze(-2)
self.temp_link_mat = ((1 - wt_i - wt_j) * self.temp_link_mat + wt_i * pt_j) * self.diag_mask
def forward(self, w_dist, prev_r_dists, debug = None):
self._update_links(w_dist)
self._update_precedence(w_dist)
# prev_r_dists shape: [ batch, n heads, cell count ]
# Emulate matrix-vector multiplication by broadcast and sum. This way we don't need to transpose the matrix
tlm_multi_head = self.temp_link_mat.unsqueeze(1)
forward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-2)).sum(-1)
backward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-1)).sum(-2)
dict_append(debug, "forward_dists", forward_dist)
dict_append(debug, "backward_dists", backward_dist)
dict_append(debug, "precedence_weights", self.precedence_weighting)
# output shapes [ batch, n_heads, cell_count ]
return forward_dist, backward_dist
class ReadHead(torch.nn.Module):
def __init__(self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
super(ReadHead, self).__init__()
self.content_addr_generator = ContentAddressGenerator(disable_content_norm=disable_content_norm,
mask_min=mask_min,
disable_key_masking=disable_key_masking)
self.read_dist = None
self.read_data = None
self.new_sequence()
def new_sequence(self):
self.read_dist = None
self.read_data = None
def forward(self, memory, read_content_keys, read_betas, forward_dist, backward_dist, gates, read_mask=None, debug=None):
content_dist = self.content_addr_generator(memory, read_content_keys, read_betas, mask=read_mask)
self.read_dist = backward_dist * gates[..., 0:1] + content_dist * gates[...,1:2] + forward_dist * gates[..., 2:]
# memory shape: [ batch, cell count, word_length ]
# read_dist shape: [ batch, n heads, cell count ]
# result shape: [ batch, n_heads, word_length ]
self.read_data = (memory.unsqueeze(1) * self.read_dist.unsqueeze(-1)).sum(-2)
dict_append(debug, "content_dist", content_dist)
dict_append(debug, "balance", gates)
dict_append(debug, "read_dist", self.read_dist)
dict_append(debug, "read_content_keys", read_content_keys)
if read_mask is not None:
dict_append(debug, "read_mask", read_mask)
dict_append(debug, "read_betas", read_betas.unsqueeze(-2))
if read_mask is not None:
dict_append(debug, "read_mask", read_mask)
return self.read_data
class RawReadHead(torch.nn.Module):
def __init__(self, n_heads, word_length, use_mask=False, disable_content_norm=False, mask_min=0.0,
disable_key_masking=False):
super(RawReadHead, self).__init__()
self.read_head = ReadHead(disable_content_norm=disable_content_norm, mask_min=mask_min,
disable_key_masking=disable_key_masking)
self.n_heads = n_heads
self.word_length = word_length
self.use_mask = use_mask
self.input_size = self.n_heads * (self.word_length*(2 if use_mask else 1) + 3 + 1)
def get_prev_dist(self, memory):
if self.read_head.read_dist is not None:
return self.read_head.read_dist
else:
m_shape = memory.size()
return torch.zeros(m_shape[0], self.n_heads, m_shape[1]).to(memory)
def get_prev_data(self, memory):
if self.read_head.read_data is not None:
return self.read_head.read_data
else:
m_shape = memory.size()
return torch.zeros(m_shape[0], self.n_heads, m_shape[-1]).to(memory)
def new_sequence(self):
self.read_head.new_sequence()
def forward(self, memory, nn_output, forward_dist, backward_dist, debug):
shapes = [[self.n_heads, self.word_length]] * (2 if self.use_mask else 1) + [[self.n_heads], [self.n_heads, 3]]
tensors = split_tensor(nn_output, shapes)
if self.use_mask:
read_mask = torch.sigmoid(tensors[0])
tensors = tensors[1:]
else:
read_mask = None
keys, betas, gates = tensors
betas = oneplus(betas)
gates = F.softmax(gates, gates.dim()-1)
return self.read_head(memory, keys, betas, forward_dist, backward_dist, gates, debug=debug, read_mask=read_mask)
def get_neural_input_size(self):
return self.input_size
class DistSharpnessEnhancer(torch.nn.Module):
def __init__(self, n_heads):
super(DistSharpnessEnhancer, self).__init__()
self.n_heads = n_heads if isinstance(n_heads, list) else [n_heads]
self.n_data = sum(self.n_heads)
def forward(self, nn_input, *dists):
assert len(dists) == len(self.n_heads)
nn_input = oneplus(nn_input[..., :self.n_data])
factors = split_tensor(nn_input, self.n_heads)
res = []
for i, d in enumerate(dists):
s = list(d.size())
ndim = d.dim()
f = factors[i]
if ndim==2:
assert self.n_heads[i]==1
elif ndim==3:
f = f.unsqueeze(-1)
else:
assert False
d += _EPS
d = d / d.max(dim=-1, keepdim=True)[0]
d = d.pow(f)
d = d / d.sum(dim=-1, keepdim=True)
res.append(d)
return res
def get_neural_input_size(self):
return self.n_data
class DNC(torch.nn.Module):
def __init__(self, input_size, output_size, word_length, cell_count, n_read_heads, controller, batch_first=False, clip_controller=20,
bias=True, mask=False, dealloc_content=True, link_sharpness_control=True, disable_content_norm=False,
mask_min=0.0, disable_key_masking=False):
super(DNC, self).__init__()
self.clip_controller = clip_controller
self.read_head = RawReadHead(n_read_heads, word_length, use_mask=mask, disable_content_norm=disable_content_norm,
mask_min=mask_min, disable_key_masking=disable_key_masking)
self.write_head = RawWriteHead(n_read_heads, word_length, use_mask=mask, dealloc_content=dealloc_content,
disable_content_norm=disable_content_norm, mask_min=mask_min,
disable_key_masking=disable_key_masking)
self.temporal_link = TemporalMemoryLinkage()
self.sharpness_control = DistSharpnessEnhancer([n_read_heads, n_read_heads]) if link_sharpness_control else None
in_size = input_size + n_read_heads * word_length
control_channels = self.read_head.get_neural_input_size() + self.write_head.get_neural_input_size() +\
(self.sharpness_control.get_neural_input_size() if self.sharpness_control is not None else 0)
self.controller = controller
controller.init(in_size)
self.controller_to_controls = torch.nn.Linear(controller.get_output_size(), control_channels, bias=bias)
self.controller_to_out = torch.nn.Linear(controller.get_output_size(), output_size, bias=bias)
self.read_to_out = torch.nn.Linear(word_length * n_read_heads, output_size, bias=bias)
self.cell_count = cell_count
self.word_length = word_length
self.memory = None
self.reset_parameters()
self.batch_first = batch_first
self.zero_mem_tensor = None
def reset_parameters(self):
linear_reset(self.controller_to_controls)
linear_reset(self.controller_to_out)
linear_reset(self.read_to_out)
self.controller.reset_parameters()
def _step(self, in_data, debug):
init_debug(debug, {
"read_head": {},
"write_head": {},
"temporal_links": {}
})
# input shape: [ batch, channels ]
batch_size = in_data.size(0)
# run the controller
prev_read_data = self.read_head.get_prev_data(self.memory).view([batch_size, -1])
control_data = self.controller(torch.cat([in_data, prev_read_data], -1))
# memory ops
controls = self.controller_to_controls(control_data).contiguous()
controls = controls.clamp(-self.clip_controller, self.clip_controller) if self.clip_controller is not None else controls
shapes = [[self.write_head.get_neural_input_size()], [self.read_head.get_neural_input_size()]]
if self.sharpness_control is not None:
shapes.append(self.sharpness_control.get_neural_input_size())
tensors = split_tensor(controls, shapes)
write_head_control, read_head_control = tensors[:2]
tensors = tensors[2:]
prev_read_dist = self.read_head.get_prev_dist(self.memory)
self.memory = self.write_head(self.memory, write_head_control, prev_read_dist, debug=dict_get(debug,"write_head"))
prev_write = self.write_head.get_prev_write()
forward_dist, backward_dist = self.temporal_link(prev_write["write_dist"] if prev_write is not None else None, prev_read_dist, debug=dict_get(debug, "temporal_links"))
if self.sharpness_control is not None:
forward_dist, backward_dist = self.sharpness_control(tensors[0], forward_dist, backward_dist)
read_data = self.read_head(self.memory, read_head_control, forward_dist, backward_dist, debug=dict_get(debug,"read_head"))
# output:
return self.controller_to_out(control_data) + self.read_to_out(read_data.view(batch_size,-1))
def _mem_init(self, batch_size, device):
if self.zero_mem_tensor is None or self.zero_mem_tensor.size(0)!=batch_size:
self.zero_mem_tensor = torch.zeros(batch_size, self.cell_count, self.word_length).to(device)
self.memory = self.zero_mem_tensor
def forward(self, in_data, debug=None):
self.write_head.new_sequence()
self.read_head.new_sequence()
self.temporal_link.new_sequence()
self.controller.new_sequence()
self._mem_init(in_data.size(0 if self.batch_first else 1), in_data.device)
out_tsteps = []
if self.batch_first:
# input format: batch, time, channels
for t in range(in_data.size(1)):
out_tsteps.append(self._step(in_data[:,t], debug))
else:
# input format: time, batch, channels
for t in range(in_data.size(0)):
out_tsteps.append(self._step(in_data[t], debug))
merge_debug_tensors(debug, dim=1 if self.batch_first else 0)
return torch.stack(out_tsteps, dim=1 if self.batch_first else 0)
class LSTMController(torch.nn.Module):
def __init__(self, layer_sizes, out_from_all_layers=True):
super(LSTMController, self).__init__()
self.out_from_all_layers = out_from_all_layers
self.layer_sizes = layer_sizes
self.states = None
self.outputs = None
def new_sequence(self):
self.states = [None] * len(self.layer_sizes)
self.outputs = [None] * len(self.layer_sizes)
def reset_parameters(self):
def init_layer(l, index):
size = self.layer_sizes[index]
# Initialize all matrices to sigmoid, just data input to tanh
a=math.sqrt(3.0)*self.stdevs[i]
l.weight.data[0:-size].uniform_(-a,a)
a*=init.calculate_gain("tanh")
l.weight.data[-size:].uniform_(-a, a)
if l.bias is not None:
l.bias.data[self.layer_sizes[i]:].fill_(0)
# init forget gate to large number.
l.bias.data[:self.layer_sizes[i]].fill_(1)
# xavier init merged input weights
for i in range(len(self.layer_sizes)):
init_layer(self.in_to_all[i], i)
init_layer(self.out_to_all[i], i)
if i>0:
init_layer(self.prev_to_all[i-1], i)
def _add_modules(self, name, m_list):
for i, m in enumerate(m_list):
self.add_module("%s_%d" % (name,i), m)
def init(self, input_size):
self.layer_sizes = self.layer_sizes
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
# So use xavier init according to this.
self.input_sizes = [(self.layer_sizes[i - 1] if i>0 else 0) + self.layer_sizes[i] + input_size
for i in range(len(self.layer_sizes))]
self.stdevs = [math.sqrt(2.0 / (self.layer_sizes[i] + self.input_sizes[i])) for i in range(len(self.layer_sizes))]
self.in_to_all= [torch.nn.Linear(input_size, 4*self.layer_sizes[i]) for i in range(len(self.layer_sizes))]
self.out_to_all = [torch.nn.Linear(self.layer_sizes[i], 4 * self.layer_sizes[i], bias=False) for i in range(len(self.layer_sizes))]
self.prev_to_all = [torch.nn.Linear(self.layer_sizes[i-1], 4 * self.layer_sizes[i], bias=False) for i in range(1,len(self.layer_sizes))]
self._add_modules("in_to_all", self.in_to_all)
self._add_modules("out_to_all", self.out_to_all)
self._add_modules("prev_to_all", self.prev_to_all)
self.reset_parameters()
def get_output_size(self):
return sum(self.layer_sizes) if self.out_from_all_layers else self.layer_sizes[-1]
def forward(self, data):
for i, size in enumerate(self.layer_sizes):
d = self.in_to_all[i](data)
if self.outputs[i] is not None:
d+=self.out_to_all[i](self.outputs[i])
if i>0:
d+=self.prev_to_all[i-1](self.outputs[i-1])
input_data = torch.tanh(d[...,-size:])
forget_gate, input_gate, output_gate = torch.sigmoid(d[...,:-size]).chunk(3,dim=-1)
state_update = input_gate * input_data
if self.states[i] is not None:
self.states[i] = self.states[i]*forget_gate + state_update
else:
self.states[i] = state_update
self.outputs[i] = output_gate * torch.tanh(self.states[i])
return torch.cat(self.outputs, -1) if self.out_from_all_layers else self.outputs[-1]
class FeedforwardController(torch.nn.Module):
def __init__(self, layer_sizes=[]):
super(FeedforwardController, self).__init__()
self.layer_sizes = layer_sizes
def new_sequence(self):
pass
def reset_parameters(self):
for module in self.model:
if isinstance(module, torch.nn.Linear):
linear_reset(module, gain=init.calculate_gain("relu"))
def get_output_size(self):
return self.layer_sizes[-1]
def init(self, input_size):
self.layer_sizes = self.layer_sizes
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
# So use xavier init according to this.
self.input_sizes = [input_size] + self.layer_sizes[:-1]
layers = []
for i, size in enumerate(self.layer_sizes):
layers.append(torch.nn.Linear(self.input_sizes[i], self.layer_sizes[i]))
layers.append(torch.nn.ReLU())
self.model = torch.nn.Sequential(*layers)
self.reset_parameters()
def forward(self, data):
return self.model(data)

61
README.md Normal file
View File

@ -0,0 +1,61 @@
PyTorch implementation of custom DNC variants
=============================================
Tasks
-----
Supported tasks:
* bAbI
* copy
* repeated copy
* associative recall
* key-value recall
* 2 way key-value recall
Visualization and debugging
---------------------------
Many interesting internal states of the DNC are visualized inside Visdom. Check console output for the port.
![](./assets/preview.png)
Usage
-----
Everything is done by main.py. Use -name to give some path (it will be created if doesn't exists), where the state of the training will be saved. Check out main.py for more information about the flags available.
Most of the trainings can be run by profiles:
```bash
./main.py -name <train dir> -profile babi
```
Supported profiles: babi, repeat_copy, repeat_copy_simple, keyvalue, keyvalue2way, associative_recall.
If you want to train a pure DNC, use add "dnc" to the profile:
```bash
./main.py -name <train dir> -profile babi,dnc
```
For other options, see main.py.
DNC variants
------------
The variant of DNC can be specified as a profile. Supported variants:
dnc, dnc-msd, dnc-m, dnc-s, dnc-d, dnc-md, dnc-ms, dnc-sd.
Reusing the code
----------------
The DNC is implemented as a single file (Models/DNC.py) depending only on torch. You should be able to reuse it very easily. Please check main.py for details on its interface.
Dependencies
------------
PyTroch (1.0), Python 3. Others can be installed by running pip3 -r requirements.txt.
License
-------
The software is under Apache 2.0 license. See http://www.apache.org/licenses/LICENSE-2.0 for further details.

167
Utils/ArgumentParser.py Normal file
View File

@ -0,0 +1,167 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import os
import json
import argparse
class ArgumentParser:
class Parsed:
pass
_type = type
@staticmethod
def str_or_none(none_string="none"):
def parse(s):
return None if s.lower()==none_string else s
return parse
@staticmethod
def list_or_none(none_string="none", type=int):
def parse(s):
return None if s.lower() == none_string else [type(a) for a in s.split(",") if a]
return parse
@staticmethod
def _merge_args(args, new_args, arg_schemas):
for name, val in new_args.items():
old = args.get(name)
if old is None:
args[name] = val
else:
args[name] = arg_schemas[name]["updater"](old, val)
class Profile:
def __init__(self, name, args=None, include=[]):
assert not (args is None and not include), "One of args or include must be defined"
self.name = name
self.args = args
if not isinstance(include, list):
include=[include]
self.include = include
def get_args(self, arg_schemas, profile_by_name):
res = {}
for n in self.include:
p = profile_by_name.get(n)
assert p is not None, "Included profile %s doesn't exists" % n
ArgumentParser._merge_args(res, p.get_args(arg_schemas, profile_by_name), arg_schemas)
ArgumentParser._merge_args(res, self.args, arg_schemas)
return res
def __init__(self, description=None):
self.parser = argparse.ArgumentParser(description=description)
self.loaded = {}
self.profiles = {}
self.args = {}
self.raw = None
self.parsed = None
self.parser.add_argument("-profile", type=str, help="Pre-defined profiles.")
def add_argument(self, name, type=None, default=None, help="", save=True, parser=lambda x:x, updater=lambda old, new:new):
assert name not in ["profile"], "Argument name %s is reserved" % name
assert not (type is None and default is None), "Either type or default must be given"
if type is None:
type = ArgumentParser._type(default)
self.parser.add_argument(name, type=int if type==bool else type, default=None, help=help)
if name[0] == '-':
name = name[1:]
self.args[name] = {
"type": type,
"default": int(default) if type==bool else default,
"save": save,
"parser": parser,
"updater": updater
}
def add_profile(self, prof):
if isinstance(prof, list):
for p in prof:
self.add_profile(p)
else:
self.profiles[prof.name] = prof
def do_parse_args(self, loaded={}):
self.raw = self.parser.parse_args()
profile = {}
if self.raw.profile:
assert not loaded, "Loading arguments from file, but profile given."
for pr in self.raw.profile.split(","):
p = self.profiles.get(pr)
assert p is not None, "Invalid profile: %s. Valid profiles: %s" % (pr, self.profiles.keys())
p = p.get_args(self.args, self.profiles)
self._merge_args(profile, p, self.args)
for k, v in self.raw.__dict__.items():
if k in ["profile"]:
continue
if v is None:
if k in loaded and self.args[k]["save"]:
self.raw.__dict__[k] = loaded[k]
else:
self.raw.__dict__[k] = profile.get(k, self.args[k]["default"])
self.parsed = ArgumentParser.Parsed()
self.parsed.__dict__.update({k: self.args[k]["parser"](self.args[k]["type"](v)) if v is not None else None
for k,v in self.raw.__dict__.items() if k in self.args})
return self.parsed
def parse_or_cache(self):
if self.parsed is None:
self.do_parse_args()
def parse(self):
self.parse_or_cache()
return self.parsed
def save(self, fname):
self.parse_or_cache()
with open(fname, 'w') as outfile:
json.dump(self.raw.__dict__, outfile, indent=4)
return True
def load(self, fname):
if os.path.isfile(fname):
map = {}
with open(fname, "r") as data_file:
map = json.load(data_file)
self.do_parse_args(map)
return self.parsed
def sync(self, fname, save=True):
if os.path.isfile(fname):
self.load(fname)
if save:
dir = os.path.dirname(fname)
if not os.path.isdir(dir):
os.makedirs(dir)
self.save(fname)
return self.parsed

75
Utils/Collate.py Normal file
View File

@ -0,0 +1,75 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
import torch
from operator import mul
from functools import reduce
class VarLengthCollate:
def __init__(self, ignore_symbol=0):
self.ignore_symbol = ignore_symbol
def _measure_array_max_dim(self, batch):
s=list(batch[0].size())
different=[False] * len(s)
for i in range(1, len(batch)):
ns = batch[i].size()
different = [different[j] or s[j]!=ns[j] for j in range(len(s))]
s=[max(s[j], ns[j]) for j in range(len(s))]
return s, different
def _merge_var_len_array(self, batch):
max_size, different = self._measure_array_max_dim(batch)
s=[len(batch)] + max_size
storage = batch[0].storage()._new_shared(reduce(mul, s, 1))
out = batch[0].new(storage).view(s).fill_(self.ignore_symbol)
for i, d in enumerate(batch):
this_o = out[i]
for j, diff in enumerate(different):
if different[j]:
this_o = this_o.narrow(j,0,d.size(j))
this_o.copy_(d)
return out
def __call__(self, batch):
if isinstance(batch[0], dict):
return {k: self([b[k] for b in batch]) for k in batch[0].keys()}
elif isinstance(batch[0], np.ndarray):
return self([torch.from_numpy(a) for a in batch])
elif torch.is_tensor(batch[0]):
return self._merge_var_len_array(batch)
else:
assert False, "Unknown type: %s" % type(batch[0])
class MetaCollate:
def __init__(self, meta_name="meta", collate=VarLengthCollate()):
self.meta_name = meta_name
self.collate = collate
def __call__(self, batch):
if isinstance(batch[0], dict):
meta = [b[self.meta_name] for b in batch]
batch = [{k: v for k,v in b.items() if k!=self.meta_name} for b in batch]
else:
meta = None
res = self.collate(batch)
if meta is not None:
res[self.meta_name] = meta
return res

133
Utils/Debug.py Normal file
View File

@ -0,0 +1,133 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
import sys
import traceback
import torch
enableDebug = False
def nan_check(arg, name=None, force=False):
if not enableDebug and not force:
return arg
is_nan = False
curr_nan = False
if isinstance(arg, torch.autograd.Variable):
curr_nan = not np.isfinite(arg.sum().cpu().data.numpy())
elif isinstance(arg, torch.nn.parameter.Parameter):
curr_nan = (not np.isfinite(arg.sum().cpu().data.numpy())) or (not np.isfinite(arg.grad.sum().cpu().data.numpy()))
elif isinstance(arg, float):
curr_nan = not np.isfinite(arg)
elif isinstance(arg, (list, tuple)):
for a in arg:
nan_check(a)
else:
assert False, "Unsupported type %s" % type(arg)
if curr_nan:
if sys.exc_info()[0] is not None:
trace = str(traceback.format_exc())
else:
trace = "".join(traceback.format_stack())
print(arg)
if name is not None:
print("NaN found in %s." % name)
else:
print("NaN found.")
if isinstance(arg, torch.autograd.Variable):
print(" Argument is a torch tensor. Shape: %s" % list(arg.size()))
print(trace)
sys.exit(-1)
return arg
def assert_range(t, min=0.0, max=1.0):
if not enableDebug:
return
if t.min().cpu().data.numpy()<min or t.max().cpu().data.numpy()>max:
print(t)
assert False
def assert_dist(t, use_lower_limit=True):
if not enableDebug:
return
assert_range(t)
if t.sum(-1).max().cpu().data.numpy()>1.001:
print("MAT:", t)
print("SUM:", t.sum(-1))
assert False
if use_lower_limit and t.sum(-1).max().cpu().data.numpy()<0.999:
print(t)
print("SUM:", t.sum(-1))
assert False
def print_stat(name, t):
if not enableDebug:
return
min = t.min().cpu().data.numpy()
max = t.max().cpu().data.numpy()
mean = t.mean().cpu().data.numpy()
print("%s: min: %g, mean: %g, max: %g" % (name, min, mean, max))
def dbg_print(*things):
if not enableDebug:
return
print(*things)
class GradPrinter(torch.autograd.Function):
@staticmethod
def forward(ctx, a):
return a
@staticmethod
def backward(ctx, g):
print("Grad (print_grad): ", g[0])
return g
def print_grad(t):
return GradPrinter.apply(t)
def assert_equal(t1, ref, limit=1e-5, force=True):
if not (enableDebug or force):
return
assert t1.shape==ref.shape, "Tensor shapes differ: got %s, ref %s" % (t1.shape, ref.shape)
norm = ref.abs().sum() / ref.nonzero().sum().float()
threshold = norm * limit
errcnt = ((t1 - ref).abs() > threshold).sum()
if errcnt > 0:
print("Tensors differ. (max difference: %g, norm %f). No of errors: %d of %d" %
((t1 - ref).abs().max().item(), norm, errcnt, t1.numel()))
print("---------------------------------------------Out-----------------------------------------------")
print(t1)
print("---------------------------------------------Ref-----------------------------------------------")
print(ref)
print("-----------------------------------------------------------------------------------------------")
assert False

24
Utils/Helpers.py Normal file
View File

@ -0,0 +1,24 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import torch.autograd
def as_numpy(data):
if isinstance(data, (torch.Tensor, torch.autograd.Variable)):
return data.detach().cpu().numpy()
else:
return data

20
Utils/Index.py Normal file
View File

@ -0,0 +1,20 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
def index_by_dim(arr, dim, i_start, i_end=None):
if dim<0:
dim += arr.ndim
return arr[tuple([slice(None,None)] * dim + [slice(i_start, i_end) if i_end is not None else i_start])]

51
Utils/Process.py Normal file
View File

@ -0,0 +1,51 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import sys
import ctypes
import subprocess
import os
def run(cmd, hide_stderr = True):
libc_search_dirs = ["/lib", "/lib/x86_64-linux-gnu", "/lib/powerpc64le-linux-gnu"]
if sys.platform=="linux" :
found = None
for d in libc_search_dirs:
file = os.path.join(d, "libc.so.6")
if os.path.isfile(file):
found = file
break
if not found:
print("WARNING: Cannot find libc.so.6. Cannot kill process when parent dies.")
killer = None
else:
libc = ctypes.CDLL(found)
PR_SET_PDEATHSIG = 1
KILL = 9
killer = lambda: libc.prctl(PR_SET_PDEATHSIG, KILL)
else:
print("WARNING: OS not linux. Cannot kill process when parent dies.")
killer = None
if hide_stderr:
stderr = open(os.devnull,'w')
else:
stderr = None
return subprocess.Popen(cmd.split(" "), stderr=stderr, preexec_fn=killer)

48
Utils/Profile.py Normal file
View File

@ -0,0 +1,48 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import atexit
ENABLED=False
_profiler = None
def construct():
global _profiler
if not ENABLED:
return
if _profiler is None:
from line_profiler import LineProfiler
_profiler = LineProfiler()
def do_profile(follow=[]):
construct()
def inner(func):
if _profiler is not None:
_profiler.add_function(func)
for f in follow:
_profiler.add_function(f)
_profiler.enable_by_count()
return func
return inner
@atexit.register
def print_prof():
if _profiler is not None:
_profiler.print_stats()

233
Utils/Saver.py Normal file
View File

@ -0,0 +1,233 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import os
import inspect
import time
class SaverElement:
def save(self):
raise NotImplementedError
def load(self, saved_state):
raise NotImplementedError
class CallbackSaver(SaverElement):
def __init__(self, save_fn, load_fn):
super().__init__()
self.save = save_fn
self.load = load_fn
class StateSaver(SaverElement):
def __init__(self, model):
super().__init__()
self._model = model
def load(self, state):
try:
self._model.load_state_dict(state)
except Exception as e:
if hasattr(self._model, "named_parameters"):
names = set([n for n, _ in self._model.named_parameters()])
loaded = set(self._model.keys())
if names!=loaded:
d = loaded.difference(names)
if d:
print("Loaded, but not in model: %s" % list(d))
d = names.difference(loaded)
if d:
print("In model, but not loaded: %s" % list(d))
if isinstance(self._model, torch.optim.Optimizer):
print("WARNING: optimizer parameters not loaded!")
else:
raise e
def save(self):
return self._model.state_dict()
class GlobalVarSaver(SaverElement):
def __init__(self, name):
caller_frame = inspect.getouterframes(inspect.currentframe())[1]
self._vars = caller_frame.frame.f_globals
self._name = name
def load(self, state):
self._vars.update({self._name: state})
def save(self):
return self._vars[self._name]
class PyObjectSaver(SaverElement):
def __init__(self, obj):
self._obj = obj
def load(self, state):
def _load(target, state):
if isinstance(target, dict):
for k, v in state.items():
target[k] = _load(target.get(k), v)
elif isinstance(target, list):
if len(target)!=len(state):
target.clear()
for v in state:
target.append(v)
else:
for i, v in enumerate(state):
target[i] = _load(target[i], v)
elif hasattr(target, "__dict__"):
_load(target.__dict__, state)
else:
return state
return target
_load(self._obj, state)
def save(self):
def _save(target):
if isinstance(target, dict):
res = {k: _save(v) for k, v in target.items()}
elif isinstance(target, list):
res = [_save(v) for v in target]
elif hasattr(target, "__dict__"):
res = {k: _save(v) for k, v in target.__dict__.items()}
else:
res = target
return res
return _save(self._obj)
@staticmethod
def obj_supported(obj):
return isinstance(obj, (list, dict)) or hasattr(obj, "__dict__")
class Saver:
def __init__(self, dir, short_interval, keep_every_n_hours=4):
self.savers = {}
self.short_interval = short_interval
os.makedirs(dir, exist_ok=True)
self.dir = dir
self._keep_every_n_seconds = keep_every_n_hours * 3600
def register(self, name, saver):
assert name not in self.savers, "Saver %s already registered" % name
if isinstance(saver, SaverElement):
self.savers[name] = saver
elif hasattr(saver, "state_dict") and callable(saver.state_dict):
self.savers[name] = StateSaver(saver)
elif PyObjectSaver.obj_supported(saver):
self.savers[name] = PyObjectSaver(saver)
else:
assert "Unsupported thing to save: %s" % type(saver)
def __setitem__(self, key, value):
self.register(key, value)
def write(self, iter):
fname = os.path.join(self.dir, self.model_name_from_index(iter))
print("Saving %s" % fname)
state = {}
for name, fns in self.savers.items():
state[name] = fns.save()
torch.save(state, fname)
print("Saved.")
self._cleanup()
def tick(self, iter):
if iter % self.short_interval != 0:
return
self.write(iter)
@staticmethod
def model_name_from_index(index):
return "model-%d.pth" % index
@staticmethod
def get_checkpoint_index_list(dir):
return list(reversed(sorted(
[int(fn.split(".")[0].split("-")[-1]) for fn in os.listdir(dir) if fn.split(".")[-1] == "pth"])))
@staticmethod
def get_ckpts_in_time_window(dir, time_window_s, index_list=None):
if index_list is None:
index_list = Saver.get_checkpoint_index_list(dir)
now = time.time()
res = []
for i in index_list:
name = Saver.model_name_from_index(i)
mtime = os.path.getmtime(os.path.join(dir, name))
if now - mtime > time_window_s:
break
res.append(name)
return res
@staticmethod
def load_last_checkpoint(dir):
last_checkpoint = Saver.get_checkpoint_index_list(dir)
if last_checkpoint:
for index in last_checkpoint:
fname = Saver.model_name_from_index(index)
try:
print("Loading %s" % fname)
data = torch.load(os.path.join(dir, fname))
except:
print("WARNING: Loading %s failed. Maybe file is corrupted?" % fname)
continue
return data
return None
def _cleanup(self):
index_list = self.get_checkpoint_index_list(self.dir)
new_files = self.get_ckpts_in_time_window(self.dir, self._keep_every_n_seconds, index_list[2:])
new_files = new_files[:-1]
for f in new_files:
os.remove(os.path.join(self.dir, f))
def load(self, fname=None):
if fname is None:
state = self.load_last_checkpoint(self.dir)
if not state:
return False
else:
state = torch.load(fname)
for k,s in state.items():
if k not in self.savers:
print("WARNING: failed to load state of %s. It doesn't exists." % k)
continue
self.savers[k].load(s)
return True

31
Utils/Seed.py Normal file
View File

@ -0,0 +1,31 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
seed = None
def fix():
global seed
seed = 0xB0C1FA52
def get_randstate():
if seed:
return np.random.RandomState(seed)
else:
return np.random.RandomState()

323
Utils/Visdom.py Normal file
View File

@ -0,0 +1,323 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import time
import visdom
import sys
import numpy as np
import os
import socket
from . import Process
vis = None
port = None
visdom_fail_count = 0
def port_used(port):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(('127.0.0.1', port))
if result == 0:
sock.close()
return True
else:
return False
def alloc_port(start_from=7000):
while True:
if port_used(start_from):
print("Port already used: %d" % start_from)
start_from += 1
else:
return start_from
def wait_for_port(port, timeout=5):
star_time = time.time()
while not port_used(port):
if time.time() - star_time > timeout:
return False
time.sleep(0.1)
return True
def start(on_port=None):
global vis
global port
global visdom_fail_count
assert vis is None, "Cannot start more than 1 visdom servers."
if visdom_fail_count>=3:
return
port = alloc_port() if on_port is None else on_port
print("Starting Visdom server on %d" % port)
Process.run("%s -m visdom.server -p %d" % (sys.executable, port))
if not wait_for_port(port):
print("ERROR: failed to start Visdom server. Server not responding.")
visdom_fail_count += 1
return
print("Done.")
vis = visdom.Visdom(port=port)
def _start_if_not_running():
if vis is None:
start()
def save_heatmap(dir, title, img):
if dir is not None:
fname = os.path.join(dir, title.replace(" ", "_") + ".npy")
d = os.path.dirname(fname)
os.makedirs(d, exist_ok=True)
np.save(fname, img)
class Plot2D:
TO_SAVE = ["x", "y", "curr_accu", "curr_cnt", "legend"]
def __init__(self, name, store_interval=1, legend=None, xlabel=None, ylabel=None):
_start_if_not_running()
self.x = []
self.y = []
self.store_interval = store_interval
self.curr_accu = None
self.curr_cnt = 0
self.name = name
self.legend = legend
self.xlabel = xlabel
self.ylabel = ylabel
self.replot = False
self.visplot = None
self.last_vis_update_pos = 0
def set_legend(self, legend):
if self.legend != legend:
self.legend = legend
self.replot = True
def _send_update(self):
if not self.x or vis is None:
return
if self.visplot is None or self.replot:
opts = {
"title": self.name
}
if self.xlabel:
opts["xlabel"] = self.xlabel
if self.ylabel:
opts["ylabel"] = self.ylabel
if self.legend is not None:
opts["legend"] = self.legend
self.visplot = vis.line(X=np.asfarray(self.x), Y=np.asfarray(self.y), opts=opts, win=self.visplot)
self.replot = False
else:
vis.line(
X=np.asfarray(self.x[self.last_vis_update_pos:]),
Y=np.asfarray(self.y[self.last_vis_update_pos:]),
win=self.visplot,
update='append'
)
self.last_vis_update_pos = len(self.x) - 1
def add_point(self, x, y):
if not isinstance(y, list):
y = [y]
if self.curr_accu is None:
self.curr_accu = [0.0] * len(y)
if len(self.curr_accu) < len(y):
# The number of curves increased.
need_to_add = (len(y) - len(self.curr_accu))
self.curr_accu += [0.0] * need_to_add
count = len(self.x)
if count>0:
self.replot = True
if not isinstance(self.x[0], list):
self.x = [[x] for x in self.x]
self.y = [[y] for y in self.y]
nan = float("nan")
for a in self.x:
a += [nan] * need_to_add
for a in self.y:
a += [nan] * need_to_add
elif len(self.curr_accu) > len(y):
y = y[:] + [float("nan")] * (len(self.curr_accu) - len(y))
self.curr_accu = [self.curr_accu[i] + y[i] for i in range(len(y))]
self.curr_cnt += 1
if self.curr_cnt == self.store_interval:
if len(y) > 1:
self.x.append([x] * len(y))
self.y.append([a / self.curr_cnt for a in self.curr_accu])
else:
self.x.append(x)
self.y.append(self.curr_accu[0] / self.curr_cnt)
self.curr_accu = [0.0] * len(y)
self.curr_cnt = 0
self._send_update()
def state_dict(self):
s = {k: self.__dict__[k] for k in self.TO_SAVE}
return s
def load_state_dict(self, state):
if self.legend is not None:
# Load legend only if not given in the constructor.
state["legend"] = self.legend
self.__dict__.update(state)
self.last_vis_update_pos = 0
# Read old format
if not isinstance(self.curr_accu, list) and self.curr_accu is not None:
self.curr_accu = [self.curr_accu]
self._send_update()
class Image:
def __init__(self, title, dumpdir=None):
_start_if_not_running()
self.win = None
self.opts = dict(title=title)
self.dumpdir = dumpdir
def set_dump_dir(self, dumpdir):
self.dumpdir = dumpdir
def draw(self, img):
if vis is None:
return
if isinstance(img, list):
if self.win is None:
self.win = vis.images(img, opts=self.opts)
else:
vis.images(img, win=self.win, opts=self.opts)
else:
if len(img.shape)==2:
img = np.expand_dims(img, 0)
elif img.shape[-1] in [1,3] and img.shape[0] not in [1,3]:
# cv2 image
img = img.transpose(2,0,1)
img = img[::-1]
if img.dtype==np.uint8:
img = img.astype(np.float32)/255
self.opts["width"] = img.shape[2]
self.opts["height"] = img.shape[1]
save_heatmap(self.dumpdir, self.opts["title"], img)
if self.win is None:
self.win = vis.image(img, opts=self.opts)
else:
vis.image(img, win=self.win, opts=self.opts)
def __call__(self, img):
self.draw(img)
class Text:
def __init__(self, title):
_start_if_not_running()
self.win = None
self.title = title
self.curr_text = ""
def set(self, text):
self.curr_text = text
if vis is None:
return
if self.win is None:
self.win = vis.text(text, opts=dict(
title=self.title
))
else:
vis.text(text, win=self.win)
def state_dict(self):
return {"text": self.curr_text}
def load_state_dict(self, state):
self.set(state["text"])
def __call__(self, text):
self.set(text)
class Heatmap:
def __init__(self, title, min=None, max=None, xlabel=None, ylabel=None, colormap='Viridis', dumpdir=None):
_start_if_not_running()
self.win = None
self.opt = dict(title=title, colormap=colormap)
self.dumpdir = dumpdir
if min is not None:
self.opt["xmin"] = min
if max is not None:
self.opt["xmax"] = max
if xlabel:
self.opt["xlabel"] = xlabel
if ylabel:
self.opt["ylabel"] = ylabel
def set_dump_dir(self, dumpdir):
self.dumpdir = dumpdir
def draw(self, img):
if vis is None:
return
o = self.opt.copy()
if "xmin" not in o:
o["xmin"] = float(img.min())
if "xmax" not in o:
o["xmax"] = float(img.max())
save_heatmap(self.dumpdir, o["title"], img)
if self.win is None:
self.win = vis.heatmap(img, opts=o)
else:
vis.heatmap(img, win=self.win, opts=o)
def __call__(self, img):
self.draw(img)

189
Utils/download.py Normal file
View File

@ -0,0 +1,189 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import requests, tarfile, io, os, zipfile
from io import BytesIO, SEEK_SET, SEEK_END
class UrlStream:
def __init__(self, url):
self._url = url
headers = requests.head(url).headers
headers = {k.lower(): v for k,v in headers.items()}
self._seek_supported = headers.get('accept-ranges')=='bytes' and 'content-length' in headers
if self._seek_supported:
self._size = int(headers['content-length'])
self._curr_pos = 0
self._buf_start_pos = 0
self._iter = None
self._buffer = None
self._buf_size = 0
self._loaded_all = False
def seekable(self):
return self._seek_supported
def _load_all(self):
if self._loaded_all:
return
self._make_request()
old_buf_pos = self._buffer.tell()
self._buffer.seek(0, SEEK_END)
for chunk in self._iter:
self._buffer.write(chunk)
self._buf_size = self._buffer.tell()
self._buffer.seek(old_buf_pos, SEEK_SET)
self._loaded_all = True
def seek(self, position, whence=SEEK_SET):
if whence == SEEK_END:
assert position<=0
if self._seek_supported:
self.seek(self._size + position)
else:
self._load_all()
self._buffer.seek(position, SEEK_END)
self._curr_pos = self._buffer.tell()
elif whence==SEEK_SET:
if self._curr_pos != position:
self._curr_pos = position
if self._seek_supported:
self._iter = None
self._buffer = None
else:
self._load_until(position)
self._buffer.seek(position)
self._curr_pos = position
else:
assert "Invalid whence %s" % whence
return self.tell()
def tell(self):
return self._curr_pos
def _load_until(self, goal_position):
self._make_request()
old_buf_pos = self._buffer.tell()
current_position = self._buffer.seek(0, SEEK_END)
goal_position = goal_position - self._buf_start_pos
while current_position < goal_position:
try:
d = next(self._iter)
self._buffer.write(d)
current_position += len(d)
except StopIteration:
break
self._buf_size = current_position
self._buffer.seek(old_buf_pos, SEEK_SET)
def _new_buffer(self):
remaining = self._buffer.read() if self._buffer is not None else None
self._buffer = BytesIO()
if remaining is not None:
self._buffer.write(remaining)
self._buf_start_pos = self._curr_pos
self._buf_size = 0 if remaining is None else len(remaining)
self._buffer.seek(0, SEEK_SET)
self._loaded_all = False
def _make_request(self):
if self._iter is None:
h = {
"User-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36",
}
if self._seek_supported:
h["Range"] = "bytes=%d-%d" % (self._curr_pos, self._size - 1)
r = requests.get(self._url, headers=h, stream=True)
self._iter = r.iter_content(1024 * 1024)
self._new_buffer()
elif self._seek_supported and self._buf_size > 128 * 1024 * 1024:
self._new_buffer()
def size(self):
if self._seek_supported:
return self._size
else:
self._load_all()
return self._buf_size
def read(self, size=None):
if size is None:
size = self.size()
self._load_until(self._curr_pos + size)
if self._seek_supported:
self._curr_pos = min(self._curr_pos+size, self._size)
return self._buffer.read(size)
def iter_content(self, block_size):
while True:
d = self.read(block_size)
if not len(d):
break
yield d
def download(url, dest=None, extract=True, ignore_if_exists=False):
"""
Download a file from the internet.
Args:
url: the url to download
dest: destination file if extract=False, or destionation dir if extract=True. If None, it will be the last part of URL.
extract: extract a tar.gz or zip file?
ignore_if_exists: don't do anything if file exists
Returns:
the destination filename.
"""
base_url = url.split("?")[0]
if dest is None:
dest = [f for f in base_url.split("/") if f][-1]
if os.path.exists(dest) and ignore_if_exists:
return dest
stream = UrlStream(url)
extension = base_url.split(".")[-1].lower()
if extract and extension in ['gz', 'bz2', 'zip']:
os.makedirs(dest, exist_ok=True)
if extension in ['gz', 'bz2']:
decompressed_file = tarfile.open(fileobj=stream, mode='r|'+extension)
elif extension=='zip':
decompressed_file = zipfile.ZipFile(stream, mode='r')
else:
assert False, "Invalid extension: %s" % extension
decompressed_file.extractall(dest)
else:
try:
with open(dest, 'wb') as f:
for d in stream.iter_content(1024*1024):
f.write(d)
except:
os.remove(dest)
raise
return dest

111
Utils/gpu_allocator.py Normal file
View File

@ -0,0 +1,111 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import subprocess
import os
import torch
from Utils.lockfile import LockFile
def get_memory_usage():
try:
proc = subprocess.Popen("nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits".split(" "),
stdout=subprocess.PIPE)
lines = [s.strip().split(" ") for s in proc.communicate()[0].decode().split("\n") if s]
return {int(g[0][:-1]): int(g[1]) for g in lines}
except:
return None
def get_free_gpus():
try:
free = []
proc = subprocess.Popen("nvidia-smi --query-compute-apps=gpu_uuid --format=csv,noheader,nounits".split(" "),
stdout=subprocess.PIPE)
uuids = [s.strip() for s in proc.communicate()[0].decode().split("\n") if s]
proc = subprocess.Popen("nvidia-smi --query-gpu=index,uuid --format=csv,noheader,nounits".split(" "),
stdout=subprocess.PIPE)
id_uid_pair = [s.strip().split(", ") for s in proc.communicate()[0].decode().split("\n") if s]
for i in id_uid_pair:
id, uid = i
if uid not in uuids:
free.append(int(id))
return free
except:
return None
def _fix_order():
os.environ["CUDA_DEVICE_ORDER"] = os.environ.get("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
def allocate(n:int = 1):
_fix_order()
with LockFile("/tmp/gpu_allocation_lock"):
if "CUDA_VISIBLE_DEVICES" in os.environ:
print("WARNING: trying to allocate %d GPUs, but CUDA_VISIBLE_DEVICES already set to %s" %
(n, os.environ["CUDA_VISIBLE_DEVICES"]))
return
allocated = get_free_gpus()
if allocated is None:
print("WARNING: failed to allocate %d GPUs" % n)
return
allocated = allocated[:n]
if len(allocated) < n:
print("There is no more free GPUs. Allocating the one with least memory usage.")
usage = get_memory_usage()
if usage is None:
print("WARNING: failed to allocate %d GPUs" % n)
return
inv_usages = {}
for k, v in usage.items():
if v not in inv_usages:
inv_usages[v] = []
inv_usages[v].append(k)
min_usage = list(sorted(inv_usages.keys()))
min_usage_devs = []
for u in min_usage:
min_usage_devs += inv_usages[u]
min_usage_devs = [m for m in min_usage_devs if m not in allocated]
n2 = n - len(allocated)
if n2>len(min_usage_devs):
print("WARNING: trying to allocate %d GPUs but only %d available" % (n, len(min_usage_devs)+len(allocated)))
n2 = len(min_usage_devs)
allocated += min_usage_devs[:n2]
os.environ["CUDA_VISIBLE_DEVICES"]=",".join([str(a) for a in allocated])
for i in range(len(allocated)):
a = torch.FloatTensor([0.0])
a.cuda(i)
def use_gpu(gpu="auto", n_autoalloc=1):
_fix_order()
gpu = gpu.lower()
if gpu in ["auto", ""]:
allocate(n_autoalloc)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = gpu

41
Utils/lockfile.py Normal file
View File

@ -0,0 +1,41 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import os
import fcntl
class LockFile:
def __init__(self, fname):
self._fname = fname
self._fd = None
def acquire(self):
self._fd=open(self._fname, "w")
os.chmod(self._fname, 0o777)
fcntl.lockf(self._fd, fcntl.LOCK_EX)
def release(self):
fcntl.lockf(self._fd, fcntl.LOCK_UN)
self._fd.close()
self._fd = None
def __enter__(self):
self.acquire()
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()

54
Utils/timer.py Normal file
View File

@ -0,0 +1,54 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import time
class OnceEvery:
def __init__(self, interval):
self._interval = interval
self._last_check = 0
def __call__(self):
now = time.time()
if now - self._last_check >= self._interval:
self._last_check = now
return True
else:
return False
class Measure:
def __init__(self, average=1):
self._start = None
self._average = average
self._accu_value = 0.0
self._history_list = []
def start(self):
self._start = time.time()
def passed(self):
if self._start is None:
return None
p = time.time() - self._start
self._history_list.append(p)
self._accu_value += p
if len(self._history_list) > self._average:
self._accu_value -= self._history_list.pop(0)
return self._accu_value / len(self._history_list)

263
Utils/universal.py Normal file
View File

@ -0,0 +1,263 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import torch.nn.functional as F
import numpy as np
float32 = [np.float32, torch.float32]
float64 = [np.float64, torch.float64]
uint8 = [np.uint8, torch.uint8]
_all_types = [float32, float64, uint8]
_dtype_numpy_map = {v[0]().dtype.name:v for v in _all_types}
_dtype_pytorch_map = {v[1]:v for v in _all_types}
def dtype(t):
if torch.is_tensor(t):
return _dtype_pytorch_map[t.dtype]
else:
return _dtype_numpy_map[t.dtype.name]
def cast(t, type):
if torch.is_tensor(t):
return t.type(type[1])
else:
return t.astype(type[0])
def to_numpy(t):
if torch.is_tensor(t):
return t.detach().cpu().numpy()
else:
return t
def to_list(t):
t = to_numpy(t)
if isinstance(t, np.ndarray):
t = t.tolist()
return t
def is_tensor(t):
return torch.is_tensor(t) or isinstance(t, np.ndarray)
def first_batch(t):
if is_tensor(t):
return t[0]
else:
return t
def ndim(t):
if torch.is_tensor(t):
return t.dim()
else:
return t.ndim
def shape(t):
return list(t.shape)
def transpose(t, axis):
if torch.is_tensor(t):
return t.permute(axis)
else:
return np.transpose(t, axis)
def apply_recursive(d, fn, filter=None):
if isinstance(d, list):
return [apply_recursive(da, fn) for da in d]
elif isinstance(d, tuple):
return tuple(apply_recursive(list(d), fn))
elif isinstance(d, dict):
return {k: apply_recursive(v, fn) for k, v in d.items()}
else:
if filter is None or filter(d):
return fn(d)
else:
return d
def apply_to_tensors(d, fn):
return apply_recursive(d, fn, torch.is_tensor)
def recursive_decorator(apply_this_fn):
def decorator(func):
def wrapped_funct(*args, **kwargs):
args = apply_recursive(args, apply_this_fn)
kwargs = apply_recursive(kwargs, apply_this_fn)
return func(*args, **kwargs)
return wrapped_funct
return decorator
untensor = recursive_decorator(to_numpy)
unnumpy = recursive_decorator(to_list)
def unbatch(only_if_dim_equal=None):
if only_if_dim_equal is not None and not isinstance(only_if_dim_equal, list):
only_if_dim_equal = [only_if_dim_equal]
def get_first_batch(t):
if is_tensor(t) and (only_if_dim_equal is None or ndim(t) in only_if_dim_equal):
return t[0]
else:
return t
return recursive_decorator(get_first_batch)
def sigmoid(t):
if torch.is_tensor(t):
return torch.sigmoid(t)
else:
return 1.0 / (1.0 + np.exp(-t))
def argmax(t, dim):
if torch.is_tensor(t):
_, res = t.max(dim)
else:
res = np.argmax(t, axis=dim)
return res
def flip(t, axis):
if torch.is_tensor(t):
return t.flip(axis)
else:
return np.flip(t, axis)
def transpose(t, axes):
if torch.is_tensor(t):
return t.permute(*axes)
else:
return np.transpose(t, axes)
def split_n(t, axis):
if torch.is_tensor(t):
return t.split(1, dim=axis)
else:
return np.split(t, t.shape[axis], axis=axis)
def cat(array_of_tensors, axis):
if torch.is_tensor(array_of_tensors[0]):
return torch.cat(array_of_tensors, axis)
else:
return np.concatenate(array_of_tensors, axis)
def clamp(t, min=None, max=None):
if torch.is_tensor(t):
return t.clamp(min, max)
else:
if min is not None:
t = np.maximum(t, min)
if max is not None:
t = np.minimum(t, max)
return t
def power(t, p):
if torch.is_tensor(t) or torch.is_tensor(p):
return torch.pow(t, p)
else:
return np.power(t, p)
def random_normal_as(a, mean, std, seed=None):
if torch.is_tensor(a):
return torch.randn_like(a) * std + mean
else:
if seed is None:
seed = np.random
return seed.normal(loc=mean, scale=std, size=shape(a))
def pad(t, pad):
assert ndim(t) == 4
if torch.is_tensor(t):
return F.pad(t, pad)
else:
assert np.pad(t, ([0,0], [0,0], pad[0:2], pad[2:]))
def dx(img):
lsh = img[:, :, :, 2:]
orig = img[:, :, :, :-2]
return pad(0.5 * (lsh - orig), (1, 1, 0, 0))
def dy(img):
ush = img[:, :, 2:, :]
orig = img[:, :, :-2, :]
return pad(0.5 * (ush - orig), (0, 0, 1, 1))
def reshape(t, shape):
if torch.is_tensor(t):
return t.view(*shape)
else:
return t.reshape(*shape)
def broadcast_to_beginning(t, target):
if torch.is_tensor(t):
nd_target = target.dim()
t_shape = list(t.shape)
return t.view(*t_shape, *([1]*(nd_target-len(t_shape))))
else:
nd_target = target.ndim
t_shape = list(t.shape)
return t.reshape(*t_shape, *([1] * (nd_target - len(t_shape))))
def lin_combine(d1,w1, d2,w2, bcast_begin=False):
if isinstance(d1, (list, tuple)):
assert len(d1) == len(d2)
res = [lin_combine(d1[i], w1, d2[i], w2) for i in range(len(d1))]
if isinstance(d1, tuple):
res = tuple(d1)
elif isinstance(d1, dict):
res = {k: lin_combine(v, w1, d2[k], w2) for k, v in d1.items()}
else:
if bcast_begin:
w1 = broadcast_to_beginning(w1, d1)
w2 = broadcast_to_beginning(w2, d2)
res = d1 * w1 + d2 * w2
return res

73
Visualize/BitmapTask.py Normal file
View File

@ -0,0 +1,73 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
try:
import cv2
except:
cv2=None
from Utils.Helpers import as_numpy
def visualize_bitmap_task(i_data, o_data, zoom=8):
if not isinstance(o_data, list):
o_data = [o_data]
imgs = []
for d in [i_data]+o_data:
if d is None:
continue
d=as_numpy(d)
if d.ndim>2:
d=d[0]
imgs.append(np.expand_dims(d.T*255, -1).astype(np.uint8))
img = np.concatenate(imgs, 0)
return nearest_zoom(img, zoom)
def visualize_01(t, zoom=8):
return nearest_zoom(np.expand_dims(t*255,-1).astype(np.uint8), zoom)
def nearest_zoom(img, zoom=1):
if zoom>1 and cv2 is not None:
return cv2.resize(img, (img.shape[1] * zoom, img.shape[0] * zoom), interpolation=cv2.INTER_NEAREST)
else:
return img
def concatenate_tensors(tensors):
max_size = None
dtype = None
for t in tensors:
s = t.shape
if max_size is None:
max_size = list(s)
dtype = t.dtype
continue
assert t.ndim ==len(max_size), "Can only concatenate tensors with same ndim."
assert t.dtype == dtype, "Tensors must have the same type"
max_size = [max(max_size[i], s[i]) for i in range(len(max_size))]
res = np.zeros([len(tensors)] + max_size, dtype=dtype)
for i, t in enumerate(tensors):
res[i][tuple([slice(0,t.shape[i]) for i in range(t.ndim)])] = t
return res

1
Visualize/__init__.py Normal file
View File

@ -0,0 +1 @@
from .BitmapTask import *

64
Visualize/preview.py Normal file
View File

@ -0,0 +1,64 @@
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import time
import threading
import traceback
import sys
from Utils import universal as U
def preview(vis_interval=None, to_numpy=True, debatch=False):
def decorator(func):
state = {
"last_vis_time": 0,
"thread_running": False
}
def wrapper_function(*args, **kwargs):
if state["thread_running"]:
return
if vis_interval is not None:
now = time.time()
if now - state["last_vis_time"] < vis_interval:
return
state["last_vis_time"] = now
state["thread_running"] = True
if debatch:
args = U.apply_recursive(args, U.first_batch)
kwargs = U.apply_recursive(kwargs, U.first_batch)
if to_numpy:
args = U.apply_recursive(args, U.to_numpy)
kwargs = U.apply_recursive(kwargs, U.to_numpy)
def run():
try:
func(*args, **kwargs)
except Exception:
traceback.print_exc()
sys.exit(-1)
state["thread_running"] = False
download_thread = threading.Thread(target=run)
download_thread.start()
return wrapper_function
return decorator

671
main.py Executable file
View File

@ -0,0 +1,671 @@
#!/usr/bin/env python3
#
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import functools
import os
import torch.utils.data
import Utils.Debug as debug
from Dataset.Bitmap.AssociativeRecall import AssociativeRecall
from Dataset.Bitmap.BitmapTaskRepeater import BitmapTaskRepeater
from Dataset.Bitmap.KeyValue import KeyValue
from Dataset.Bitmap.CopyTask import CopyData
from Dataset.Bitmap.KeyValue2Way import KeyValue2Way
from Dataset.NLP.bAbi import bAbiDataset
from Models.DNC import DNC, LSTMController, FeedforwardController
from Utils import Visdom
from Utils.ArgumentParser import ArgumentParser
from Utils.Index import index_by_dim
from Utils.Saver import Saver, GlobalVarSaver, StateSaver
from Utils.Collate import MetaCollate
from Utils import gpu_allocator
from Dataset.NLP.NLPTask import NLPTask
from tqdm import tqdm
from Visualize.preview import preview
from Utils.timer import OnceEvery
from Utils import Seed
import time
import sys
import signal
import math
from Utils import Profile
Profile.ENABLED=False
def main():
global i
global loss_sum
global running
parser = ArgumentParser()
parser.add_argument("-bit_w", type=int, default=8, help="Bit vector length for copy task")
parser.add_argument("-block_w", type=int, default=3, help="Block width to associative recall task")
parser.add_argument("-len", type=str, default="4", help="Sequence length for copy task", parser=lambda x: [int(a) for a in x.split("-")])
parser.add_argument("-repeat", type=str, default="1", help="Sequence length for copy task", parser=lambda x: [int(a) for a in x.split("-")])
parser.add_argument("-batch_size", type=int, default=16, help="Sequence length for copy task")
parser.add_argument("-n_subbatch", type=str, default="auto", help="Average this much forward passes to a backward pass")
parser.add_argument("-max_input_count_per_batch", type=int, default=6000, help="Max batch_size*len that can fit into memory")
parser.add_argument("-lr", type=float, default=0.0001, help="Learning rate")
parser.add_argument("-wd", type=float, default=1e-5, help="Weight decay")
parser.add_argument("-optimizer", type=str, default="rmsprop", help="Optimizer algorithm")
parser.add_argument("-name", type=str, help="Save training to this directory")
parser.add_argument("-preview_interval", type=int, default=10, help="Show preview every nth iteration")
parser.add_argument("-info_interval", type=int, default=10, help="Show info every nth iteration")
parser.add_argument("-save_interval", type=int, default=500, help="Save network every nth iteration")
parser.add_argument("-masked_lookup", type=bool, default=1, help="Enable masking in content lookups")
parser.add_argument("-visport", type=int, default=-1, help="Port to run Visdom server on. -1 to disable")
parser.add_argument("-gpu", default="auto", type=str, help="Run on this GPU.")
parser.add_argument("-debug", type=bool, default=1, help="Enable debugging")
parser.add_argument("-task", type=str, default="copy", help="Task to learn")
parser.add_argument("-mem_count", type=int, default=16, help="Number of memory cells")
parser.add_argument("-data_word_size", type=int, default=128, help="Memory word size")
parser.add_argument("-n_read_heads", type=int, default=1, help="Number of read heads")
parser.add_argument("-layer_sizes", type=str, default="256", help="Controller layer sizes. Separate with ,. For example 512,256,256", parser=lambda x: [int(y) for y in x.split(",") if y])
parser.add_argument("-debug_log", type=bool, default=0, help="Enable debug log")
parser.add_argument("-controller_type", type=str, default="lstm", help="Controller type: lstm or linear")
parser.add_argument("-lstm_use_all_outputs", type=bool, default=1, help="Use all LSTM outputs as controller output vs use only the last layer")
parser.add_argument("-momentum", type=float, default=0.9, help="Momentum for optimizer")
parser.add_argument("-embedding_size", type=int, default=256, help="Size of word embedding for NLP tasks")
parser.add_argument("-test_interval", type=int, default=10000, help="Run test in this interval")
parser.add_argument("-dealloc_content", type=bool, default=1, help="Deallocate memory content, unlike DNC, which leaves it unchanged, just decreases the usage counter, causing problems with lookup")
parser.add_argument("-sharpness_control", type=bool, default=1, help="Distribution sharpness control for forward and backward links")
parser.add_argument("-think_steps", type=int, default=0, help="Iddle steps before requiring the answer (for bAbi)")
parser.add_argument("-dump_profile", type=str, save=False)
parser.add_argument("-test_on_start", default="0", save=False)
parser.add_argument("-dump_heatmaps", default=False, save=False)
parser.add_argument("-test_batch_size", default=16)
parser.add_argument("-mask_min", default=0.0)
parser.add_argument("-load", type=str, save=False)
parser.add_argument("-dataset_path", type=str, default="none", parser=ArgumentParser.str_or_none(), help="Specify babi path manually")
parser.add_argument("-babi_train_tasks", type=str, default="none", parser=ArgumentParser.list_or_none(type=str), help="babi task list to use for training")
parser.add_argument("-babi_test_tasks", type=str, default="none", parser=ArgumentParser.list_or_none(type=str), help="babi task list to use for testing")
parser.add_argument("-babi_train_sets", type=str, default="train", parser=ArgumentParser.list_or_none(type=str), help="babi train sets to use")
parser.add_argument("-babi_test_sets", type=str, default="test", parser=ArgumentParser.list_or_none(type=str), help="babi test sets to use")
parser.add_argument("-noargsave", type=bool, default=False, help="Do not save modified arguments", save=False)
parser.add_argument("-demo", type=bool, default=False, help="Do a single step with fixed seed", save=False)
parser.add_argument("-exit_after", type=int, help="Exit after this amount of steps. Useful for debugging.", save=False)
parser.add_argument("-grad_clip", type=float, default=10.0, help="Max gradient norm")
parser.add_argument("-clip_controller", type=float, default=20.0, help="Max gradient norm")
parser.add_argument("-print_test", default=False, save=False)
parser.add_profile([
ArgumentParser.Profile("babi", {
"preview_interval": 10,
"save_interval": 500,
"task": "babi",
"mem_count": 256,
"data_word_size": 64,
"n_read_heads": 4,
"layer_sizes": "256",
"controller_type": "lstm",
"lstm_use_all_outputs": True,
"momentum": 0.9,
"embedding_size": 128,
"test_interval": 5000,
"think_steps": 3,
"batch_size": 2
}, include=["dnc-msd"]),
ArgumentParser.Profile("repeat_copy", {
"bit_w": 8,
"repeat": "1-8",
"len": "2-14",
"task": "copy",
"think_steps": 1,
"preview_interval": 10,
"info_interval": 10,
"save_interval": 100,
"data_word_size": 16,
"layer_sizes": "32",
"n_subbatch": 1,
"controller_type": "lstm",
}),
ArgumentParser.Profile("repeat_copy_simple", {
"repeat": "1-3",
}, include="repeat_copy"),
ArgumentParser.Profile("dnc", {
"masked_lookup": False,
"sharpness_control": False,
"dealloc_content": False
}),
ArgumentParser.Profile("dnc-m", {
"masked_lookup": True,
"sharpness_control": False,
"dealloc_content": False
}),
ArgumentParser.Profile("dnc-s", {
"masked_lookup": False,
"sharpness_control": True,
"dealloc_content": False
}),
ArgumentParser.Profile("dnc-d", {
"masked_lookup": False,
"sharpness_control": False,
"dealloc_content": True
}),
ArgumentParser.Profile("dnc-md", {
"masked_lookup": True,
"sharpness_control": False,
"dealloc_content": True
}),
ArgumentParser.Profile("dnc-ms", {
"masked_lookup": True,
"sharpness_control": True,
"dealloc_content": False
}),
ArgumentParser.Profile("dnc-sd", {
"masked_lookup": False,
"sharpness_control": True,
"dealloc_content": True
}),
ArgumentParser.Profile("dnc-msd", {
"masked_lookup": True,
"sharpness_control": True,
"dealloc_content": True
}),
ArgumentParser.Profile("keyvalue", {
"repeat": "1",
"len": "2-16",
"mem_count": 16,
"task": "keyvalue",
"think_steps": 1,
"preview_interval": 10,
"info_interval": 10,
"data_word_size": 32,
"bit_w": 12,
"save_interval": 1000,
"layer_sizes": "32"
}),
ArgumentParser.Profile("keyvalue2way", {
"task": "keyvalue2way",
}, include="keyvalue"),
ArgumentParser.Profile("associative_recall",{
"task": "recall",
"bit_w": 8,
"len": "2-16",
"mem_count": 64,
"data_word_size": 32,
"n_read_heads": 1,
"layer_sizes": "128",
"controller_type": "lstm",
"lstm_use_all_outputs": 1,
"think_steps": 1,
"mask_min": 0.1,
"info_interval": 10,
"save_interval": 1000,
"preview_interval": 10,
"n_subbatch": 1,
})
])
opt = parser.parse()
assert opt.name is not None, "Training dir (-name parameter) not given"
opt = parser.sync(os.path.join(opt.name, "args.json"), save=not opt.noargsave)
if opt.demo:
Seed.fix()
os.makedirs(os.path.join(opt.name,"save"), exist_ok=True)
os.makedirs(os.path.join(opt.name,"preview"), exist_ok=True)
gpu_allocator.use_gpu(opt.gpu)
debug.enableDebug = opt.debug_log
if opt.visport>0:
Visdom.start(opt.visport)
Visdom.Text("Name").set(opt.name)
class LengthHackSampler:
def __init__(self, batch_size, length):
self.length = length
self.batch_size = batch_size
def __iter__(self):
while True:
len = self.length() if callable(self.length) else self.length
yield [len] * self.batch_size
def __len__(self):
return 0x7FFFFFFF
embedding = None
test_set = None
curriculum = None
loader_reset = False
if opt.task=="copy":
dataset = CopyData(bit_w=opt.bit_w)
in_size = opt.bit_w + 1
out_size = in_size
elif opt.task=="recall":
dataset = AssociativeRecall(bit_w=opt.bit_w, block_w=opt.block_w)
in_size = opt.bit_w + 2
out_size = in_size
elif opt.task=="keyvalue":
assert opt.bit_w % 2==0, "Key-value datasets works only with even bit_w"
dataset = KeyValue(bit_w=opt.bit_w)
in_size = opt.bit_w + 1
out_size = opt.bit_w//2
elif opt.task=="keyvalue2way":
assert opt.bit_w % 2==0, "Key-value datasets works only with even bit_w"
dataset = KeyValue2Way(bit_w=opt.bit_w)
in_size = opt.bit_w + 2
out_size = opt.bit_w//2
elif opt.task=="babi":
dataset = bAbiDataset(think_steps=opt.think_steps, dir_name=opt.dataset_path)
test_set = bAbiDataset(think_steps=opt.think_steps, dir_name=opt.dataset_path, name="test")
dataset.use(opt.babi_train_tasks, opt.babi_train_sets)
in_size = opt.embedding_size
print("bAbi: loaded total of %d sequences." % len(dataset))
test_set.use(opt.babi_test_tasks, opt.babi_test_sets)
out_size = len(dataset.vocabulary)
print("bAbi: using %d sequences for training, %d for testing" % (len(dataset), len(test_set)))
else:
assert False, "Invalid task: %s" % opt.task
if opt.task in ["babi"]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, num_workers=4, pin_memory=True, shuffle=True, collate_fn=MetaCollate())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=opt.test_batch_size, num_workers=opt.test_batch_size, pin_memory=True, shuffle=False, collate_fn=MetaCollate()) if test_set is not None else None
else:
dataset = BitmapTaskRepeater(dataset)
data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=LengthHackSampler(opt.batch_size, BitmapTaskRepeater.key_sampler(opt.len, opt.repeat)), num_workers=1, pin_memory=True)
if opt.controller_type == "lstm":
controller_constructor = functools.partial(LSTMController, out_from_all_layers=opt.lstm_use_all_outputs)
elif opt.controller_type == "linear":
controller_constructor = FeedforwardController
else:
assert False, "Invalid controller: %s" % opt.controller_type
model = DNC(in_size, out_size, opt.data_word_size, opt.mem_count, opt.n_read_heads, controller_constructor(opt.layer_sizes),
batch_first=True, mask=opt.masked_lookup, dealloc_content=opt.dealloc_content,
link_sharpness_control=opt.sharpness_control,
mask_min=opt.mask_min, clip_controller=opt.clip_controller)
params = [
{'params': [p for n, p in model.named_parameters() if not n.endswith(".bias")]},
{'params': [p for n, p in model.named_parameters() if n.endswith(".bias")], 'weight_decay': 0}
]
device = torch.device('cuda') if opt.gpu!="none" else torch.device("cpu")
print("DEVICE: ", device)
if isinstance(dataset, NLPTask):
embedding = torch.nn.Embedding(len(dataset.vocabulary), opt.embedding_size).to(device)
params.append({'params': embedding.parameters(), 'weight_decay': 0})
if opt.optimizer=="sgd":
optimizer = torch.optim.SGD(params, lr=opt.lr, weight_decay=opt.wd, momentum=opt.momentum)
elif opt.optimizer=="adam":
optimizer = torch.optim.Adam(params, lr=opt.lr, weight_decay=opt.wd)
elif opt.optimizer == "rmsprop":
optimizer = torch.optim.RMSprop(params, lr=opt.lr, weight_decay=opt.wd, momentum=opt.momentum, eps=1e-10)
else:
assert "Invalid optimizer: %s" % opt.optimizer
n_params = sum([sum([t.numel() for t in d['params']]) for d in params])
print("Number of parameters: %d" % n_params)
model = model.to(device)
if embedding is not None and hasattr(embedding, "to"):
embedding = embedding.to(device)
i=0
loss_sum = 0
loss_plot = Visdom.Plot2D("loss", store_interval=opt.info_interval, xlabel="iterations", ylabel="loss")
if curriculum is not None:
curriculum_plot = Visdom.Plot2D("curriculum lesson" +
(" (last %d)" % (curriculum.n_lessons-1) if curriculum.n_lessons is not None else ""),
xlabel="iterations", ylabel="lesson")
curriculum_accuracy = Visdom.Plot2D("curriculum accuracy", xlabel="iterations", ylabel="accuracy")
saver = Saver(os.path.join(opt.name, "save"), short_interval=opt.save_interval)
saver.register("model", StateSaver(model))
saver.register("optimizer", StateSaver(optimizer))
saver.register("i", GlobalVarSaver("i"))
saver.register("loss_sum", GlobalVarSaver("loss_sum"))
saver.register("loss_plot", StateSaver(loss_plot))
saver.register("dataset", StateSaver(dataset))
if test_set:
saver.register("test_set", StateSaver(test_set))
if curriculum is not None:
saver.register("curriculum", StateSaver(curriculum))
saver.register("curriculum_plot", StateSaver(curriculum_plot))
saver.register("curriculum_accuracy", StateSaver(curriculum_accuracy))
if isinstance(dataset, NLPTask):
saver.register("word_embeddings", StateSaver(embedding))
elif embedding is not None:
saver.register("embeddings", StateSaver(embedding))
if not saver.load(opt.load):
model.reset_parameters()
if embedding is not None:
embedding.reset_parameters()
visualizers = {}
debug_schemas={
"read_head" : {
"list_dim" : 2
},
"temporal_links/forward_dists" : {
"list_dim" : 2
},
"temporal_links/backward_dists" : {
"list_dim" : 2
}
}
def plot_debug(debug, prefix="", schema={}):
if debug is None:
return
for k, v in debug.items():
curr_name = prefix+k
if curr_name in debug_schemas:
curr_schema = schema.copy()
curr_schema.update(debug_schemas[curr_name])
else:
curr_schema = schema
if isinstance(v, dict):
plot_debug(v, curr_name+"/", curr_schema)
continue
data = v[0]
if curr_schema.get("list_dim",-1) > 0:
if data.ndim != 3:
print("WARNING: unknown data shape for array display: %s, tensor %s" % (data.shape, curr_name))
continue
n_steps = data.shape[curr_schema["list_dim"]-1]
if curr_name not in visualizers:
visualizers[curr_name] = [Visdom.Heatmap(curr_name+"_%d" % i, dumpdir=os.path.join(opt.name, "preview") if opt.dump_heatmaps else None) for i in range(n_steps)]
for i in range(n_steps):
visualizers[curr_name][i].draw(index_by_dim(data, curr_schema["list_dim"]-1, i))
else:
if data.ndim != 2:
print("WARNING: unknown data shape for simple display: %s, tensor %s" % (data.shape, curr_name))
continue
if curr_name not in visualizers:
visualizers[curr_name] = Visdom.Heatmap(curr_name, dumpdir=os.path.join(opt.name, "preview") if opt.dump_heatmaps else None)
visualizers[curr_name].draw(data)
def run_model(input, debug=None):
if isinstance(dataset, NLPTask):
input = embedding(input["input"])
else:
input = input["input"] * 2.0 - 1.0
return model(input, debug=debug)
def multiply_grads(params, mul):
if mul==1:
return
for pa in params:
for p in pa["params"]:
p.grad.data *= mul
def test():
if test_set is None:
return
print("TESTING...")
start_time=time.time()
t = test_set.start_test()
with torch.no_grad():
for data in tqdm(test_loader):
data = {k: v.to(device) if torch.is_tensor(v) else v for k, v in data.items()}
if hasattr(dataset, "prepare"):
data = dataset.prepare(data)
net_out = run_model(data)
test_set.veify_result(t, data, net_out)
test_set.show_test_results(i, t)
print("Test done in %gs" % (time.time() - start_time))
if opt.test_on_start.lower() in ["on", "1", "true", "quit"]:
test()
if opt.test_on_start.lower() == "quit":
saver.write(i)
sys.exit(-1)
if opt.print_test:
model.eval()
total = 0
correct = 0
with torch.no_grad():
for data in tqdm(test_loader):
if not running:
return
data = {k: v.to(device) if torch.is_tensor(v) else v for k, v in data.items()}
if hasattr(test_set, "prepare"):
data = test_set.prepare(data)
net_out = run_model(data)
c,t = test_set.curriculum_measure(net_out, data["output"])
total += t
correct += c
print("Test result: %2.f%% (%d out of %d correct)" % (100.0*correct/total, correct, total))
model.train()
return
iter_start_time = time.time() if i % opt.info_interval == 0 else None
data_load_total_time = 0
start_i = i
if opt.dump_profile:
profiler = torch.autograd.profiler.profile(use_cuda=True)
if opt.dump_heatmaps:
dataset.set_dump_dir(os.path.join(opt.name, "preview"))
@preview()
def do_visualize(raw_data, output, pos_map, debug):
if pos_map is not None:
output = embedding.backmap_output(output, pos_map, raw_data["output"].shape[1])
dataset.visualize_preview(raw_data, output)
if debug is not None:
plot_debug(debug)
preview_timer=OnceEvery(opt.preview_interval)
pos_map = None
start_iter = i
if curriculum is not None:
curriculum.init()
while running:
data_load_timer = time.time()
for data in data_loader:
if not running:
break
if loader_reset:
print("Loader reset requested. Resetting...")
loader_reset = False
if curriculum is not None:
curriculum.lesson_started()
break
if opt.dump_profile:
if i==start_i+1:
print("Starting profiler")
profiler.__enter__()
elif i==start_i+5+1:
print("Stopping profiler")
profiler.__exit__(None, None, None)
print("Average stats")
print(profiler.key_averages().table("cpu_time_total"))
print("Writing trace to file")
profiler.export_chrome_trace(opt.dump_profile)
print("Done.")
sys.exit(0)
else:
print("Step %d out of 5" % (i-start_i))
debug.dbg_print("-------------------------------------")
raw_data = data
data = {k: v.to(device) if torch.is_tensor(v) else v for k,v in data.items()}
if hasattr(dataset, "prepare"):
data = dataset.prepare(data)
data_load_total_time += time.time() - data_load_timer
need_preview = preview_timer()
debug_data = {} if opt.debug and need_preview else None
optimizer.zero_grad()
if opt.n_subbatch=="auto":
n_subbatch = math.ceil(data["input"].numel() / opt.max_input_count_per_batch)
else:
n_subbatch = int(opt.n_subbatch)
real_batch = max(math.floor(opt.batch_size/n_subbatch),1)
n_subbatch = math.ceil(opt.batch_size/real_batch)
remaning_batch = opt.batch_size % real_batch
for subbatch in range(n_subbatch):
if not running:
break
input = data["input"]
target = data["output"]
if n_subbatch!=1:
input = input[subbatch * real_batch: (subbatch + 1) * real_batch]
target = target[subbatch * real_batch:(subbatch + 1) * real_batch]
f2 = data.copy()
f2["input"] = input
output = run_model(f2, debug=debug_data if subbatch==n_subbatch-1 else None)
l = dataset.loss(output, target)
debug.nan_check(l, force=True)
l.backward()
if curriculum is not None:
curriculum.update(*dataset.curriculum_measure(output, target))
if remaning_batch!=0 and subbatch == n_subbatch-2:
multiply_grads(params, real_batch/remaning_batch)
if n_subbatch!=1:
if remaning_batch==0:
multiply_grads(params, 1/n_subbatch)
else:
multiply_grads(params, remaning_batch / opt.batch_size)
for p in params:
torch.nn.utils.clip_grad_norm_(p["params"], opt.grad_clip)
optimizer.step()
i += 1
curr_loss = l.data.item()
loss_plot.add_point(i, curr_loss)
loss_sum += curr_loss
if i % opt.info_interval == 0:
tim = time.time()
loss_avg = loss_sum / opt.info_interval
if curriculum is not None:
curriculum_accuracy.add_point(i, curriculum.get_accuracy())
curriculum_plot.add_point(i, curriculum.step)
message = "Iteration %d, loss: %.4f" % (i, loss_avg)
if iter_start_time is not None:
message += " (%.2f ms/iter, load time %.2g ms/iter, visport: %s)" % (
(tim - iter_start_time) / opt.info_interval * 1000.0,
data_load_total_time / opt.info_interval * 1000.0,
Visdom.port)
print(message)
iter_start_time = tim
loss_sum = 0
data_load_total_time = 0
debug.dbg_print("Iteration %d, loss %g" % (i, curr_loss))
if need_preview:
do_visualize(raw_data, output, pos_map, debug_data)
if i % opt.test_interval==0:
test()
saver.tick(i)
if opt.demo and opt.exit_after is None:
running = False
input("Press enter to quit.")
if opt.exit_after is not None and (i-start_iter)>=opt.exit_after:
running=False
data_load_timer = time.time()
if __name__ == "__main__":
global running
running = True
def signal_handler(signal, frame):
global running
print('You pressed Ctrl+C!')
running = False
signal.signal(signal.SIGINT, signal_handler)
main()

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
tqdm
visdom
numpy