diff --git a/Dataset/.DS_Store b/Dataset/.DS_Store new file mode 100644 index 0000000..3f2c165 Binary files /dev/null and b/Dataset/.DS_Store differ diff --git a/Dataset/Bitmap/.DS_Store b/Dataset/Bitmap/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/Dataset/Bitmap/.DS_Store differ diff --git a/Dataset/Bitmap/AssociativeRecall.py b/Dataset/Bitmap/AssociativeRecall.py new file mode 100644 index 0000000..86f9434 --- /dev/null +++ b/Dataset/Bitmap/AssociativeRecall.py @@ -0,0 +1,72 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import numpy as np +from .BitmapTask import BitmapTask +from Utils.Seed import get_randstate + + +class AssociativeRecall(BitmapTask): + def __init__(self, length=None, bit_w=8, block_w=3, transform=lambda x: x): + super(AssociativeRecall, self).__init__() + self.length = length + self.bit_w = bit_w + self.block_w = block_w + self.transform = transform + self.seed = None + + def __getitem__(self, key): + if self.seed is None: + self.seed = get_randstate() + + length = self.length() if callable(self.length) else self.length + if length is None: + # Random length batch hack. + length = key + + stride = self.block_w + 1 + + d = self.seed.randint( + 0, 2, [length * (self.block_w + 1), self.bit_w + 2] + ).astype(np.float32) + d[:, -2:] = 0 + + # Terminate input block + for i in range(1, length, 1): + d[i * stride - 1, :] = 0 + d[i * stride - 1, -2] = 1 + + # Terminate input sequence + d[-1, :] = 0 + d[-1, -1] = 1 + + # Add and terminate query + ti = self.seed.randint(0, length - 1) + d = np.concatenate( + ( + d, + d[ti * stride : (ti + 1) * stride - 1], + np.zeros([self.block_w + 1, self.bit_w + 2], np.float32), + ), + axis=0, + ) + d[-(1 + self.block_w), -1] = 1 + + # Target + target = np.zeros_like(d) + target[-self.block_w :] = d[(ti + 1) * stride : (ti + 2) * stride - 1] + + return self.transform({"input": d, "output": target}) diff --git a/Dataset/Bitmap/BitmapTask.py b/Dataset/Bitmap/BitmapTask.py new file mode 100644 index 0000000..eac03e4 --- /dev/null +++ b/Dataset/Bitmap/BitmapTask.py @@ -0,0 +1,82 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import torch +import torch.nn.functional as F +from Visualize.BitmapTask import visualize_bitmap_task +from Utils import Visdom +from Utils import universal as U + +import numpy as np + + +class BitmapTask(torch.utils.data.Dataset): + def __init__(self): + super(BitmapTask, self).__init__() + + self._img = Visdom.Image("preview") + + def set_dump_dir(self, dir): + self._img.set_dump_dir(dir) + + def __len__(self): + return 0x7FFFFFFF + + def visualize_preview(self, data, net_output): + img = visualize_bitmap_task( + data["input"], [data["output"], U.sigmoid(net_output)] + ) + self._img.draw(img) + + def loss(self, net_output, target): + return F.binary_cross_entropy_with_logits( + net_output, target, reduction="sum" + ) / net_output.size(0) + + def accuracy(self, net_output, target): + return F.binary_cross_entropy_with_logits( + net_output, target, reduction="sum" + ) / net_output.size(0) + + def demon_loss(self, net_output, target, saved_actions, device): + """ + computes the loss for the demon + :param net_output: + :param target: + :param saved_actions: + :return: + """ + net_output = net_output.detach() + loss = F.binary_cross_entropy_with_logits( + net_output, target, reduction="none" + ).sum(dim=-1) + + policy_losses = [] # list to save actor (policy) loss + + discount_factor = 0.99 + for i in range(0, loss.size(1)): # computing expected total reward + discount_vector = torch.from_numpy(np.array([np.power(discount_factor,i) for i in range(loss.size(1)-i)])).to(device) + policy_losses.append(((saved_actions[i].log_prob).squeeze(1) * (discount_vector*loss[:, i:]).mean(dim=1))) + + demon_loss = torch.stack(policy_losses).mean(dim=0)/loss.size(1) + + return demon_loss + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass diff --git a/Dataset/Bitmap/BitmapTaskRepeater.py b/Dataset/Bitmap/BitmapTaskRepeater.py new file mode 100644 index 0000000..0457b08 --- /dev/null +++ b/Dataset/Bitmap/BitmapTaskRepeater.py @@ -0,0 +1,56 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import numpy as np +import random +from .BitmapTask import BitmapTask + + +class BitmapTaskRepeater(BitmapTask): + def __init__(self, dataset): + super(BitmapTaskRepeater, self).__init__() + self.dataset = dataset + + def __getitem__(self, key): + r = [self.dataset[k] for k in key] + if len(r) == 1: + return r[0] + else: + return { + "input": np.concatenate([a["input"] for a in r], axis=0), + "output": np.concatenate([a["output"] for a in r], axis=0), + } + + @staticmethod + def key_sampler(length, repeat): + def call_sampler(s): + if callable(s): + return s() + elif isinstance(s, list): + if len(s) == 2: + return random.randint(*s) + elif len(s) == 1: + return s[0] + else: + assert False, "Invalid sample parameter: %s" % s + else: + return s + + def s(): + r = call_sampler(repeat) + return [call_sampler(length) for i in range(r)] + + return s diff --git a/Dataset/Bitmap/CopyTask.py b/Dataset/Bitmap/CopyTask.py new file mode 100644 index 0000000..953f9f3 --- /dev/null +++ b/Dataset/Bitmap/CopyTask.py @@ -0,0 +1,49 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import numpy as np +from .BitmapTask import BitmapTask +from Utils.Seed import get_randstate + + +class CopyData(BitmapTask): + def __init__(self, length=None, bit_w=8, transform=lambda x: x): + super(CopyData, self).__init__() + self.length = length + self.bit_w = bit_w + self.transform = transform + self.seed = None + + def __getitem__(self, key): + if self.seed is None: + self.seed = get_randstate() + + length = self.length() if callable(self.length) else self.length + if length is None: + # Random length batch hack. + length = key + + d = self.seed.randint(0, 2, [length + 1, self.bit_w + 1]).astype(np.float32) + z = np.zeros_like(d) + + d[-1] = 0 + d[:, -1] = 0 + d[-1, -1] = 1 + + i_p = np.concatenate((d, z), axis=0) + o_p = np.concatenate((z, d), axis=0) + + return self.transform({"input": i_p, "output": o_p}) diff --git a/Dataset/Bitmap/KeyValue.py b/Dataset/Bitmap/KeyValue.py new file mode 100644 index 0000000..527dc9b --- /dev/null +++ b/Dataset/Bitmap/KeyValue.py @@ -0,0 +1,81 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import math +import numpy as np +from .BitmapTask import BitmapTask +from Utils.Seed import get_randstate + + +class KeyValue(BitmapTask): + def __init__(self, length=None, bit_w=8, transform=lambda x: x): + assert bit_w % 2 == 0, "bit_w must be even" + super(KeyValue, self).__init__() + self.length = length + self.bit_w = bit_w + self.transform = transform + self.seed = None + self.key_w = self.bit_w // 2 + self.max_key = 2 ** self.key_w - 1 + + def __getitem__(self, key): + if self.seed is None: + self.seed = get_randstate() + + if self.length is None: + # Random length batch hack. + length = key + else: + length = self.length() if callable(self.length) else self.length + + # keys must be unique + keys = None + last_size = 0 + while last_size != length: + res = self.seed.random_integers(0, self.max_key, size=(length - last_size)) + if keys is not None: + keys = np.concatenate((res, keys)) + else: + keys = res + + keys = np.unique(keys) + last_size = keys.size + + # view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it + keys = keys.view(np.uint8).reshape(length, -1) + keys = keys[:, : math.ceil(self.key_w / 8)] + keys = np.unpackbits(np.expand_dims(keys, -1), axis=-1) + keys = np.flip(keys, axis=-1).reshape(keys.shape[0], -1)[:, : self.key_w] + keys = keys.astype(np.float32) + + values = self.seed.randint(0, 2, keys.shape).astype(np.float32) + + perm = self.seed.permutation(length) + keys_perm = keys[perm, :] + values_perm = values[perm, :] + + i_p = np.zeros((2 * length + 2, self.bit_w + 1), dtype=np.float32) + i_p[:length, : self.key_w] = keys + i_p[:length, self.key_w : -1] = values + i_p[length + 1 : -1, : self.key_w] = keys_perm + + i_p[length, -1] = 1 + i_p[-1, -1] = 1 + + o_p = np.zeros((2 * length + 2, self.key_w), dtype=np.float32) + o_p[length + 1 : -1] = values_perm + + return self.transform({"input": i_p, "output": o_p}) diff --git a/Dataset/Bitmap/KeyValue2Way.py b/Dataset/Bitmap/KeyValue2Way.py new file mode 100644 index 0000000..0656b10 --- /dev/null +++ b/Dataset/Bitmap/KeyValue2Way.py @@ -0,0 +1,89 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import math +import numpy as np +from .BitmapTask import BitmapTask +from Utils.Seed import get_randstate + + +class KeyValue2Way(BitmapTask): + def __init__(self, length=None, bit_w=8, transform=lambda x: x): + assert bit_w % 2 == 0, "bit_w must be even" + super(KeyValue2Way, self).__init__() + self.length = length + self.bit_w = bit_w + self.transform = transform + self.seed = None + self.key_w = self.bit_w // 2 + self.max_key = 2 ** self.key_w - 1 + + def __getitem__(self, key): + if self.seed is None: + self.seed = get_randstate() + + if self.length is None: + # Random length batch hack. + length = key + else: + length = self.length() if callable(self.length) else self.length + + # keys must be unique + keys = None + last_size = 0 + while last_size != length: + res = self.seed.random_integers(0, self.max_key, size=(length - last_size)) + if keys is not None: + keys = np.concatenate((res, keys)) + else: + keys = res + + keys = np.unique(keys) + last_size = keys.size + + # view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it + keys = keys.view(np.uint8).reshape(length, -1) + keys = keys[:, : math.ceil(self.key_w / 8)] + keys = np.unpackbits(np.expand_dims(keys, -1), axis=-1) + keys = np.flip(keys, axis=-1).reshape(keys.shape[0], -1)[:, : self.key_w] + keys = keys.astype(np.float32) + + values = self.seed.randint(0, 2, keys.shape).astype(np.float32) + + perm = self.seed.permutation(length) + keys_perm = keys[perm, :] + values_perm = values[perm, :] + + i_p = np.zeros((3 * (length + 1), self.bit_w + 2), dtype=np.float32) + o_p = np.zeros((3 * (length + 1), self.key_w), dtype=np.float32) + + i_p[:length, : self.key_w] = keys + i_p[:length, self.key_w : -2] = values + i_p[length + 1 : 2 * length + 1, : self.key_w] = keys_perm + o_p[length + 1 : 2 * length + 1] = values_perm + + perm = self.seed.permutation(length) + keys_perm = keys[perm, :] + values_perm = values[perm, :] + + o_p[2 * (length + 1) : -1] = keys_perm + i_p[2 * (length + 1) : -1, : self.key_w] = values_perm + + i_p[length, -2] = 1 + i_p[2 * length + 1, -1] = 1 + i_p[-1, -2:] = 1 + + return self.transform({"input": i_p, "output": o_p}) diff --git a/Dataset/Bitmap/__init__.py b/Dataset/Bitmap/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Dataset/NLP/.DS_Store b/Dataset/NLP/.DS_Store new file mode 100644 index 0000000..e98ece8 Binary files /dev/null and b/Dataset/NLP/.DS_Store differ diff --git a/Dataset/NLP/.gitignore b/Dataset/NLP/.gitignore new file mode 100644 index 0000000..06cf653 --- /dev/null +++ b/Dataset/NLP/.gitignore @@ -0,0 +1 @@ +cache diff --git a/Dataset/NLP/NLPTask.py b/Dataset/NLP/NLPTask.py new file mode 100644 index 0000000..419347e --- /dev/null +++ b/Dataset/NLP/NLPTask.py @@ -0,0 +1,153 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import torch +import torch.nn.functional as F +import os +from .Vocabulary import Vocabulary +from Utils import Visdom +from Utils import universal as U + +import numpy as np + + +class NLPTask(torch.utils.data.Dataset): + def __init__(self): + super(NLPTask, self).__init__() + + self.my_dir = os.path.abspath(os.path.dirname(__file__)) + self.cache_dir = os.path.join(self.my_dir, "cache") + + if not os.path.isdir(self.cache_dir): + os.makedirs(self.cache_dir) + + self.vocabulary = self._load_vocabulary() + + self._preview = None + + def _load_vocabulary(self): + cache_file = os.path.join(self.cache_dir, "vocabulary.pth") + if not os.path.isfile(cache_file): + print("WARNING: Vocabulary not found. Removing cached files.") + for f in os.listdir(self.cache_dir): + f = os.path.join(self.cache_dir, f) + if f.endswith(".pth"): + print(" " + f) + os.remove(f) + return Vocabulary() + else: + return torch.load(cache_file) + + def save_vocabulary(self): + cache_file = os.path.join(self.cache_dir, "vocabulary.pth") + if os.path.isfile(cache_file): + os.remove(cache_file) + torch.save(self.vocabulary, cache_file) + + def loss(self, net_output, target): + s = list(net_output.size()) + return ( + F.cross_entropy( + net_output.view([s[0] * s[1], s[2]]), + target.view([-1]), + ignore_index=0, + reduction="sum", + ) + / s[0] + ) + + + def demon_loss(self, net_output, target, saved_actions, device): + """ + computes the loss for the demon + :param net_output: + :param target: + :param saved_actions: + :return: + """ + net_output = net_output.detach() + s = list(net_output.size()) + loss = F.cross_entropy( + net_output.view([s[0] * s[1], s[2]]), + target.view([-1]), + ignore_index=0, + reduction="none", + ).view(s[0], s[1]) + + policy_losses = [] # list to save actor (policy) loss + + discount_factor = 0.99 + for i in range(0, loss.size(1)): # computing expected total reward + discount_vector = torch.from_numpy( + np.array([np.power(discount_factor, i) for i in range(loss.size(1) - i)])).to(device) + policy_losses.append(((saved_actions[i].log_prob).squeeze(1) * (discount_vector * loss[:, i:]).mean(dim=1))) + + demon_loss = torch.stack(policy_losses).mean(dim=0) + + return demon_loss + + + def generate_preview_text(self, data, net_output): + input = U.to_numpy(data["input"][0]) + reference = U.to_numpy(data["output"][0]) + net_out = U.argmax(net_output[0], -1) + net_out = U.to_numpy(net_out) + + res = "" + start_index = 0 + + for i in range(input.shape[0]): + if reference[i] != 0: + if start_index < i: + end_index = i + while end_index > start_index and input[end_index] == 0: + end_index -= 1 + + if end_index > start_index: + sentence = ( + " ".join( + self.vocabulary.indices_to_sentence( + input[start_index:i].tolist() + ) + ) + .replace(" .", ".") + .replace(" ,", ",") + .replace(" ?", "?") + .split(". ") + ) + sentence = ". ".join([s.capitalize() for s in sentence]) + res += sentence + "
" + + start_index = i + 1 + + match = reference[i] == net_out[i] + res += '%s [%s]
' % ( + "green" if match else "red", + self.vocabulary.indices_to_sentence([net_out[i]])[0], + self.vocabulary.indices_to_sentence([reference[i]])[0], + ) + return res + + def visualize_preview(self, data, net_output): + res = self.generate_preview_text(data, net_output) + + if self._preview is None: + self._preview = Visdom.Text("Preview") + + self._preview.set(res) + + def set_dump_dir(self, dir): + pass diff --git a/Dataset/NLP/Vocabulary.py b/Dataset/NLP/Vocabulary.py new file mode 100644 index 0000000..5b77fdc --- /dev/null +++ b/Dataset/NLP/Vocabulary.py @@ -0,0 +1,52 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + + +class Vocabulary: + def __init__(self): + self.words = {"-": 0, "?": 1, "": 2} + self.inv_words = {0: "-", 1: "?", 2: ""} + self.next_id = 3 + self.punctations = [".", "?", ","] + + def _process_word(self, w, add_words): + if not w.isalpha() and w not in self.punctations: + print("WARNING: word with unknown characters: %s", w) + w = "" + + if w not in self.words: + if add_words: + self.words[w] = self.next_id + self.inv_words[self.next_id] = w + self.next_id += 1 + else: + w = "" + + return self.words[w] + + def sentence_to_indices(self, sentence, add_words=True): + for p in self.punctations: + sentence = sentence.replace(p, " %s " % p) + + return [ + self._process_word(w, add_words) for w in sentence.lower().split(" ") if w + ] + + def indices_to_sentence(self, indices): + return [self.inv_words[i] for i in indices] + + def __len__(self): + return len(self.words) diff --git a/Dataset/NLP/__init__.py b/Dataset/NLP/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Dataset/NLP/bAbi.py b/Dataset/NLP/bAbi.py new file mode 100644 index 0000000..013775f --- /dev/null +++ b/Dataset/NLP/bAbi.py @@ -0,0 +1,297 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import os +import glob +import torch +from collections import namedtuple +import numpy as np +from .NLPTask import NLPTask +from Utils import Visdom + +Sentence = namedtuple("Sentence", ["sentence", "answer", "supporting_facts"]) + + +class bAbiDataset(NLPTask): + URL = "http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz" + DIR_NAME = "tasks_1-20_v1-2" + + def __init__( + self, dirs=["en-10k"], sets=None, think_steps=0, dir_name=None, name=None + ): + super(bAbiDataset, self).__init__() + + self._test_res_win = None + self._test_plot_win = None + self._think_steps = think_steps + + if dir_name is None: + self._download() + dir_name = os.path.join(self.cache_dir, self.DIR_NAME) + + self.data = {} + for d in dirs: + self.data[d] = self._load_or_create(os.path.join(dir_name, d)) + + self.all_tasks = None + self.name = name + self.use(sets=sets) + + def _make_active_list(self, tasks, sets, dirs): + def verify(name, checker): + if checker is None: + return True + + if callable(checker): + return checker(name) + elif isinstance(checker, list): + return name in checker + else: + return name == checker + + res = [] + for dirname, setlist in self.data.items(): + if not verify(dirname, dirs): + continue + + for sname, tasklist in setlist.items(): + if not verify(sname, sets): + continue + + for task, data in tasklist.items(): + name = task.split("_")[0][2:] + if not verify(name, tasks): + continue + + res += [(d, dirname, task, sname) for d in data] + + return res + + def use(self, tasks=None, sets=None, dirs=None): + self.all_tasks = self._make_active_list(tasks=tasks, sets=sets, dirs=dirs) + + def __len__(self): + return len(self.all_tasks) + + def _get_seq(self, index): + return self.all_tasks[index] + + def _seq_to_nn_input(self, seq): + in_arr = [] + out_arr = [] + hasAnswer = False + for sentence in seq[0]: + in_arr += sentence.sentence + out_arr += [0] * len(sentence.sentence) + if sentence.answer is not None: + in_arr += [0] * (len(sentence.answer) + self._think_steps) + out_arr += [0] * self._think_steps + sentence.answer + hasAnswer = True + + in_arr = np.asarray(in_arr, np.int64) + out_arr = np.asarray(out_arr, np.int64) + + return { + "input": in_arr, + "output": out_arr, + "meta": {"dir": seq[1], "task": seq[2], "set": seq[3]}, + } + + def __getitem__(self, item): + seq = self._get_seq(item) + return self._seq_to_nn_input(seq) + + def _load_or_create(self, directory): + cache_name = directory.replace("/", "_") + cache_file = os.path.join(self.cache_dir, cache_name + ".pth") + if not os.path.isfile(cache_file): + print("bAbI: Loading %s" % directory) + res = self._load_dir(directory) + print("Write: ", cache_file) + self.save_vocabulary() + torch.save(res, cache_file) + else: + res = torch.load(cache_file) + return res + + def _download(self): + if not os.path.isdir(os.path.join(self.cache_dir, self.DIR_NAME)): + print(self.URL) + print("bAbi data not found. Downloading...") + import requests, tarfile, io + + request = requests.get( + self.URL, + headers={ + "User-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36" + }, + ) + + decompressed_file = tarfile.open( + fileobj=io.BytesIO(request.content), mode="r|gz" + ) + decompressed_file.extractall(self.cache_dir) + print("Done") + + def _load_dir( + self, + directory, + parse_name=lambda x: x.split(".")[0], + parse_set=lambda x: x.split(".")[0].split("_")[-1], + ): + res = {} + for f in glob.glob(os.path.join(directory, "**", "*.txt"), recursive=True): + basename = os.path.basename(f) + task_name = parse_name(basename) + set = parse_set(basename) + print("Loading", f) + + s = res.get(set) + if s is None: + s = {} + res[set] = s + s[task_name] = self._load_task(f, task_name) + + return res + + def _load_task(self, filename, task_name): + task = [] + currTask = [] + + nextIndex = 1 + with open(filename, "r") as f: + for line in f: + line = [f.strip() for f in line.split("\t")] + line[0] = line[0].split(" ") + i = int(line[0][0]) + line[0] = " ".join(line[0][1:]) + + if i != nextIndex: + nextIndex = i + task.append(currTask) + currTask = [] + + isQuestion = len(line) > 1 + currTask.append( + Sentence( + self.vocabulary.sentence_to_indices(line[0]), + self.vocabulary.sentence_to_indices(line[1].replace(",", " ")) + if isQuestion + else None, + [int(f) for f in line[2].split(" ")] if isQuestion else None, + ) + ) + + nextIndex += 1 + return task + + def start_test(self): + return {} + + def veify_result(self, test, data, net_output): + _, net_output = net_output.max(-1) + + ref = data["output"] + + mask = 1.0 - ref.eq(0).float() + + correct = (torch.eq(net_output, ref).float() * mask).sum(-1) + total = mask.sum(-1) + + correct = correct.data.cpu().numpy() + total = total.data.cpu().numpy() + + for i in range(correct.shape[0]): + task = data["meta"][i]["task"] + if task not in test: + test[task] = {"total": 0, "correct": 0} + + d = test[task] + d["total"] += total[i] + d["correct"] += correct[i] + + def _ensure_test_wins_exists(self, legend=None): + if self._test_res_win is None: + n = ("[" + self.name + "]") if self.name is not None else "" + self._test_res_win = Visdom.Text("Test results" + n) + self._test_plot_win = Visdom.Plot2D("Test results" + n, legend=legend) + elif self._test_plot_win.legend is None: + self._test_plot_win.set_legend(legend=legend) + + def show_test_results(self, iteration, test): + res = {k: v["correct"] / v["total"] for k, v in test.items()} + + t = "" + + all_keys = list(res.keys()) + + num_keys = [k for k in all_keys if k.startswith("qa")] + tmp = [ + i[0] + for i in sorted( + enumerate(num_keys), key=lambda x: int(x[1][2:].split("_")[0]) + ) + ] + num_keys = [num_keys[j] for j in tmp] + + all_keys = num_keys + sorted([k for k in all_keys if not k.startswith("qa")]) + + err_precent = [(1.0 - res[k]) * 100.0 for k in all_keys] + + n_passed = sum([int(p <= 5) for p in err_precent]) + n_total = len(err_precent) + err_precent = err_precent + [sum(err_precent) / len(err_precent)] + all_keys += ["mean"] + + for i, k in enumerate(all_keys): + t += '%s: %.2f%%
' % ( + "green" if err_precent[i] <= 5 else "red", + k, + err_precent[i], + ) + + t += "
Total: %d of %d passed." % (n_passed, n_total) + + self._ensure_test_wins_exists( + legend=[i.split("_")[0] if i.startswith("qa") else i for i in all_keys] + ) + + self._test_res_win.set(t) + self._test_plot_win.add_point(iteration, err_precent) + + def state_dict(self): + if self._test_res_win is not None: + return { + "_test_res_win": self._test_res_win.state_dict(), + "_test_plot_win": self._test_plot_win.state_dict(), + } + else: + return {} + + def load_state_dict(self, state): + if state: + self._ensure_test_wins_exists() + self._test_res_win.load_state_dict(state["_test_res_win"]) + self._test_plot_win.load_state_dict(state["_test_plot_win"]) + self._test_plot_win.legend = None + + def visualize_preview(self, data, net_output): + res = self.generate_preview_text(data, net_output) + res = ("%s
" % data["meta"][0]["task"]) + res + if self._preview is None: + self._preview = Visdom.Text("Preview") + + self._preview.set(res) diff --git a/Dataset/__init__.py b/Dataset/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/Dataset/__init__.py @@ -0,0 +1 @@ + diff --git a/Models/.DS_Store b/Models/.DS_Store new file mode 100644 index 0000000..e98ece8 Binary files /dev/null and b/Models/.DS_Store differ diff --git a/Models/DNCA.py b/Models/DNCA.py new file mode 100644 index 0000000..6dd7a30 --- /dev/null +++ b/Models/DNCA.py @@ -0,0 +1,965 @@ +# The Initial DNC Copyright 2017 Robert Csordas. All Rights Reserved. +# The modification of the initial DNC implementation by Ari Azarafrooz. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + + +import torch +import torch.utils.data +import torch.nn.functional as F +import torch.nn.init as init +import functools +import math + + +def oneplus(t): + return F.softplus(t, 1, 20) + 1.0 + + +def get_next_tensor_part(src, dims, prev_pos=0): + if not isinstance(dims, list): + dims = [dims] + n = functools.reduce(lambda x, y: x * y, dims) + data = src.narrow(-1, prev_pos, n) + return ( + data.contiguous().view(list(data.size())[:-1] + dims) + if len(dims) > 1 + else data, + prev_pos + n, + ) + + +def split_tensor(src, shapes): + pos = 0 + res = [] + for s in shapes: + d, pos = get_next_tensor_part(src, s, pos) + res.append(d) + return res + + +def dict_get(dict, name): + return dict.get(name) if dict is not None else None + + +def dict_append(dict, name, val): + if dict is not None: + l = dict.get(name) + if not l: + l = [] + dict[name] = l + l.append(val) + + +def init_debug(debug, initial): + if debug is not None and not debug: + debug.update(initial) + + +def merge_debug_tensors(d, dim): + if d is not None: + for k, v in d.items(): + if isinstance(v, dict): + merge_debug_tensors(v, dim) + elif isinstance(v, list): + d[k] = torch.stack(v, dim) + + +def linear_reset(module, gain=1.0): + assert isinstance(module, torch.nn.Linear) + init.xavier_uniform_(module.weight, gain=gain) + s = module.weight.size(1) + if module.bias is not None: + module.bias.data.zero_() + + +_EPS = 1e-6 + + +class AllocationManager(torch.nn.Module): + def __init__(self): + super(AllocationManager, self).__init__() + self.usages = None + self.zero_usages = None + self.debug_sequ_init = False + self.one = None + + def _init_sequence(self, prev_read_distributions): + # prev_read_distributions size is [batch, n_heads, cell count] + s = prev_read_distributions.size() + if self.zero_usages is None or list(self.zero_usages.size()) != [s[0], s[-1]]: + self.zero_usages = torch.zeros( + s[0], s[-1], device=prev_read_distributions.device + ) + if self.debug_sequ_init: + self.zero_usages += torch.arange(0, s[-1]).unsqueeze(0) * 1e-10 + + self.usages = self.zero_usages + + def _init_consts(self, device): + if self.one is None: + self.one = torch.ones(1, device=device) + + def new_sequence(self): + self.usages = None + + def update_usages( + self, prev_write_distribution, prev_read_distributions, free_gates + ): + # Read distributions shape: [batch, n_heads, cell count] + # Free gates shape: [batch, n_heads] + + self._init_consts(prev_read_distributions.device) + phi = torch.addcmul( + self.one, -1, free_gates.unsqueeze(-1), prev_read_distributions + ).prod(-2) + # Phi is the free tensor, sized [batch, cell count] + + # If memory usage counter if doesn't exists + if self.usages is None: + self._init_sequence(prev_read_distributions) + # in first timestep nothing is written or read yet, so we don't need any further processing + else: + self.usages = ( + torch.addcmul( + self.usages, 1, prev_write_distribution.detach(), (1 - self.usages) + ) + * phi + ) + + return phi + + def forward(self, prev_write_distribution, prev_read_distributions, free_gates): + phi = self.update_usages( + prev_write_distribution, prev_read_distributions, free_gates + ) + sorted_usage, free_list = (self.usages * (1.0 - _EPS) + _EPS).sort(-1) + + u_prod = sorted_usage.cumprod(-1) + one_minus_usage = 1.0 - sorted_usage + sorted_scores = torch.cat( + [one_minus_usage[..., 0:1], one_minus_usage[..., 1:] * u_prod[..., :-1]], + dim=-1, + ) + + return sorted_scores.clone().scatter_(-1, free_list, sorted_scores), phi + + +class ContentAddressGenerator(torch.nn.Module): + def __init__( + self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False + ): + super(ContentAddressGenerator, self).__init__() + self.disable_content_norm = disable_content_norm + self.mask_min = mask_min + self.disable_key_masking = disable_key_masking + + def forward(self, memory, keys, betas, mask=None): + # Memory shape [batch, cell count, word length] + # Key shape [batch, n heads*, word length] + # Betas shape [batch, n heads] + if mask is not None and self.mask_min != 0: + mask = mask * (1.0 - self.mask_min) + self.mask_min + + single_head = keys.dim() == 2 + if single_head: + # Single head + keys = keys.unsqueeze(1) + if mask is not None: + mask = mask.unsqueeze(1) + + memory = memory.unsqueeze(1) + keys = keys.unsqueeze(-2) + + if mask is not None: + mask = mask.unsqueeze(-2) + memory = memory * mask + if not self.disable_key_masking: + keys = keys * mask + + # Shape [batch, n heads, cell count] + norm = keys.norm(dim=-1) + if not self.disable_content_norm: + norm = norm * memory.norm(dim=-1) + + scores = (memory * keys).sum(-1) / (norm + _EPS) + scores *= betas.unsqueeze(-1) + + res = F.softmax(scores, scores.dim() - 1) + return res.squeeze(1) if single_head else res + + +class WriteHead(torch.nn.Module): + @staticmethod + def create_write_archive(write_dist, erase_vector, write_vector, phi): + return dict( + write_dist=write_dist, + erase_vector=erase_vector, + write_vector=write_vector, + phi=phi, + ) + + def __init__( + self, + dealloc_content=True, + disable_content_norm=False, + mask_min=0.0, + disable_key_masking=False, + ): + super(WriteHead, self).__init__() + self.write_content_generator = ContentAddressGenerator( + disable_content_norm, + mask_min=mask_min, + disable_key_masking=disable_key_masking, + ) + self.allocation_manager = AllocationManager() + self.last_write = None + self.dealloc_content = dealloc_content + self.new_sequence() + + def new_sequence(self): + self.last_write = None + self.allocation_manager.new_sequence() + + @staticmethod + def mem_update(memory, write_dist, erase_vector, write_vector, phi): + # In original paper the memory content is NOT deallocated, which makes content based addressing basically + # unusable when multiple similar steps should be done. The reason for this is that the memory contents are + # still there, so the lookup will find them, unless an allocation clears it before the next search, which is + # completely random. So I'm arguing that erase matrix should also take in account the free gates (multiply it + # with phi) + write_dist = write_dist.unsqueeze(-1) + + erase_matrix = 1.0 - write_dist * erase_vector.unsqueeze(-2) + if phi is not None: + erase_matrix = erase_matrix * phi.unsqueeze(-1) + + update_matrix = write_dist * write_vector.unsqueeze(-2) + return memory * erase_matrix + update_matrix + + def forward( + self, + demon_action, + memory, + write_content_key, + write_beta, + erase_vector, + write_vector, + alloc_gate, + write_gate, + free_gates, + prev_read_dist, + write_mask=None, + debug=None, + ): + last_w_dist = ( + self.last_write["write_dist"] if self.last_write is not None else None + ) + + content_dist = self.write_content_generator( + memory, write_content_key, write_beta, mask=write_mask + ) + alloc_dist, phi = self.allocation_manager( + last_w_dist, prev_read_dist, free_gates + ) + + # Shape [batch, cell count] + write_dist = write_gate * ( + alloc_gate * alloc_dist + (1 - alloc_gate) * content_dist + ) + self.last_write = WriteHead.create_write_archive( + write_dist, + erase_vector, + write_vector, + phi if self.dealloc_content else None, + ) + + dict_append(debug, "alloc_dist", alloc_dist) + dict_append(debug, "write_dist", write_dist) + dict_append(debug, "mem_usages", self.allocation_manager.usages) + dict_append(debug, "free_gates", free_gates) + dict_append(debug, "write_betas", write_beta) + dict_append(debug, "write_gate", write_gate) + dict_append(debug, "write_vector", write_vector) + dict_append(debug, "alloc_gate", alloc_gate) + dict_append(debug, "erase_vector", erase_vector) + if write_mask is not None: + dict_append(debug, "write_mask", write_mask) + + return WriteHead.mem_update(memory, **self.last_write) + + +class RawWriteHead(torch.nn.Module): + def __init__( + self, + n_read_heads, + word_length, + use_mask=False, + dealloc_content=True, + disable_content_norm=False, + mask_min=0.0, + disable_key_masking=False, + ): + super(RawWriteHead, self).__init__() + self.write_head = WriteHead( + dealloc_content=dealloc_content, + disable_content_norm=disable_content_norm, + mask_min=mask_min, + disable_key_masking=disable_key_masking, + ) + self.word_length = word_length + self.n_read_heads = n_read_heads + self.use_mask = use_mask + self.input_size = ( + 3 * self.word_length + + self.n_read_heads + + 3 + + (self.word_length if use_mask else 0) + ) + + def new_sequence(self): + self.write_head.new_sequence() + + def get_prev_write(self): + return self.write_head.last_write + + def forward(self, demon_action, memory, nn_output, prev_read_dist, debug): + shapes = ( + [[self.word_length]] * (4 if self.use_mask else 3) + + [[self.n_read_heads]] + + [[1]] * 3 + ) + tensors = split_tensor(nn_output, shapes) + + if self.use_mask: + write_mask = torch.sigmoid(tensors[0]) + tensors = tensors[1:] + else: + write_mask = None + + ( + write_content_key, + erase_vector, + write_vector, + free_gates, + write_beta, + alloc_gate, + write_gate, + ) = tensors + + erase_vector = torch.sigmoid(erase_vector) + free_gates = torch.sigmoid(free_gates) + write_beta = oneplus(write_beta) + alloc_gate = torch.sigmoid(alloc_gate) + write_gate = torch.sigmoid(write_gate) + + return self.write_head( + demon_action, + memory, + write_content_key, + write_beta, + erase_vector, + write_vector, + alloc_gate, + write_gate, + free_gates, + prev_read_dist, + debug=debug, + write_mask=write_mask, + ) + + def get_neural_input_size(self): + return self.input_size + + +class TemporalMemoryLinkage(torch.nn.Module): + def __init__(self): + super(TemporalMemoryLinkage, self).__init__() + self.temp_link_mat = None + self.precedence_weighting = None + self.diag_mask = None + + self.initial_temp_link_mat = None + self.initial_precedence_weighting = None + self.initial_diag_mask = None + self.initial_shape = None + + def new_sequence(self): + self.temp_link_mat = None + self.precedence_weighting = None + self.diag_mask = None + + def _init_link(self, w_dist): + s = list(w_dist.size()) + if self.initial_shape is None or s != self.initial_shape: + self.initial_temp_link_mat = torch.zeros(s[0], s[-1], s[-1]).to( + w_dist.device + ) + self.initial_precedence_weighting = torch.zeros(s[0], s[-1]).to( + w_dist.device + ) + self.initial_diag_mask = ( + 1.0 - torch.eye(s[-1]).unsqueeze(0).to(w_dist) + ).detach() + + self.temp_link_mat = self.initial_temp_link_mat + self.precedence_weighting = self.initial_precedence_weighting + self.diag_mask = self.initial_diag_mask + + def _update_precedence(self, w_dist): + # w_dist shape: [ batch, cell count ] + self.precedence_weighting = ( + 1.0 - w_dist.sum(-1, keepdim=True) + ) * self.precedence_weighting + w_dist + + def _update_links(self, w_dist): + if self.temp_link_mat is None: + self._init_link(w_dist) + + wt_i = w_dist.unsqueeze(-1) + wt_j = w_dist.unsqueeze(-2) + pt_j = self.precedence_weighting.unsqueeze(-2) + + self.temp_link_mat = ( + (1 - wt_i - wt_j) * self.temp_link_mat + wt_i * pt_j + ) * self.diag_mask + + def forward(self, w_dist, prev_r_dists, debug=None): + self._update_links(w_dist) + self._update_precedence(w_dist) + + # prev_r_dists shape: [ batch, n heads, cell count ] + # Emulate matrix-vector multiplication by broadcast and sum. This way we don't need to transpose the matrix + tlm_multi_head = self.temp_link_mat.unsqueeze(1) + + forward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-2)).sum(-1) + backward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-1)).sum(-2) + + dict_append(debug, "forward_dists", forward_dist) + dict_append(debug, "backward_dists", backward_dist) + dict_append(debug, "precedence_weights", self.precedence_weighting) + + # output shapes [ batch, n_heads, cell_count ] + return forward_dist, backward_dist + + +class ReadHead(torch.nn.Module): + def __init__( + self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False + ): + super(ReadHead, self).__init__() + self.content_addr_generator = ContentAddressGenerator( + disable_content_norm=disable_content_norm, + mask_min=mask_min, + disable_key_masking=disable_key_masking, + ) + self.read_dist = None + self.read_data = None + self.new_sequence() + + def new_sequence(self): + self.read_dist = None + self.read_data = None + + def forward( + self, + memory, + read_content_keys, + read_betas, + forward_dist, + backward_dist, + gates, + read_mask=None, + debug=None, + ): + content_dist = self.content_addr_generator( + memory, read_content_keys, read_betas, mask=read_mask + ) + + self.read_dist = ( + backward_dist * gates[..., 0:1] + + content_dist * gates[..., 1:2] + + forward_dist * gates[..., 2:] + ) + + # memory shape: [ batch, cell count, word_length ] + # read_dist shape: [ batch, n heads, cell count ] + # result shape: [ batch, n_heads, word_length ] + self.read_data = (memory.unsqueeze(1) * self.read_dist.unsqueeze(-1)).sum(-2) + + dict_append(debug, "content_dist", content_dist) + dict_append(debug, "balance", gates) + dict_append(debug, "read_dist", self.read_dist) + dict_append(debug, "read_content_keys", read_content_keys) + if read_mask is not None: + dict_append(debug, "read_mask", read_mask) + dict_append(debug, "read_betas", read_betas.unsqueeze(-2)) + if read_mask is not None: + dict_append(debug, "read_mask", read_mask) + + return self.read_data + + +class RawReadHead(torch.nn.Module): + def __init__( + self, + n_heads, + word_length, + use_mask=False, + disable_content_norm=False, + mask_min=0.0, + disable_key_masking=False, + ): + super(RawReadHead, self).__init__() + self.read_head = ReadHead( + disable_content_norm=disable_content_norm, + mask_min=mask_min, + disable_key_masking=disable_key_masking, + ) + self.n_heads = n_heads + self.word_length = word_length + self.use_mask = use_mask + self.input_size = self.n_heads * ( + self.word_length * (2 if use_mask else 1) + 3 + 1 + ) + + def get_prev_dist(self, memory): + if self.read_head.read_dist is not None: + return self.read_head.read_dist + else: + m_shape = memory.size() + return torch.zeros(m_shape[0], self.n_heads, m_shape[1]).to(memory) + + def get_prev_data(self, memory): + if self.read_head.read_data is not None: + return self.read_head.read_data + else: + m_shape = memory.size() + return torch.zeros(m_shape[0], self.n_heads, m_shape[-1]).to(memory) + + def new_sequence(self): + self.read_head.new_sequence() + + def forward(self, memory, nn_output, forward_dist, backward_dist, debug): + shapes = [[self.n_heads, self.word_length]] * (2 if self.use_mask else 1) + [ + [self.n_heads], + [self.n_heads, 3], + ] + tensors = split_tensor(nn_output, shapes) + + if self.use_mask: + read_mask = torch.sigmoid(tensors[0]) + tensors = tensors[1:] + else: + read_mask = None + + keys, betas, gates = tensors + + betas = oneplus(betas) + gates = F.softmax(gates, gates.dim() - 1) + + return self.read_head( + memory, + keys, + betas, + forward_dist, + backward_dist, + gates, + debug=debug, + read_mask=read_mask, + ) + + def get_neural_input_size(self): + return self.input_size + + +class DistSharpnessEnhancer(torch.nn.Module): + def __init__(self, n_heads): + super(DistSharpnessEnhancer, self).__init__() + self.n_heads = n_heads if isinstance(n_heads, list) else [n_heads] + self.n_data = sum(self.n_heads) + + def forward(self, nn_input, *dists): + assert len(dists) == len(self.n_heads) + nn_input = oneplus(nn_input[..., : self.n_data]) + factors = split_tensor(nn_input, self.n_heads) + + res = [] + for i, d in enumerate(dists): + s = list(d.size()) + ndim = d.dim() + f = factors[i] + if ndim == 2: + assert self.n_heads[i] == 1 + elif ndim == 3: + f = f.unsqueeze(-1) + else: + assert False + + d += _EPS + d = d / d.max(dim=-1, keepdim=True)[0] + d = d.pow(f) + d = d / d.sum(dim=-1, keepdim=True) + res.append(d) + return res + + def get_neural_input_size(self): + return self.n_data + + +class DNC(torch.nn.Module): + def __init__( + self, + input_size, + output_size, + word_length, + cell_count, + n_read_heads, + controller, + batch_first=False, + clip_controller=20, + bias=True, + mask=False, + dealloc_content=True, + link_sharpness_control=True, + disable_content_norm=False, + mask_min=0.0, + disable_key_masking=False, + ): + super(DNC, self).__init__() + + self.clip_controller = clip_controller + + self.read_head = RawReadHead( + n_read_heads, + word_length, + use_mask=mask, + disable_content_norm=disable_content_norm, + mask_min=mask_min, + disable_key_masking=disable_key_masking, + ) + self.write_head = RawWriteHead( + n_read_heads, + word_length, + use_mask=mask, + dealloc_content=dealloc_content, + disable_content_norm=disable_content_norm, + mask_min=mask_min, + disable_key_masking=disable_key_masking, + ) + self.temporal_link = TemporalMemoryLinkage() + self.sharpness_control = ( + DistSharpnessEnhancer([n_read_heads, n_read_heads]) + if link_sharpness_control + else None + ) + + in_size = input_size + n_read_heads * word_length + control_channels = ( + self.read_head.get_neural_input_size() + + self.write_head.get_neural_input_size() + + ( + self.sharpness_control.get_neural_input_size() + if self.sharpness_control is not None + else 0 + ) + ) + + self.controller = controller + controller.init(in_size) + self.controller_to_controls = torch.nn.Linear( + controller.get_output_size(), control_channels, bias=bias + ) + self.controller_to_out = torch.nn.Linear( + controller.get_output_size(), output_size, bias=bias + ) + self.read_to_out = torch.nn.Linear( + word_length * n_read_heads, output_size, bias=bias + ) + + self.cell_count = cell_count + self.word_length = word_length + + self.memory = None + self.reset_parameters() + + self.batch_first = batch_first + self.zero_mem_tensor = None + + self.mem_state = None + + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + + def reset_parameters(self): + linear_reset(self.controller_to_controls) + linear_reset(self.controller_to_out) + linear_reset(self.read_to_out) + self.controller.reset_parameters() + + def _step(self, in_data, debug, demon, rollout_storage): + init_debug(debug, {"read_head": {}, "write_head": {}, "temporal_links": {}}) + + # input shape: [ batch, channels ] + batch_size = in_data.size(0) + + # # run the demon if it is used + if demon: + # Running policy_old: + demon_action = demon.select_action( + torch.cat([in_data, self.memory.view(batch_size, -1)], -1), + rollout_storage, + ) + in_data = in_data + demon_action + + demon_action = None + + # run the controller + prev_read_data = self.read_head.get_prev_data(self.memory).view( + [batch_size, -1] + ) + + control_data = self.controller(torch.cat([in_data, prev_read_data], -1)) + + # memory ops + controls = self.controller_to_controls(control_data).contiguous() + controls = ( + controls.clamp(-self.clip_controller, self.clip_controller) + if self.clip_controller is not None + else controls + ) + + shapes = [ + [self.write_head.get_neural_input_size()], + [self.read_head.get_neural_input_size()], + ] + if self.sharpness_control is not None: + shapes.append(self.sharpness_control.get_neural_input_size()) + + tensors = split_tensor(controls, shapes) + + write_head_control, read_head_control = tensors[:2] + tensors = tensors[2:] + + prev_read_dist = self.read_head.get_prev_dist(self.memory) + + self.memory = self.write_head( + demon_action, + self.memory, + write_head_control, + prev_read_dist, + debug=dict_get(debug, "write_head"), + ) + + prev_write = self.write_head.get_prev_write() + forward_dist, backward_dist = self.temporal_link( + prev_write["write_dist"] if prev_write is not None else None, + prev_read_dist, + debug=dict_get(debug, "temporal_links"), + ) + + if self.sharpness_control is not None: + forward_dist, backward_dist = self.sharpness_control( + tensors[0], forward_dist, backward_dist + ) + + read_data = self.read_head( + self.memory, + read_head_control, + forward_dist, + backward_dist, + debug=dict_get(debug, "read_head"), + ) + + # output: + return self.controller_to_out(control_data) + self.read_to_out( + read_data.view(batch_size, -1) + ) + + def _mem_init(self, batch_size, device): + if self.zero_mem_tensor is None or self.zero_mem_tensor.size(0) != batch_size: + self.zero_mem_tensor = torch.zeros( + batch_size, self.cell_count, self.word_length + ).to(device) + + self.memory = self.zero_mem_tensor + + if self.mem_state is None: + self.mem_state = [] + + def forward(self, in_data, debug=None, demon=None, rollout_storage=None): + self.write_head.new_sequence() + self.read_head.new_sequence() + self.temporal_link.new_sequence() + self.controller.new_sequence() + + self._mem_init(in_data.size(0 if self.batch_first else 1), in_data.device) + + out_tsteps = [] + + if self.batch_first: + # input format: batch, time, channels + for t in range(in_data.size(1)): + out_tsteps.append( + self._step(in_data[:, t], debug, demon, rollout_storage) + ) + self.mem_state.append(self.memory.view(in_data.size(0), -1)) + else: + # input format: time, batch, channels + for t in range(in_data.size(0)): + out_tsteps.append(self._step(in_data[t], debug, demon, rollout_storage)) + self.mem_state.append(self.memory.view(-1, in_data.size(0))) + + merge_debug_tensors(debug, dim=1 if self.batch_first else 0) + return torch.stack(out_tsteps, dim=1 if self.batch_first else 0) + + +class LSTMController(torch.nn.Module): + def __init__(self, layer_sizes, out_from_all_layers=True): + super(LSTMController, self).__init__() + self.out_from_all_layers = out_from_all_layers + self.layer_sizes = layer_sizes + self.states = None + self.outputs = None + + def new_sequence(self): + self.states = [None] * len(self.layer_sizes) + self.outputs = [None] * len(self.layer_sizes) + + def reset_parameters(self): + def init_layer(l, index): + size = self.layer_sizes[index] + # Initialize all matrices to sigmoid, just data input to tanh + a = math.sqrt(3.0) * self.stdevs[i] + l.weight.data[0:-size].uniform_(-a, a) + a *= init.calculate_gain("tanh") + l.weight.data[-size:].uniform_(-a, a) + if l.bias is not None: + l.bias.data[self.layer_sizes[i] :].fill_(0) + # init forget gate to large number. + l.bias.data[: self.layer_sizes[i]].fill_(1) + + # xavier init merged input weights + for i in range(len(self.layer_sizes)): + init_layer(self.in_to_all[i], i) + init_layer(self.out_to_all[i], i) + if i > 0: + init_layer(self.prev_to_all[i - 1], i) + + def _add_modules(self, name, m_list): + for i, m in enumerate(m_list): + self.add_module("%s_%d" % (name, i), m) + + def init(self, input_size): + self.layer_sizes = self.layer_sizes + + # Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big. + # So use xavier init according to this. + self.input_sizes = [ + (self.layer_sizes[i - 1] if i > 0 else 0) + self.layer_sizes[i] + input_size + for i in range(len(self.layer_sizes)) + ] + self.stdevs = [ + math.sqrt(2.0 / (self.layer_sizes[i] + self.input_sizes[i])) + for i in range(len(self.layer_sizes)) + ] + self.in_to_all = [ + torch.nn.Linear(input_size, 4 * self.layer_sizes[i]) + for i in range(len(self.layer_sizes)) + ] + self.out_to_all = [ + torch.nn.Linear(self.layer_sizes[i], 4 * self.layer_sizes[i], bias=False) + for i in range(len(self.layer_sizes)) + ] + self.prev_to_all = [ + torch.nn.Linear( + self.layer_sizes[i - 1], 4 * self.layer_sizes[i], bias=False + ) + for i in range(1, len(self.layer_sizes)) + ] + + self._add_modules("in_to_all", self.in_to_all) + self._add_modules("out_to_all", self.out_to_all) + self._add_modules("prev_to_all", self.prev_to_all) + + self.reset_parameters() + + def get_output_size(self): + return ( + sum(self.layer_sizes) if self.out_from_all_layers else self.layer_sizes[-1] + ) + + def forward(self, data): + for i, size in enumerate(self.layer_sizes): + d = self.in_to_all[i](data) + if self.outputs[i] is not None: + d += self.out_to_all[i](self.outputs[i]) + if i > 0: + d += self.prev_to_all[i - 1](self.outputs[i - 1]) + + input_data = torch.tanh(d[..., -size:]) + forget_gate, input_gate, output_gate = torch.sigmoid(d[..., :-size]).chunk( + 3, dim=-1 + ) + + state_update = input_gate * input_data + + if self.states[i] is not None: + self.states[i] = self.states[i] * forget_gate + state_update + else: + self.states[i] = state_update + + self.outputs[i] = output_gate * torch.tanh(self.states[i]) + + return ( + torch.cat(self.outputs, -1) + if self.out_from_all_layers + else self.outputs[-1] + ) + + +class FeedforwardController(torch.nn.Module): + def __init__(self, layer_sizes=[]): + super(FeedforwardController, self).__init__() + self.layer_sizes = layer_sizes + + def new_sequence(self): + pass + + def reset_parameters(self): + for module in self.model: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + + def get_output_size(self): + return self.layer_sizes[-1] + + def init(self, input_size): + self.layer_sizes = self.layer_sizes + + # Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big. + # So use xavier init according to this. + self.input_sizes = [input_size] + self.layer_sizes[:-1] + + layers = [] + for i, size in enumerate(self.layer_sizes): + layers.append(torch.nn.Linear(self.input_sizes[i], self.layer_sizes[i])) + layers.append(torch.nn.ReLU()) + self.model = torch.nn.Sequential(*layers) + self.reset_parameters() + + def forward(self, data): + return self.model(data) diff --git a/Models/Demon.py b/Models/Demon.py new file mode 100644 index 0000000..13ded38 --- /dev/null +++ b/Models/Demon.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +from torch.distributions import Normal + +from collections import namedtuple + +LOG_SIG_MAX = 2 +LOG_SIG_MIN = -20 +EPSILON = 1e-6 + +SavedAction = namedtuple("SavedAction", ["action", "log_prob", "mean"]) + + +def linear_reset(module, gain=1.0): + assert isinstance(module, torch.nn.Linear) + init.xavier_uniform_(module.weight, gain=gain) + s = module.weight.size(1) + if module.bias is not None: + module.bias.data.zero_() + + +class ZNet(nn.Module): + def __init__(self): + super(ZNet, self).__init__() + + def reset_parameters(self): + for module in self.lstm: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + + for module in self.hidden2z: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + + def init(self, input_size): + self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True)) + self.hidden2z = nn.Sequential(nn.Linear(32, 1)) + self.reset_parameters() + + def forward(self, data): + output, (hn, cn) = self.lstm(data) + zvals = self.hidden2z(output) + return F.softplus(zvals) + + +class FNet(nn.Module): + def __init__(self): + super(FNet, self).__init__() + + def reset_parameters(self): + for module in self.lstm: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + + for module in self.hidden2z: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + + def init(self, input_size): + self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True)) + self.hidden2z = nn.Sequential(nn.Linear(32, 1)) + self.reset_parameters() + + def forward(self, data): + output, (hn, cn) = self.lstm(data) + output = F.elu(output) + fvals = self.hidden2z(output) + return fvals + + +class Demon(torch.nn.Module): + """ + Demon manipulates the external memory of DNC. + """ + + def __init__(self, layer_sizes=[]): + super(Demon, self).__init__() + self.layer_sizes = layer_sizes + self.action_scale = torch.tensor(1) + self.action_bias = torch.tensor(0.0) + self.saved_actions = [] + + def get_output_size(self): + return self.layer_sizes[-1] + + def reset_parameters(self): + for module in self.model: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + linear_reset(self.embed_mean, gain=init.calculate_gain("relu")) + linear_reset(self.embed_log_std, gain=init.calculate_gain("relu")) + + def init(self, input_size, output_size): + # Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big. + # So use xavier init according to this. + self.input_sizes = [input_size] + self.layer_sizes[:-1] + layers = [] + for i, size in enumerate(self.layer_sizes): + layers.append(nn.Linear(self.input_sizes[i], self.layer_sizes[i])) + layers.append(nn.ReLU()) + + self.model = nn.Sequential(*layers) + self.embed_mean = nn.Linear(self.layer_sizes[-1], output_size) + self.embed_log_std = nn.Linear(self.layer_sizes[-1], output_size) + + self.reset_parameters() + + def forward(self, data): + x = self.model(data) + x = F.relu(x) + mean, log_std = self.embed_mean(x), self.embed_log_std(x) + log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) + std = torch.exp(log_std) + return mean, std + + def act(self, data): + """ + pathwise derivative estimator for taking actions. + :param data: + :return: + """ + mean, std = self.forward(data) + normal = Normal(mean, std) + x = normal.rsample() + + y = torch.softmax(x, dim=1) + + action = y * self.action_scale + self.action_bias + log_prob = normal.log_prob(action) + # Enforcing Action Bound + log_prob -= torch.log(self.action_scale * (1 - y.pow(2)) + EPSILON) + log_prob = log_prob.sum(1, keepdim=True) + + mean = torch.softmax(mean, dim=1) * self.action_scale + self.action_bias + self.saved_actions.append(SavedAction(action, log_prob, mean)) + + return mean diff --git a/Models/Information_Agents.py b/Models/Information_Agents.py new file mode 100644 index 0000000..2abbbd4 --- /dev/null +++ b/Models/Information_Agents.py @@ -0,0 +1,238 @@ +import torch +import torch.nn as nn +from torch.distributions import MultivariateNormal +import torch.nn.functional as F +import torch.nn.init as init + +import numpy as np + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class RolloutStorage: + def __init__(self): + self.actions = [] + self.states = [] + self.logprobs = [] + self.rewards = [] + self.is_terminals = [] + + def clear_storage(self): + del self.actions[:] + del self.states[:] + del self.logprobs[:] + del self.rewards[:] + del self.is_terminals[:] + + +class ActorCritic(nn.Module): + def __init__(self, state_dim, action_dim, action_std): + super(ActorCritic, self).__init__() + self.actor = nn.Sequential( + nn.Linear(state_dim, 64), + nn.Tanh(), + nn.Linear(64, 32), + nn.Tanh(), + nn.Linear(32, action_dim), + nn.Softmax(dim=1), + ) + + # critic + self.critic = nn.Sequential( + nn.Linear(state_dim, 64), + nn.Tanh(), + nn.Linear(64, 32), + nn.Tanh(), + nn.Linear(32, 1), + ) + self.action_var = torch.full((action_dim,), action_std * action_std).to(device) + + def forward(self): + raise NotImplementedError + + def act(self, state, rollout_storage): + action_mean = self.actor(state) + cov_mat = torch.diag(self.action_var).to(device) + + dist = MultivariateNormal(action_mean, cov_mat) + action = dist.sample() + action_logprob = dist.log_prob(action) + + if rollout_storage: + rollout_storage.states.append(state) + rollout_storage.actions.append(action) + rollout_storage.logprobs.append(action_logprob) + + return action.detach() + + def evaluate(self, state, action): + action_mean = self.actor(state) + + action_var = self.action_var.expand_as(action_mean) + cov_mat = torch.diag_embed(action_var).to(device) + + dist = MultivariateNormal(action_mean, cov_mat) + + action_logprobs = dist.log_prob(action) + dist_entropy = dist.entropy() + state_value = self.critic(state) + + return action_logprobs, torch.squeeze(state_value), dist_entropy + + +class Demon: + def __init__( + self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip + ): + self.lr = lr + self.betas = betas + self.gamma = gamma + self.eps_clip = eps_clip + self.K_epochs = K_epochs + + self.policy = ActorCritic(state_dim, action_dim, action_std).to(device) + self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas) + + self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device) + self.policy_old.load_state_dict(self.policy.state_dict()) + + self.MseLoss = nn.MSELoss() + + def select_action(self, state, rollout_storage): + return self.policy_old.act(state, rollout_storage) + + def update(self, rollout_storage): + # Monte Carlo estimate of rewards: + rewards = [] + discounted_reward = 0 + for reward in reversed(rollout_storage.rewards): + discounted_reward = reward + (self.gamma * discounted_reward) + rewards.insert(0, discounted_reward) + + # Normalizing the rewards: + rewards = torch.stack(rewards) + rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5) + rewards = rewards.squeeze(-1) + + # convert list to tensor + old_states = torch.squeeze( + torch.stack(rollout_storage.states).to(device), 1 + ).detach() + old_actions = torch.squeeze( + torch.stack(rollout_storage.actions).to(device), 1 + ).detach() + old_logprobs = ( + torch.squeeze(torch.stack(rollout_storage.logprobs), 1).to(device).detach() + ) + + # Optimize policy for K epochs: + for _ in range(self.K_epochs): + # Evaluating old actions and values : + logprobs, state_values, dist_entropy = self.policy.evaluate( + old_states, old_actions + ) + + # Finding the ratio (pi_theta / pi_theta__old): + ratios = torch.exp(logprobs - old_logprobs.detach()) + + try: + state_values = state_values[ + :-1, : + ] # reward is computed as the mutual info between consequenct mem state, + # therefore n-1 values only. + ratios = ratios[:-1, :] # the same for ratio + dist_entropy = dist_entropy[:-1, :] # the same for entropy + advantages = rewards - state_values.detach() + surr1 = ratios * advantages + surr2 = ( + torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) + * advantages + ) + # Finding Surrogate Loss: + loss = ( + -torch.min(surr1, surr2) + + 0.5 * self.MseLoss(state_values, rewards) + - 0.01 * dist_entropy + ) + # take gradient step + self.optimizer.zero_grad() + loss.mean().backward() + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5) + self.optimizer.step() + except Exception: + # Do thing for the sequences of lentgh 1. + loss = torch.zeros_like(rewards).to(device) + continue + + # Copy new weights into old policy: + self.policy_old.load_state_dict(self.policy.state_dict()) + return loss + + +############################################ +# Mutual information Estimator Network###### +############################################ + + +def linear_reset(module, gain=1.0): + assert isinstance(module, torch.nn.Linear) + init.xavier_uniform_(module.weight, gain=gain) + s = module.weight.size(1) + if module.bias is not None: + module.bias.data.zero_() + + +class FNet(nn.Module): + """ + Monte-Carlo estimators for Mutual Information Known as MINE. + Mine produces estimates that are neither an upper or lower bound on MI. + Other ZNet can be Introduced to address the problem of building bounds with finite samples (unlike Monte Carlo) + """ + + def __init__(self): + super(FNet, self).__init__() + + def reset_parameters(self): + for module in self.lstm: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + + for module in self.hidden2f: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + + def init(self, input_size): + self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True)) + self.hidden2f = nn.Sequential(nn.Linear(32, 1)) + self.reset_parameters() + + def forward(self, data): + output, (hn, cn) = self.lstm(data) + output = F.elu(output) + fvals = self.hidden2f(output) + return fvals + + +class ZNet(nn.Module): + def __init__(self): + super(ZNet, self).__init__() + + def reset_parameters(self): + for module in self.lstm: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + + for module in self.hidden2z: + if isinstance(module, torch.nn.Linear): + linear_reset(module, gain=init.calculate_gain("relu")) + + def init(self, input_size): + self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True)) + self.hidden2z = nn.Sequential(nn.Linear(32, 1)) + self.reset_parameters() + + def forward(self, data): + output, (hn, cn) = self.lstm(data) + output = F.elu(output) + zvals = self.hidden2z(output) + return F.softplus(zvals) diff --git a/Utils/.DS_Store b/Utils/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/Utils/.DS_Store differ diff --git a/Utils/ArgumentParser.py b/Utils/ArgumentParser.py new file mode 100644 index 0000000..438a5cf --- /dev/null +++ b/Utils/ArgumentParser.py @@ -0,0 +1,167 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import os +import json +import argparse + + +class ArgumentParser: + class Parsed: + pass + + _type = type + + @staticmethod + def str_or_none(none_string="none"): + def parse(s): + return None if s.lower()==none_string else s + return parse + + @staticmethod + def list_or_none(none_string="none", type=int): + def parse(s): + return None if s.lower() == none_string else [type(a) for a in s.split(",") if a] + return parse + + @staticmethod + def _merge_args(args, new_args, arg_schemas): + for name, val in new_args.items(): + old = args.get(name) + if old is None: + args[name] = val + else: + args[name] = arg_schemas[name]["updater"](old, val) + + class Profile: + def __init__(self, name, args=None, include=[]): + assert not (args is None and not include), "One of args or include must be defined" + self.name = name + self.args = args + if not isinstance(include, list): + include=[include] + self.include = include + + def get_args(self, arg_schemas, profile_by_name): + res = {} + + for n in self.include: + p = profile_by_name.get(n) + assert p is not None, "Included profile %s doesn't exists" % n + + ArgumentParser._merge_args(res, p.get_args(arg_schemas, profile_by_name), arg_schemas) + + ArgumentParser._merge_args(res, self.args, arg_schemas) + return res + + + def __init__(self, description=None): + self.parser = argparse.ArgumentParser(description=description) + self.loaded = {} + self.profiles = {} + self.args = {} + self.raw = None + self.parsed = None + self.parser.add_argument("-profile", type=str, help="Pre-defined profiles.") + + def add_argument(self, name, type=None, default=None, help="", save=True, parser=lambda x:x, updater=lambda old, new:new): + assert name not in ["profile"], "Argument name %s is reserved" % name + assert not (type is None and default is None), "Either type or default must be given" + + if type is None: + type = ArgumentParser._type(default) + + self.parser.add_argument(name, type=int if type==bool else type, default=None, help=help) + if name[0] == '-': + name = name[1:] + + self.args[name] = { + "type": type, + "default": int(default) if type==bool else default, + "save": save, + "parser": parser, + "updater": updater + } + + def add_profile(self, prof): + if isinstance(prof, list): + for p in prof: + self.add_profile(p) + else: + self.profiles[prof.name] = prof + + def do_parse_args(self, loaded={}): + self.raw = self.parser.parse_args() + + profile = {} + if self.raw.profile: + assert not loaded, "Loading arguments from file, but profile given." + for pr in self.raw.profile.split(","): + p = self.profiles.get(pr) + assert p is not None, "Invalid profile: %s. Valid profiles: %s" % (pr, self.profiles.keys()) + p = p.get_args(self.args, self.profiles) + self._merge_args(profile, p, self.args) + + for k, v in self.raw.__dict__.items(): + if k in ["profile"]: + continue + + if v is None: + if k in loaded and self.args[k]["save"]: + self.raw.__dict__[k] = loaded[k] + else: + self.raw.__dict__[k] = profile.get(k, self.args[k]["default"]) + + self.parsed = ArgumentParser.Parsed() + self.parsed.__dict__.update({k: self.args[k]["parser"](self.args[k]["type"](v)) if v is not None else None + for k,v in self.raw.__dict__.items() if k in self.args}) + + return self.parsed + + def parse_or_cache(self): + if self.parsed is None: + self.do_parse_args() + + def parse(self): + self.parse_or_cache() + return self.parsed + + def save(self, fname): + self.parse_or_cache() + with open(fname, 'w') as outfile: + json.dump(self.raw.__dict__, outfile, indent=4) + return True + + def load(self, fname): + if os.path.isfile(fname): + map = {} + with open(fname, "r") as data_file: + map = json.load(data_file) + + self.do_parse_args(map) + return self.parsed + + def sync(self, fname, save=True): + if os.path.isfile(fname): + self.load(fname) + + if save: + dir = os.path.dirname(fname) + if not os.path.isdir(dir): + os.makedirs(dir) + + self.save(fname) + return self.parsed diff --git a/Utils/Collate.py b/Utils/Collate.py new file mode 100644 index 0000000..3b593d2 --- /dev/null +++ b/Utils/Collate.py @@ -0,0 +1,75 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import numpy as np +import torch +from operator import mul +from functools import reduce + +class VarLengthCollate: + def __init__(self, ignore_symbol=0): + self.ignore_symbol = ignore_symbol + + def _measure_array_max_dim(self, batch): + s=list(batch[0].size()) + different=[False] * len(s) + for i in range(1, len(batch)): + ns = batch[i].size() + different = [different[j] or s[j]!=ns[j] for j in range(len(s))] + s=[max(s[j], ns[j]) for j in range(len(s))] + return s, different + + def _merge_var_len_array(self, batch): + max_size, different = self._measure_array_max_dim(batch) + s=[len(batch)] + max_size + storage = batch[0].storage()._new_shared(reduce(mul, s, 1)) + out = batch[0].new(storage).view(s).fill_(self.ignore_symbol) + for i, d in enumerate(batch): + this_o = out[i] + for j, diff in enumerate(different): + if different[j]: + this_o = this_o.narrow(j,0,d.size(j)) + this_o.copy_(d) + return out + + + def __call__(self, batch): + if isinstance(batch[0], dict): + return {k: self([b[k] for b in batch]) for k in batch[0].keys()} + elif isinstance(batch[0], np.ndarray): + return self([torch.from_numpy(a) for a in batch]) + elif torch.is_tensor(batch[0]): + return self._merge_var_len_array(batch) + else: + assert False, "Unknown type: %s" % type(batch[0]) + +class MetaCollate: + def __init__(self, meta_name="meta", collate=VarLengthCollate()): + self.meta_name = meta_name + self.collate = collate + + def __call__(self, batch): + if isinstance(batch[0], dict): + meta = [b[self.meta_name] for b in batch] + batch = [{k: v for k,v in b.items() if k!=self.meta_name} for b in batch] + else: + meta = None + + res = self.collate(batch) + if meta is not None: + res[self.meta_name] = meta + + return res \ No newline at end of file diff --git a/Utils/Debug.py b/Utils/Debug.py new file mode 100644 index 0000000..779ecaf --- /dev/null +++ b/Utils/Debug.py @@ -0,0 +1,133 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import numpy as np +import sys +import traceback +import torch + +enableDebug = False + +def nan_check(arg, name=None, force=False): + if not enableDebug and not force: + return arg + is_nan = False + curr_nan = False + if isinstance(arg, torch.autograd.Variable): + curr_nan = not np.isfinite(arg.sum().cpu().data.numpy()) + elif isinstance(arg, torch.nn.parameter.Parameter): + curr_nan = (not np.isfinite(arg.sum().cpu().data.numpy())) or (not np.isfinite(arg.grad.sum().cpu().data.numpy())) + elif isinstance(arg, float): + curr_nan = not np.isfinite(arg) + elif isinstance(arg, (list, tuple)): + for a in arg: + nan_check(a) + else: + assert False, "Unsupported type %s" % type(arg) + + if curr_nan: + if sys.exc_info()[0] is not None: + trace = str(traceback.format_exc()) + else: + trace = "".join(traceback.format_stack()) + + print(arg) + if name is not None: + print("NaN found in %s." % name) + else: + print("NaN found.") + if isinstance(arg, torch.autograd.Variable): + print(" Argument is a torch tensor. Shape: %s" % list(arg.size())) + + print(trace) + sys.exit(-1) + + return arg + + +def assert_range(t, min=0.0, max=1.0): + if not enableDebug: + return + + if t.min().cpu().data.numpy()max: + print(t) + assert False + + +def assert_dist(t, use_lower_limit=True): + if not enableDebug: + return + + assert_range(t) + + if t.sum(-1).max().cpu().data.numpy()>1.001: + print("MAT:", t) + print("SUM:", t.sum(-1)) + assert False + + if use_lower_limit and t.sum(-1).max().cpu().data.numpy()<0.999: + print(t) + print("SUM:", t.sum(-1)) + assert False + + +def print_stat(name, t): + if not enableDebug: + return + + min = t.min().cpu().data.numpy() + max = t.max().cpu().data.numpy() + mean = t.mean().cpu().data.numpy() + + print("%s: min: %g, mean: %g, max: %g" % (name, min, mean, max)) + + +def dbg_print(*things): + if not enableDebug: + return + print(*things) + +class GradPrinter(torch.autograd.Function): + @staticmethod + def forward(ctx, a): + return a + + @staticmethod + def backward(ctx, g): + print("Grad (print_grad): ", g[0]) + return g + +def print_grad(t): + return GradPrinter.apply(t) + +def assert_equal(t1, ref, limit=1e-5, force=True): + if not (enableDebug or force): + return + + assert t1.shape==ref.shape, "Tensor shapes differ: got %s, ref %s" % (t1.shape, ref.shape) + norm = ref.abs().sum() / ref.nonzero().sum().float() + threshold = norm * limit + + errcnt = ((t1 - ref).abs() > threshold).sum() + if errcnt > 0: + print("Tensors differ. (max difference: %g, norm %f). No of errors: %d of %d" % + ((t1 - ref).abs().max().item(), norm, errcnt, t1.numel())) + print("---------------------------------------------Out-----------------------------------------------") + print(t1) + print("---------------------------------------------Ref-----------------------------------------------") + print(ref) + print("-----------------------------------------------------------------------------------------------") + assert False \ No newline at end of file diff --git a/Utils/Helpers.py b/Utils/Helpers.py new file mode 100644 index 0000000..4868150 --- /dev/null +++ b/Utils/Helpers.py @@ -0,0 +1,24 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import torch +import torch.autograd + +def as_numpy(data): + if isinstance(data, (torch.Tensor, torch.autograd.Variable)): + return data.detach().cpu().numpy() + else: + return data diff --git a/Utils/Index.py b/Utils/Index.py new file mode 100644 index 0000000..3813caf --- /dev/null +++ b/Utils/Index.py @@ -0,0 +1,20 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +def index_by_dim(arr, dim, i_start, i_end=None): + if dim<0: + dim += arr.ndim + return arr[tuple([slice(None,None)] * dim + [slice(i_start, i_end) if i_end is not None else i_start])] \ No newline at end of file diff --git a/Utils/Process.py b/Utils/Process.py new file mode 100644 index 0000000..12fe617 --- /dev/null +++ b/Utils/Process.py @@ -0,0 +1,51 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import sys +import ctypes +import subprocess +import os + +def run(cmd, hide_stderr = True): + libc_search_dirs = ["/lib", "/lib/x86_64-linux-gnu", "/lib/powerpc64le-linux-gnu"] + + if sys.platform=="linux" : + found = None + for d in libc_search_dirs: + file = os.path.join(d, "libc.so.6") + if os.path.isfile(file): + found = file + break + + if not found: + print("WARNING: Cannot find libc.so.6. Cannot kill process when parent dies.") + killer = None + else: + libc = ctypes.CDLL(found) + PR_SET_PDEATHSIG = 1 + KILL = 9 + killer = lambda: libc.prctl(PR_SET_PDEATHSIG, KILL) + else: + print("WARNING: OS not linux. Cannot kill process when parent dies.") + killer = None + + + if hide_stderr: + stderr = open(os.devnull,'w') + else: + stderr = None + + return subprocess.Popen(cmd.split(" "), stderr=stderr, preexec_fn=killer) diff --git a/Utils/Profile.py b/Utils/Profile.py new file mode 100644 index 0000000..60eb546 --- /dev/null +++ b/Utils/Profile.py @@ -0,0 +1,48 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import atexit + +ENABLED=False + +_profiler = None + + +def construct(): + global _profiler + if not ENABLED: + return + + if _profiler is None: + from line_profiler import LineProfiler + _profiler = LineProfiler() + + +def do_profile(follow=[]): + construct() + def inner(func): + if _profiler is not None: + _profiler.add_function(func) + for f in follow: + _profiler.add_function(f) + _profiler.enable_by_count() + return func + return inner + +@atexit.register +def print_prof(): + if _profiler is not None: + _profiler.print_stats() diff --git a/Utils/Saver.py b/Utils/Saver.py new file mode 100644 index 0000000..99f6984 --- /dev/null +++ b/Utils/Saver.py @@ -0,0 +1,233 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import torch +import os +import inspect +import time + + +class SaverElement: + def save(self): + raise NotImplementedError + + def load(self, saved_state): + raise NotImplementedError + + +class CallbackSaver(SaverElement): + def __init__(self, save_fn, load_fn): + super().__init__() + self.save = save_fn + self.load = load_fn + + +class StateSaver(SaverElement): + def __init__(self, model): + super().__init__() + self._model = model + + def load(self, state): + try: + self._model.load_state_dict(state) + except Exception as e: + if hasattr(self._model, "named_parameters"): + names = set([n for n, _ in self._model.named_parameters()]) + loaded = set(self._model.keys()) + if names!=loaded: + d = loaded.difference(names) + if d: + print("Loaded, but not in model: %s" % list(d)) + d = names.difference(loaded) + if d: + print("In model, but not loaded: %s" % list(d)) + if isinstance(self._model, torch.optim.Optimizer): + print("WARNING: optimizer parameters not loaded!") + else: + raise e + + def save(self): + return self._model.state_dict() + + +class GlobalVarSaver(SaverElement): + def __init__(self, name): + caller_frame = inspect.getouterframes(inspect.currentframe())[1] + self._vars = caller_frame.frame.f_globals + self._name = name + + def load(self, state): + self._vars.update({self._name: state}) + + def save(self): + return self._vars[self._name] + + +class PyObjectSaver(SaverElement): + def __init__(self, obj): + self._obj = obj + + def load(self, state): + def _load(target, state): + if isinstance(target, dict): + for k, v in state.items(): + target[k] = _load(target.get(k), v) + elif isinstance(target, list): + if len(target)!=len(state): + target.clear() + for v in state: + target.append(v) + else: + for i, v in enumerate(state): + target[i] = _load(target[i], v) + + elif hasattr(target, "__dict__"): + _load(target.__dict__, state) + else: + return state + return target + + _load(self._obj, state) + + def save(self): + def _save(target): + if isinstance(target, dict): + res = {k: _save(v) for k, v in target.items()} + elif isinstance(target, list): + res = [_save(v) for v in target] + elif hasattr(target, "__dict__"): + res = {k: _save(v) for k, v in target.__dict__.items()} + else: + res = target + + return res + + return _save(self._obj) + + @staticmethod + def obj_supported(obj): + return isinstance(obj, (list, dict)) or hasattr(obj, "__dict__") + + +class Saver: + def __init__(self, dir, short_interval, keep_every_n_hours=4): + self.savers = {} + self.short_interval = short_interval + os.makedirs(dir, exist_ok=True) + self.dir = dir + self._keep_every_n_seconds = keep_every_n_hours * 3600 + + def register(self, name, saver): + assert name not in self.savers, "Saver %s already registered" % name + + if isinstance(saver, SaverElement): + self.savers[name] = saver + elif hasattr(saver, "state_dict") and callable(saver.state_dict): + self.savers[name] = StateSaver(saver) + elif PyObjectSaver.obj_supported(saver): + self.savers[name] = PyObjectSaver(saver) + else: + assert "Unsupported thing to save: %s" % type(saver) + + def __setitem__(self, key, value): + self.register(key, value) + + def write(self, iter): + fname = os.path.join(self.dir, self.model_name_from_index(iter)) + print("Saving %s" % fname) + + state = {} + for name, fns in self.savers.items(): + state[name] = fns.save() + + torch.save(state, fname) + print("Saved.") + + self._cleanup() + + def tick(self, iter): + if iter % self.short_interval != 0: + return + + self.write(iter) + + @staticmethod + def model_name_from_index(index): + return "model-%d.pth" % index + + @staticmethod + def get_checkpoint_index_list(dir): + return list(reversed(sorted( + [int(fn.split(".")[0].split("-")[-1]) for fn in os.listdir(dir) if fn.split(".")[-1] == "pth"]))) + + @staticmethod + def get_ckpts_in_time_window(dir, time_window_s, index_list=None): + if index_list is None: + index_list = Saver.get_checkpoint_index_list(dir) + + + now = time.time() + + res = [] + for i in index_list: + name = Saver.model_name_from_index(i) + mtime = os.path.getmtime(os.path.join(dir, name)) + if now - mtime > time_window_s: + break + + res.append(name) + + return res + + @staticmethod + def load_last_checkpoint(dir): + last_checkpoint = Saver.get_checkpoint_index_list(dir) + + if last_checkpoint: + for index in last_checkpoint: + fname = Saver.model_name_from_index(index) + try: + print("Loading %s" % fname) + data = torch.load(os.path.join(dir, fname)) + except: + print("WARNING: Loading %s failed. Maybe file is corrupted?" % fname) + continue + return data + return None + + def _cleanup(self): + index_list = self.get_checkpoint_index_list(self.dir) + new_files = self.get_ckpts_in_time_window(self.dir, self._keep_every_n_seconds, index_list[2:]) + new_files = new_files[:-1] + + for f in new_files: + os.remove(os.path.join(self.dir, f)) + + def load(self, fname=None): + if fname is None: + state = self.load_last_checkpoint(self.dir) + if not state: + return False + else: + state = torch.load(fname) + + for k,s in state.items(): + if k not in self.savers: + print("WARNING: failed to load state of %s. It doesn't exists." % k) + continue + self.savers[k].load(s) + + return True diff --git a/Utils/Seed.py b/Utils/Seed.py new file mode 100644 index 0000000..ee32095 --- /dev/null +++ b/Utils/Seed.py @@ -0,0 +1,31 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import numpy as np + +seed = None + + +def fix(): + global seed + seed = 0xB0C1FA52 + + +def get_randstate(): + if seed: + return np.random.RandomState(seed) + else: + return np.random.RandomState() diff --git a/Utils/Visdom.py b/Utils/Visdom.py new file mode 100644 index 0000000..ec69fb4 --- /dev/null +++ b/Utils/Visdom.py @@ -0,0 +1,323 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import time +import visdom +import sys +import numpy as np +import os +import socket +from . import Process + +vis = None +port = None +visdom_fail_count = 0 + + +def port_used(port): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex(('127.0.0.1', port)) + if result == 0: + sock.close() + return True + else: + return False + + +def alloc_port(start_from=7000): + while True: + if port_used(start_from): + print("Port already used: %d" % start_from) + start_from += 1 + else: + return start_from + + +def wait_for_port(port, timeout=5): + star_time = time.time() + while not port_used(port): + if time.time() - star_time > timeout: + return False + + time.sleep(0.1) + return True + + +def start(on_port=None): + global vis + global port + global visdom_fail_count + assert vis is None, "Cannot start more than 1 visdom servers." + + if visdom_fail_count>=3: + return + + port = alloc_port() if on_port is None else on_port + + print("Starting Visdom server on %d" % port) + Process.run("%s -m visdom.server -p %d" % (sys.executable, port)) + if not wait_for_port(port): + print("ERROR: failed to start Visdom server. Server not responding.") + visdom_fail_count += 1 + return + print("Done.") + + vis = visdom.Visdom(port=port) + + +def _start_if_not_running(): + if vis is None: + start() + + +def save_heatmap(dir, title, img): + if dir is not None: + fname = os.path.join(dir, title.replace(" ", "_") + ".npy") + d = os.path.dirname(fname) + os.makedirs(d, exist_ok=True) + np.save(fname, img) + + +class Plot2D: + TO_SAVE = ["x", "y", "curr_accu", "curr_cnt", "legend"] + + def __init__(self, name, store_interval=1, legend=None, xlabel=None, ylabel=None): + _start_if_not_running() + + self.x = [] + self.y = [] + self.store_interval = store_interval + self.curr_accu = None + self.curr_cnt = 0 + self.name = name + self.legend = legend + self.xlabel = xlabel + self.ylabel = ylabel + + self.replot = False + self.visplot = None + self.last_vis_update_pos = 0 + + def set_legend(self, legend): + if self.legend != legend: + self.legend = legend + self.replot = True + + def _send_update(self): + if not self.x or vis is None: + return + + if self.visplot is None or self.replot: + opts = { + "title": self.name + } + if self.xlabel: + opts["xlabel"] = self.xlabel + if self.ylabel: + opts["ylabel"] = self.ylabel + + if self.legend is not None: + opts["legend"] = self.legend + + self.visplot = vis.line(X=np.asfarray(self.x), Y=np.asfarray(self.y), opts=opts, win=self.visplot) + self.replot = False + else: + vis.line( + X=np.asfarray(self.x[self.last_vis_update_pos:]), + Y=np.asfarray(self.y[self.last_vis_update_pos:]), + win=self.visplot, + update='append' + ) + + self.last_vis_update_pos = len(self.x) - 1 + + def add_point(self, x, y): + if not isinstance(y, list): + y = [y] + + + + if self.curr_accu is None: + self.curr_accu = [0.0] * len(y) + + if len(self.curr_accu) < len(y): + # The number of curves increased. + need_to_add = (len(y) - len(self.curr_accu)) + + self.curr_accu += [0.0] * need_to_add + count = len(self.x) + if count>0: + self.replot = True + if not isinstance(self.x[0], list): + self.x = [[x] for x in self.x] + self.y = [[y] for y in self.y] + + nan = float("nan") + for a in self.x: + a += [nan] * need_to_add + for a in self.y: + a += [nan] * need_to_add + elif len(self.curr_accu) > len(y): + y = y[:] + [float("nan")] * (len(self.curr_accu) - len(y)) + + self.curr_accu = [self.curr_accu[i] + y[i] for i in range(len(y))] + self.curr_cnt += 1 + if self.curr_cnt == self.store_interval: + if len(y) > 1: + self.x.append([x] * len(y)) + self.y.append([a / self.curr_cnt for a in self.curr_accu]) + else: + self.x.append(x) + self.y.append(self.curr_accu[0] / self.curr_cnt) + + self.curr_accu = [0.0] * len(y) + self.curr_cnt = 0 + + self._send_update() + + def state_dict(self): + s = {k: self.__dict__[k] for k in self.TO_SAVE} + return s + + def load_state_dict(self, state): + if self.legend is not None: + # Load legend only if not given in the constructor. + state["legend"] = self.legend + self.__dict__.update(state) + self.last_vis_update_pos = 0 + + # Read old format + if not isinstance(self.curr_accu, list) and self.curr_accu is not None: + self.curr_accu = [self.curr_accu] + + self._send_update() + + +class Image: + def __init__(self, title, dumpdir=None): + _start_if_not_running() + + self.win = None + self.opts = dict(title=title) + self.dumpdir = dumpdir + + def set_dump_dir(self, dumpdir): + self.dumpdir = dumpdir + + def draw(self, img): + if vis is None: + return + + if isinstance(img, list): + if self.win is None: + self.win = vis.images(img, opts=self.opts) + else: + vis.images(img, win=self.win, opts=self.opts) + else: + if len(img.shape)==2: + img = np.expand_dims(img, 0) + elif img.shape[-1] in [1,3] and img.shape[0] not in [1,3]: + # cv2 image + img = img.transpose(2,0,1) + img = img[::-1] + + if img.dtype==np.uint8: + img = img.astype(np.float32)/255 + + self.opts["width"] = img.shape[2] + self.opts["height"] = img.shape[1] + + save_heatmap(self.dumpdir, self.opts["title"], img) + if self.win is None: + self.win = vis.image(img, opts=self.opts) + else: + vis.image(img, win=self.win, opts=self.opts) + + def __call__(self, img): + self.draw(img) + + +class Text: + def __init__(self, title): + _start_if_not_running() + + self.win = None + self.title = title + self.curr_text = "" + + def set(self, text): + self.curr_text = text + + if vis is None: + return + + if self.win is None: + self.win = vis.text(text, opts=dict( + title=self.title + )) + else: + vis.text(text, win=self.win) + + def state_dict(self): + return {"text": self.curr_text} + + def load_state_dict(self, state): + self.set(state["text"]) + + def __call__(self, text): + self.set(text) + + +class Heatmap: + def __init__(self, title, min=None, max=None, xlabel=None, ylabel=None, colormap='Viridis', dumpdir=None): + _start_if_not_running() + + self.win = None + self.opt = dict(title=title, colormap=colormap) + self.dumpdir = dumpdir + if min is not None: + self.opt["xmin"] = min + if max is not None: + self.opt["xmax"] = max + + if xlabel: + self.opt["xlabel"] = xlabel + if ylabel: + self.opt["ylabel"] = ylabel + + def set_dump_dir(self, dumpdir): + self.dumpdir = dumpdir + + def draw(self, img): + if vis is None: + return + + o = self.opt.copy() + if "xmin" not in o: + o["xmin"] = float(img.min()) + + if "xmax" not in o: + o["xmax"] = float(img.max()) + + save_heatmap(self.dumpdir, o["title"], img) + + if self.win is None: + self.win = vis.heatmap(img, opts=o) + else: + vis.heatmap(img, win=self.win, opts=o) + + def __call__(self, img): + self.draw(img) diff --git a/Utils/download.py b/Utils/download.py new file mode 100644 index 0000000..e78b5e1 --- /dev/null +++ b/Utils/download.py @@ -0,0 +1,189 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import requests, tarfile, io, os, zipfile + +from io import BytesIO, SEEK_SET, SEEK_END + + +class UrlStream: + def __init__(self, url): + self._url = url + headers = requests.head(url).headers + headers = {k.lower(): v for k,v in headers.items()} + self._seek_supported = headers.get('accept-ranges')=='bytes' and 'content-length' in headers + if self._seek_supported: + self._size = int(headers['content-length']) + self._curr_pos = 0 + self._buf_start_pos = 0 + self._iter = None + self._buffer = None + self._buf_size = 0 + self._loaded_all = False + + def seekable(self): + return self._seek_supported + + def _load_all(self): + if self._loaded_all: + return + self._make_request() + old_buf_pos = self._buffer.tell() + self._buffer.seek(0, SEEK_END) + for chunk in self._iter: + self._buffer.write(chunk) + self._buf_size = self._buffer.tell() + self._buffer.seek(old_buf_pos, SEEK_SET) + self._loaded_all = True + + def seek(self, position, whence=SEEK_SET): + if whence == SEEK_END: + assert position<=0 + if self._seek_supported: + self.seek(self._size + position) + else: + self._load_all() + self._buffer.seek(position, SEEK_END) + self._curr_pos = self._buffer.tell() + elif whence==SEEK_SET: + if self._curr_pos != position: + self._curr_pos = position + if self._seek_supported: + self._iter = None + self._buffer = None + else: + self._load_until(position) + self._buffer.seek(position) + self._curr_pos = position + else: + assert "Invalid whence %s" % whence + + return self.tell() + + def tell(self): + return self._curr_pos + + def _load_until(self, goal_position): + self._make_request() + old_buf_pos = self._buffer.tell() + current_position = self._buffer.seek(0, SEEK_END) + + goal_position = goal_position - self._buf_start_pos + while current_position < goal_position: + try: + d = next(self._iter) + self._buffer.write(d) + current_position += len(d) + except StopIteration: + break + self._buf_size = current_position + self._buffer.seek(old_buf_pos, SEEK_SET) + + def _new_buffer(self): + remaining = self._buffer.read() if self._buffer is not None else None + self._buffer = BytesIO() + if remaining is not None: + self._buffer.write(remaining) + self._buf_start_pos = self._curr_pos + self._buf_size = 0 if remaining is None else len(remaining) + self._buffer.seek(0, SEEK_SET) + self._loaded_all = False + + def _make_request(self): + if self._iter is None: + h = { + "User-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36", + } + if self._seek_supported: + h["Range"] = "bytes=%d-%d" % (self._curr_pos, self._size - 1) + + r = requests.get(self._url, headers=h, stream=True) + + self._iter = r.iter_content(1024 * 1024) + self._new_buffer() + elif self._seek_supported and self._buf_size > 128 * 1024 * 1024: + self._new_buffer() + + def size(self): + if self._seek_supported: + return self._size + else: + self._load_all() + return self._buf_size + + def read(self, size=None): + if size is None: + size = self.size() + + self._load_until(self._curr_pos + size) + if self._seek_supported: + self._curr_pos = min(self._curr_pos+size, self._size) + + return self._buffer.read(size) + + def iter_content(self, block_size): + while True: + d = self.read(block_size) + if not len(d): + break + yield d + + +def download(url, dest=None, extract=True, ignore_if_exists=False): + """ + Download a file from the internet. + + Args: + url: the url to download + dest: destination file if extract=False, or destionation dir if extract=True. If None, it will be the last part of URL. + extract: extract a tar.gz or zip file? + ignore_if_exists: don't do anything if file exists + + Returns: + the destination filename. + """ + + base_url = url.split("?")[0] + + if dest is None: + dest = [f for f in base_url.split("/") if f][-1] + + if os.path.exists(dest) and ignore_if_exists: + return dest + + stream = UrlStream(url) + extension = base_url.split(".")[-1].lower() + + if extract and extension in ['gz', 'bz2', 'zip']: + os.makedirs(dest, exist_ok=True) + + if extension in ['gz', 'bz2']: + decompressed_file = tarfile.open(fileobj=stream, mode='r|'+extension) + elif extension=='zip': + decompressed_file = zipfile.ZipFile(stream, mode='r') + else: + assert False, "Invalid extension: %s" % extension + + decompressed_file.extractall(dest) + else: + try: + with open(dest, 'wb') as f: + for d in stream.iter_content(1024*1024): + f.write(d) + except: + os.remove(dest) + raise + return dest diff --git a/Utils/gpu_allocator.py b/Utils/gpu_allocator.py new file mode 100644 index 0000000..be88455 --- /dev/null +++ b/Utils/gpu_allocator.py @@ -0,0 +1,111 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import subprocess +import os +import torch +from Utils.lockfile import LockFile + +def get_memory_usage(): + try: + proc = subprocess.Popen("nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits".split(" "), + stdout=subprocess.PIPE) + lines = [s.strip().split(" ") for s in proc.communicate()[0].decode().split("\n") if s] + return {int(g[0][:-1]): int(g[1]) for g in lines} + except: + return None + + +def get_free_gpus(): + try: + free = [] + proc = subprocess.Popen("nvidia-smi --query-compute-apps=gpu_uuid --format=csv,noheader,nounits".split(" "), + stdout=subprocess.PIPE) + uuids = [s.strip() for s in proc.communicate()[0].decode().split("\n") if s] + + proc = subprocess.Popen("nvidia-smi --query-gpu=index,uuid --format=csv,noheader,nounits".split(" "), + stdout=subprocess.PIPE) + + id_uid_pair = [s.strip().split(", ") for s in proc.communicate()[0].decode().split("\n") if s] + for i in id_uid_pair: + id, uid = i + + if uid not in uuids: + free.append(int(id)) + + return free + except: + return None + +def _fix_order(): + os.environ["CUDA_DEVICE_ORDER"] = os.environ.get("CUDA_DEVICE_ORDER", "PCI_BUS_ID") + +def allocate(n:int = 1): + _fix_order() + with LockFile("/tmp/gpu_allocation_lock"): + if "CUDA_VISIBLE_DEVICES" in os.environ: + print("WARNING: trying to allocate %d GPUs, but CUDA_VISIBLE_DEVICES already set to %s" % + (n, os.environ["CUDA_VISIBLE_DEVICES"])) + return + + allocated = get_free_gpus() + if allocated is None: + print("WARNING: failed to allocate %d GPUs" % n) + return + allocated = allocated[:n] + + if len(allocated) < n: + print("There is no more free GPUs. Allocating the one with least memory usage.") + usage = get_memory_usage() + if usage is None: + print("WARNING: failed to allocate %d GPUs" % n) + return + + inv_usages = {} + + for k, v in usage.items(): + if v not in inv_usages: + inv_usages[v] = [] + + inv_usages[v].append(k) + + min_usage = list(sorted(inv_usages.keys())) + min_usage_devs = [] + for u in min_usage: + min_usage_devs += inv_usages[u] + + min_usage_devs = [m for m in min_usage_devs if m not in allocated] + + n2 = n - len(allocated) + if n2>len(min_usage_devs): + print("WARNING: trying to allocate %d GPUs but only %d available" % (n, len(min_usage_devs)+len(allocated))) + n2 = len(min_usage_devs) + + allocated += min_usage_devs[:n2] + + os.environ["CUDA_VISIBLE_DEVICES"]=",".join([str(a) for a in allocated]) + for i in range(len(allocated)): + a = torch.FloatTensor([0.0]) + a.cuda(i) + +def use_gpu(gpu="auto", n_autoalloc=1): + _fix_order() + + gpu = gpu.lower() + if gpu in ["auto", ""]: + allocate(n_autoalloc) + else: + os.environ["CUDA_VISIBLE_DEVICES"] = gpu diff --git a/Utils/lockfile.py b/Utils/lockfile.py new file mode 100644 index 0000000..f713b34 --- /dev/null +++ b/Utils/lockfile.py @@ -0,0 +1,41 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import os +import fcntl + + +class LockFile: + def __init__(self, fname): + self._fname = fname + self._fd = None + + def acquire(self): + self._fd=open(self._fname, "w") + os.chmod(self._fname, 0o777) + + fcntl.lockf(self._fd, fcntl.LOCK_EX) + + def release(self): + fcntl.lockf(self._fd, fcntl.LOCK_UN) + self._fd.close() + self._fd = None + + def __enter__(self): + self.acquire() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() diff --git a/Utils/timer.py b/Utils/timer.py new file mode 100644 index 0000000..402052a --- /dev/null +++ b/Utils/timer.py @@ -0,0 +1,54 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import time + + +class OnceEvery: + def __init__(self, interval): + self._interval = interval + self._last_check = 0 + + def __call__(self): + now = time.time() + if now - self._last_check >= self._interval: + self._last_check = now + return True + else: + return False + + +class Measure: + def __init__(self, average=1): + self._start = None + self._average = average + self._accu_value = 0.0 + self._history_list = [] + + def start(self): + self._start = time.time() + + def passed(self): + if self._start is None: + return None + + p = time.time() - self._start + self._history_list.append(p) + self._accu_value += p + if len(self._history_list) > self._average: + self._accu_value -= self._history_list.pop(0) + + return self._accu_value / len(self._history_list) diff --git a/Utils/universal.py b/Utils/universal.py new file mode 100644 index 0000000..70ebf0c --- /dev/null +++ b/Utils/universal.py @@ -0,0 +1,263 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import torch +import torch.nn.functional as F +import numpy as np + +float32 = [np.float32, torch.float32] +float64 = [np.float64, torch.float64] +uint8 = [np.uint8, torch.uint8] + +_all_types = [float32, float64, uint8] + +_dtype_numpy_map = {v[0]().dtype.name:v for v in _all_types} +_dtype_pytorch_map = {v[1]:v for v in _all_types} + + +def dtype(t): + if torch.is_tensor(t): + return _dtype_pytorch_map[t.dtype] + else: + return _dtype_numpy_map[t.dtype.name] + + +def cast(t, type): + if torch.is_tensor(t): + return t.type(type[1]) + else: + return t.astype(type[0]) + + +def to_numpy(t): + if torch.is_tensor(t): + return t.detach().cpu().numpy() + else: + return t + + +def to_list(t): + t = to_numpy(t) + if isinstance(t, np.ndarray): + t = t.tolist() + return t + + +def is_tensor(t): + return torch.is_tensor(t) or isinstance(t, np.ndarray) + + +def first_batch(t): + if is_tensor(t): + return t[0] + else: + return t + + +def ndim(t): + if torch.is_tensor(t): + return t.dim() + else: + return t.ndim + + +def shape(t): + return list(t.shape) + + +def transpose(t, axis): + if torch.is_tensor(t): + return t.permute(axis) + else: + return np.transpose(t, axis) + + +def apply_recursive(d, fn, filter=None): + if isinstance(d, list): + return [apply_recursive(da, fn) for da in d] + elif isinstance(d, tuple): + return tuple(apply_recursive(list(d), fn)) + elif isinstance(d, dict): + return {k: apply_recursive(v, fn) for k, v in d.items()} + else: + if filter is None or filter(d): + return fn(d) + else: + return d + + +def apply_to_tensors(d, fn): + return apply_recursive(d, fn, torch.is_tensor) + + +def recursive_decorator(apply_this_fn): + def decorator(func): + def wrapped_funct(*args, **kwargs): + args = apply_recursive(args, apply_this_fn) + kwargs = apply_recursive(kwargs, apply_this_fn) + + return func(*args, **kwargs) + + return wrapped_funct + + return decorator + +untensor = recursive_decorator(to_numpy) +unnumpy = recursive_decorator(to_list) + +def unbatch(only_if_dim_equal=None): + if only_if_dim_equal is not None and not isinstance(only_if_dim_equal, list): + only_if_dim_equal = [only_if_dim_equal] + + def get_first_batch(t): + if is_tensor(t) and (only_if_dim_equal is None or ndim(t) in only_if_dim_equal): + return t[0] + else: + return t + + return recursive_decorator(get_first_batch) + + +def sigmoid(t): + if torch.is_tensor(t): + return torch.sigmoid(t) + else: + return 1.0 / (1.0 + np.exp(-t)) + + +def argmax(t, dim): + if torch.is_tensor(t): + _, res = t.max(dim) + else: + res = np.argmax(t, axis=dim) + + return res + + +def flip(t, axis): + if torch.is_tensor(t): + return t.flip(axis) + else: + return np.flip(t, axis) + + +def transpose(t, axes): + if torch.is_tensor(t): + return t.permute(*axes) + else: + return np.transpose(t, axes) + + +def split_n(t, axis): + if torch.is_tensor(t): + return t.split(1, dim=axis) + else: + return np.split(t, t.shape[axis], axis=axis) + + +def cat(array_of_tensors, axis): + if torch.is_tensor(array_of_tensors[0]): + return torch.cat(array_of_tensors, axis) + else: + return np.concatenate(array_of_tensors, axis) + + +def clamp(t, min=None, max=None): + if torch.is_tensor(t): + return t.clamp(min, max) + else: + if min is not None: + t = np.maximum(t, min) + + if max is not None: + t = np.minimum(t, max) + + return t + + +def power(t, p): + if torch.is_tensor(t) or torch.is_tensor(p): + return torch.pow(t, p) + else: + return np.power(t, p) + + +def random_normal_as(a, mean, std, seed=None): + if torch.is_tensor(a): + return torch.randn_like(a) * std + mean + else: + if seed is None: + seed = np.random + return seed.normal(loc=mean, scale=std, size=shape(a)) + + +def pad(t, pad): + assert ndim(t) == 4 + + if torch.is_tensor(t): + return F.pad(t, pad) + else: + assert np.pad(t, ([0,0], [0,0], pad[0:2], pad[2:])) + + +def dx(img): + lsh = img[:, :, :, 2:] + orig = img[:, :, :, :-2] + + return pad(0.5 * (lsh - orig), (1, 1, 0, 0)) + + +def dy(img): + ush = img[:, :, 2:, :] + orig = img[:, :, :-2, :] + + return pad(0.5 * (ush - orig), (0, 0, 1, 1)) + + +def reshape(t, shape): + if torch.is_tensor(t): + return t.view(*shape) + else: + return t.reshape(*shape) + + +def broadcast_to_beginning(t, target): + if torch.is_tensor(t): + nd_target = target.dim() + t_shape = list(t.shape) + return t.view(*t_shape, *([1]*(nd_target-len(t_shape)))) + else: + nd_target = target.ndim + t_shape = list(t.shape) + return t.reshape(*t_shape, *([1] * (nd_target - len(t_shape)))) + + +def lin_combine(d1,w1, d2,w2, bcast_begin=False): + if isinstance(d1, (list, tuple)): + assert len(d1) == len(d2) + res = [lin_combine(d1[i], w1, d2[i], w2) for i in range(len(d1))] + if isinstance(d1, tuple): + res = tuple(d1) + elif isinstance(d1, dict): + res = {k: lin_combine(v, w1, d2[k], w2) for k, v in d1.items()} + else: + if bcast_begin: + w1 = broadcast_to_beginning(w1, d1) + w2 = broadcast_to_beginning(w2, d2) + + res = d1 * w1 + d2 * w2 + + return res diff --git a/Visualize/.DS_Store b/Visualize/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/Visualize/.DS_Store differ diff --git a/Visualize/BitmapTask.py b/Visualize/BitmapTask.py new file mode 100644 index 0000000..7721337 --- /dev/null +++ b/Visualize/BitmapTask.py @@ -0,0 +1,73 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import numpy as np +try: + import cv2 +except: + cv2=None + +from Utils.Helpers import as_numpy + +def visualize_bitmap_task(i_data, o_data, zoom=8): + if not isinstance(o_data, list): + o_data = [o_data] + + imgs = [] + for d in [i_data]+o_data: + if d is None: + continue + + d=as_numpy(d) + if d.ndim>2: + d=d[0] + + imgs.append(np.expand_dims(d.T*255, -1).astype(np.uint8)) + + img = np.concatenate(imgs, 0) + return nearest_zoom(img, zoom) + +def visualize_01(t, zoom=8): + return nearest_zoom(np.expand_dims(t*255,-1).astype(np.uint8), zoom) + +def nearest_zoom(img, zoom=1): + if zoom>1 and cv2 is not None: + return cv2.resize(img, (img.shape[1] * zoom, img.shape[0] * zoom), interpolation=cv2.INTER_NEAREST) + else: + return img + +def concatenate_tensors(tensors): + max_size = None + dtype = None + + for t in tensors: + s = t.shape + if max_size is None: + max_size = list(s) + dtype = t.dtype + continue + + assert t.ndim ==len(max_size), "Can only concatenate tensors with same ndim." + assert t.dtype == dtype, "Tensors must have the same type" + max_size = [max(max_size[i], s[i]) for i in range(len(max_size))] + + res = np.zeros([len(tensors)] + max_size, dtype=dtype) + for i, t in enumerate(tensors): + res[i][tuple([slice(0,t.shape[i]) for i in range(t.ndim)])] = t + return res + + + diff --git a/Visualize/__init__.py b/Visualize/__init__.py new file mode 100644 index 0000000..06f5289 --- /dev/null +++ b/Visualize/__init__.py @@ -0,0 +1 @@ +from .BitmapTask import * diff --git a/Visualize/preview.py b/Visualize/preview.py new file mode 100644 index 0000000..ea70521 --- /dev/null +++ b/Visualize/preview.py @@ -0,0 +1,64 @@ +# Copyright 2017 Robert Csordas. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import time +import threading +import traceback +import sys +from Utils import universal as U + + +def preview(vis_interval=None, to_numpy=True, debatch=False): + def decorator(func): + state = { + "last_vis_time": 0, + "thread_running": False + } + + def wrapper_function(*args, **kwargs): + if state["thread_running"]: + return + + if vis_interval is not None: + now = time.time() + if now - state["last_vis_time"] < vis_interval: + return + state["last_vis_time"] = now + + state["thread_running"] = True + + if debatch: + args = U.apply_recursive(args, U.first_batch) + kwargs = U.apply_recursive(kwargs, U.first_batch) + + if to_numpy: + args = U.apply_recursive(args, U.to_numpy) + kwargs = U.apply_recursive(kwargs, U.to_numpy) + + def run(): + try: + func(*args, **kwargs) + except Exception: + traceback.print_exc() + sys.exit(-1) + + state["thread_running"] = False + + download_thread = threading.Thread(target=run) + download_thread.start() + + return wrapper_function + return decorator diff --git a/assets/demon.png b/assets/demon.png new file mode 100644 index 0000000..bad4372 Binary files /dev/null and b/assets/demon.png differ diff --git a/memory_demon.py b/memory_demon.py new file mode 100644 index 0000000..726ca58 --- /dev/null +++ b/memory_demon.py @@ -0,0 +1,1070 @@ +#!/usr/bin/env python3# +# The Initial DNC Copyright 2017 Robert Csordas. All Rights Reserved. +# The modification of the initial DNC implementation by Ari Azarafrooz. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +import functools +import os + +import torch.utils.data + +import Utils.Debug as debug +from Dataset.Bitmap.AssociativeRecall import AssociativeRecall +from Dataset.Bitmap.BitmapTaskRepeater import BitmapTaskRepeater +from Dataset.Bitmap.KeyValue import KeyValue +from Dataset.Bitmap.CopyTask import CopyData +from Dataset.Bitmap.KeyValue2Way import KeyValue2Way +from Dataset.NLP.bAbi import bAbiDataset +from Models.DNCA import DNC, LSTMController, FeedforwardController +from Models.Information_Agents import RolloutStorage, Demon, FNet, ZNet + +from Utils import Visdom +from Utils.ArgumentParser import ArgumentParser +from Utils.Index import index_by_dim +from Utils.Saver import Saver, GlobalVarSaver, StateSaver +from Utils.Collate import MetaCollate +from Utils import gpu_allocator +from Dataset.NLP.NLPTask import NLPTask +from tqdm import tqdm +from Visualize.preview import preview +from Utils.timer import OnceEvery +from Utils import Seed +import time +import sys +import signal +import math +from Utils import Profile +import shutil +import math + +#from torch.utils.tensorboard import SummaryWriter + +import numpy as np + +model_dir = "" + +Profile.ENABLED = False + +random_seed = 1 + +if random_seed: + print("Random Seed: {}".format(random_seed)) + torch.manual_seed(random_seed) + np.random.seed(random_seed) + +if os.path.exists("tmp_train_dir"): + shutil.rmtree("tmp_train_dir") + +action_std = 0.1 # constant std for action distribution (Multivariate Normal) +K_epochs = 1 # update policy for K epochs +eps_clip = 0.2 # clip parameter for PPO +gamma = 1 # discount factor +lr = 0.001 # parameters for Adam optimizer #0.01 +betas = (0.9, 0.999) + + +def main(): + global i + global loss_sum + global running + parser = ArgumentParser() + parser.add_argument( + "-bit_w", type=int, default=8, help="Bit vector length for copy task" + ) + parser.add_argument( + "-block_w", type=int, default=3, help="Block width to associative recall task" + ) + parser.add_argument( + "-len", + type=str, + default="4", + help="Sequence length for copy task", + parser=lambda x: [int(a) for a in x.split("-")], + ) + parser.add_argument( + "-repeat", + type=str, + default="1", + help="Sequence length for copy task", + parser=lambda x: [int(a) for a in x.split("-")], + ) + parser.add_argument( + "-batch_size", type=int, default=16, help="Sequence length for copy task" + ) + parser.add_argument( + "-n_subbatch", + type=str, + default="auto", + help="Average this much forward passes to a backward pass", + ) + parser.add_argument( + "-max_input_count_per_batch", + type=int, + default=6000, + help="Max batch_size*len that can fit into memory", + ) + parser.add_argument("-lr", type=float, default=0.0001, help="Learning rate") + parser.add_argument("-wd", type=float, default=1e-5, help="Weight decay") + parser.add_argument( + "-optimizer", type=str, default="rmsprop", help="Optimizer algorithm" + ) + parser.add_argument("-name", type=str, help="Save training to this directory") + parser.add_argument( + "-preview_interval", + type=int, + default=10, + help="Show preview every nth iteration", + ) + parser.add_argument( + "-info_interval", type=int, default=10, help="Show info every nth iteration" + ) + parser.add_argument( + "-save_interval", type=int, default=500, help="Save network every nth iteration" + ) + parser.add_argument( + "-masked_lookup", type=bool, default=1, help="Enable masking in content lookups" + ) + parser.add_argument( + "-visport", + type=int, + default=-1, + help="Port to run Visdom server on. -1 to disable", + ) + parser.add_argument("-gpu", default="auto", type=str, help="Run on this GPU.") + parser.add_argument("-debug", type=bool, default=0, help="Enable debugging") + parser.add_argument("-task", type=str, default="copy", help="Task to learn") + parser.add_argument( + "-mem_count", type=int, default=16, help="Number of memory cells" + ) + parser.add_argument( + "-data_word_size", type=int, default=128, help="Memory word size" + ) + parser.add_argument( + "-n_read_heads", type=int, default=1, help="Number of read heads" + ) + parser.add_argument( + "-layer_sizes", + type=str, + default="256", + help="Controller layer sizes. Separate with ,. For example 512,256,256", + parser=lambda x: [int(y) for y in x.split(",") if y], + ) + parser.add_argument("-debug_log", type=bool, default=0, help="Enable debug log") + parser.add_argument( + "-controller_type", + type=str, + default="lstm", + help="Controller type: lstm or linear", + ) + parser.add_argument( + "-lstm_use_all_outputs", + type=bool, + default=1, + help="Use all LSTM outputs as controller output vs use only the last layer", + ) + parser.add_argument( + "-momentum", type=float, default=0.9, help="Momentum for optimizer" + ) + parser.add_argument( + "-embedding_size", + type=int, + default=256, + help="Size of word embedding for NLP tasks", + ) + parser.add_argument( + "-test_interval", type=int, default=10, help="Run test in this interval" + ) + parser.add_argument( + "-dealloc_content", + type=bool, + default=1, + help="Deallocate memory content, unlike DNC, which leaves it unchanged, just decreases the usage counter, causing problems with lookup", + ) + parser.add_argument( + "-sharpness_control", + type=bool, + default=1, + help="Distribution sharpness control for forward and backward links", + ) + parser.add_argument( + "-think_steps", + type=int, + default=0, + help="Iddle steps before requiring the answer (for bAbi)", + ) + parser.add_argument("-dump_profile", type=str, save=False) + parser.add_argument("-test_on_start", default="0", save=False) + parser.add_argument("-dump_heatmaps", default=False, save=False) + parser.add_argument("-test_batch_size", default=16) + parser.add_argument("-mask_min", default=0.0) + parser.add_argument("-load", type=str, save=False) + parser.add_argument( + "-dataset_path", + type=str, + default="none", + parser=ArgumentParser.str_or_none(), + help="Specify babi path manually", + ) + parser.add_argument( + "-babi_train_tasks", + type=str, + default="none", + parser=ArgumentParser.list_or_none(type=str), + help="babi task list to use for training", + ) + parser.add_argument( + "-babi_test_tasks", + type=str, + default="none", + parser=ArgumentParser.list_or_none(type=str), + help="babi task list to use for testing", + ) + parser.add_argument( + "-babi_train_sets", + type=str, + default="train", + parser=ArgumentParser.list_or_none(type=str), + help="babi train sets to use", + ) + parser.add_argument( + "-babi_test_sets", + type=str, + default="test", + parser=ArgumentParser.list_or_none(type=str), + help="babi test sets to use", + ) + parser.add_argument( + "-noargsave", + type=bool, + default=False, + help="Do not save modified arguments", + save=False, + ) + parser.add_argument( + "-demo", + type=bool, + default=False, + help="Do a single step with fixed seed", + save=False, + ) + parser.add_argument( + "-exit_after", + type=int, + help="Exit after this amount of steps. Useful for debugging.", + save=False, + ) + parser.add_argument( + "-grad_clip", type=float, default=10.0, help="Max gradient norm" + ) + parser.add_argument( + "-clip_controller", type=float, default=20.0, help="Max gradient norm" + ) + parser.add_argument("-print_test", default=False, save=False) + + parser.add_profile( + [ + ArgumentParser.Profile( + "babi", + { + "preview_interval": 10, + "save_interval": 500, + "task": "babi", + "mem_count": 64, + "data_word_size": 64, + "n_read_heads": 4, + "layer_sizes": "128", + "controller_type": "lstm", + "lstm_use_all_outputs": True, + "momentum": 0.9, + "embedding_size": 128, + "test_interval": 10000, + "think_steps": 3, + "batch_size": 4, + }, + include=["dnc-msd"], + ), + ArgumentParser.Profile( + "repeat_copy", + { + "bit_w": 8, + "repeat": "1-8", + "len": "2-14", + "task": "copy", + "think_steps": 1, + "preview_interval": 10, + "info_interval": 10, + "save_interval": 100, + "data_word_size": 16, + "layer_sizes": "32", + "n_subbatch": 1, + "controller_type": "lstm", + }, + ), + ArgumentParser.Profile( + "repeat_copy_simple", + { + "repeat": "1-3", + }, + include="repeat_copy", + ), + ArgumentParser.Profile( + "dnc", + { + "masked_lookup": False, + "sharpness_control": False, + "dealloc_content": False, + }, + ), + ArgumentParser.Profile( + "dnc-m", + { + "masked_lookup": True, + "sharpness_control": False, + "dealloc_content": False, + }, + ), + ArgumentParser.Profile( + "dnc-s", + { + "masked_lookup": False, + "sharpness_control": True, + "dealloc_content": False, + }, + ), + ArgumentParser.Profile( + "dnc-d", + { + "masked_lookup": False, + "sharpness_control": False, + "dealloc_content": True, + }, + ), + ArgumentParser.Profile( + "dnc-md", + { + "masked_lookup": True, + "sharpness_control": False, + "dealloc_content": True, + }, + ), + ArgumentParser.Profile( + "dnc-ms", + { + "masked_lookup": True, + "sharpness_control": True, + "dealloc_content": False, + }, + ), + ArgumentParser.Profile( + "dnc-sd", + { + "masked_lookup": False, + "sharpness_control": True, + "dealloc_content": True, + }, + ), + ArgumentParser.Profile( + "dnc-msd", + { + "masked_lookup": True, + "sharpness_control": True, + "dealloc_content": True, + }, + ), + ArgumentParser.Profile( + "keyvalue", + { + "repeat": "1", + "len": "2-16", + "mem_count": 16, + "task": "keyvalue", + "think_steps": 1, + "preview_interval": 10, + "info_interval": 10, + "data_word_size": 32, + "bit_w": 12, + "save_interval": 1000, + "layer_sizes": "32", + }, + ), + ArgumentParser.Profile( + "keyvalue2way", + { + "task": "keyvalue2way", + }, + include="keyvalue", + ), + ArgumentParser.Profile( + "associative_recall", + { + "task": "recall", + "bit_w": 8, + "len": "2-16", + "mem_count": 64, + "data_word_size": 32, + "n_read_heads": 1, + "layer_sizes": "128", + "controller_type": "lstm", + "lstm_use_all_outputs": 1, + "think_steps": 1, + "mask_min": 0.1, + "info_interval": 10, + "save_interval": 1000, + "preview_interval": 10, + "n_subbatch": 1, + }, + ), + ] + ) + + opt = parser.parse() + assert opt.name is not None, "Training dir (-name parameter) not given" + opt = parser.sync(os.path.join(opt.name, "args.json"), save=not opt.noargsave) + + if opt.demo: + Seed.fix() + + os.makedirs(os.path.join(opt.name, "save"), exist_ok=True) + os.makedirs(os.path.join(opt.name, "preview"), exist_ok=True) + + gpu_allocator.use_gpu(opt.gpu) + + debug.enableDebug = opt.debug_log + + if opt.visport > 0: + Visdom.start(opt.visport) + + Visdom.Text("Name").set(opt.name) + + class LengthHackSampler: + def __init__(self, batch_size, length): + self.length = length + self.batch_size = batch_size + + def __iter__(self): + while True: + len = self.length() if callable(self.length) else self.length + yield [len] * self.batch_size + + def __len__(self): + return 0x7FFFFFFF + + embedding = None + test_set = None + curriculum = None + loader_reset = False + if opt.task == "copy": + dataset = CopyData(bit_w=opt.bit_w) + in_size = opt.bit_w + 1 + out_size = in_size + elif opt.task == "recall": + dataset = AssociativeRecall(bit_w=opt.bit_w, block_w=opt.block_w) + in_size = opt.bit_w + 2 + out_size = in_size + elif opt.task == "keyvalue": + assert opt.bit_w % 2 == 0, "Key-value datasets works only with even bit_w" + dataset = KeyValue(bit_w=opt.bit_w) + in_size = opt.bit_w + 1 + out_size = opt.bit_w // 2 + elif opt.task == "keyvalue2way": + assert opt.bit_w % 2 == 0, "Key-value datasets works only with even bit_w" + dataset = KeyValue2Way(bit_w=opt.bit_w) + in_size = opt.bit_w + 2 + out_size = opt.bit_w // 2 + elif opt.task == "babi": + dataset = bAbiDataset(think_steps=opt.think_steps, dir_name=opt.dataset_path) + test_set = bAbiDataset( + think_steps=opt.think_steps, dir_name=opt.dataset_path, name="test" + ) + dataset.use(opt.babi_train_tasks, opt.babi_train_sets) + in_size = opt.embedding_size + print("bAbi: loaded total of %d sequences." % len(dataset)) + test_set.use(opt.babi_test_tasks, opt.babi_test_sets) + out_size = len(dataset.vocabulary) + print( + "bAbi: using %d sequences for training, %d for testing" + % (len(dataset), len(test_set)) + ) + else: + assert False, "Invalid task: %s" % opt.task + + if opt.task in ["babi"]: + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=opt.batch_size, + num_workers=4, + pin_memory=True, + shuffle=True, + collate_fn=MetaCollate(), + ) + test_loader = ( + torch.utils.data.DataLoader( + test_set, + batch_size=opt.test_batch_size, + num_workers=opt.test_batch_size, + pin_memory=True, + shuffle=False, + collate_fn=MetaCollate(), + ) + if test_set is not None + else None + ) + else: + dataset = BitmapTaskRepeater(dataset) + data_loader = torch.utils.data.DataLoader( + dataset, + batch_sampler=LengthHackSampler( + opt.batch_size, BitmapTaskRepeater.key_sampler(opt.len, opt.repeat) + ), + num_workers=1, + pin_memory=True, + ) + + if opt.controller_type == "lstm": + controller_constructor = functools.partial( + LSTMController, out_from_all_layers=opt.lstm_use_all_outputs + ) + elif opt.controller_type == "linear": + controller_constructor = FeedforwardController + else: + assert False, "Invalid controller: %s" % opt.controller_type + + parity_size = 0 + + model = DNC( + in_size + parity_size, + out_size, + opt.data_word_size, + opt.mem_count, + opt.n_read_heads, + controller_constructor(opt.layer_sizes), + batch_first=True, + mask=opt.masked_lookup, + dealloc_content=opt.dealloc_content, + link_sharpness_control=opt.sharpness_control, + mask_min=opt.mask_min, + clip_controller=opt.clip_controller, + ) + + # model.load_state_dict(torch.load(model_dir, map_location="cpu")["model"]) + + print("data_word_size: {}".format(opt.data_word_size)) + rollout_storage = RolloutStorage() + demon_state_dim = ( + in_size + opt.mem_count * opt.data_word_size + ) #:TODO opt.mem_count * opt.data_word_size + demon_action_dim = in_size + + demon = Demon( + demon_state_dim, + demon_action_dim, + action_std, + lr, + betas, + gamma, + K_epochs, + eps_clip, + ) + + fnet = FNet() + fnet.init(2 * opt.mem_count * opt.data_word_size) + #fnet.load_state_dict(torch.load(model_dir, map_location="cpu")["FNet"]) + + znet = ZNet() + znet.init(opt.mem_count * opt.data_word_size) + #znet.load_state_dict(torch.load(model_dir, map_location="cpu")["ZNet"]) + + params = [ + {"params": [p for n, p in model.named_parameters() if not n.endswith(".bias")]}, + { + "params": [p for n, p in model.named_parameters() if n.endswith(".bias")], + "weight_decay": 0, + }, + ] + + device = ( + torch.device("cuda") + if opt.gpu != "none" and torch.cuda.is_available() + else torch.device("cpu") + ) + print("DEVICE: ", device) + + if isinstance(dataset, NLPTask): + embedding = torch.nn.Embedding(len(dataset.vocabulary), opt.embedding_size).to( + device + ) + params.append({"params": embedding.parameters(), "weight_decay": 0}) + # embedding.load_state_dict( + # torch.load(model_dir, map_location="cpu")["word_embeddings"] + # ) + + if opt.optimizer == "sgd": + optimizer = torch.optim.SGD( + params, lr=opt.lr, weight_decay=opt.wd, momentum=opt.momentum + ) + elif opt.optimizer == "adam": + optimizer = torch.optim.Adam(params, lr=opt.lr, weight_decay=opt.wd) + elif opt.optimizer == "rmsprop": + optimizer = torch.optim.RMSprop( + params, lr=opt.lr, weight_decay=opt.wd, momentum=opt.momentum, eps=1e-10 + ) + else: + assert "Invalid optimizer: %s" % opt.optimizer + + n_params = sum([sum([t.numel() for t in d["params"]]) for d in params]) + print("Number of parameters: %d" % n_params) + + model = model.to(device) + fnet = fnet.to(device) + znet = znet.to(device) + znet_optim = torch.optim.Adam(znet.parameters(), lr=0.001) + fnet_optim = torch.optim.Adam(fnet.parameters(), lr=0.001) + + if embedding is not None and hasattr(embedding, "to"): + embedding = embedding.to(device) + + i = 0 + loss_sum = 0 + + loss_plot = Visdom.Plot2D( + "loss", store_interval=opt.info_interval, xlabel="iterations", ylabel="loss" + ) + + if curriculum is not None: + curriculum_plot = Visdom.Plot2D( + "curriculum lesson" + + ( + " (last %d)" % (curriculum.n_lessons - 1) + if curriculum.n_lessons is not None + else "" + ), + xlabel="iterations", + ylabel="lesson", + ) + curriculum_accuracy = Visdom.Plot2D( + "curriculum accuracy", xlabel="iterations", ylabel="accuracy" + ) + + saver = Saver(os.path.join(opt.name, "save"), short_interval=opt.save_interval) + saver.register("model", StateSaver(model)) + saver.register("optimizer", StateSaver(optimizer)) + saver.register("i", GlobalVarSaver("i")) + saver.register("loss_sum", GlobalVarSaver("loss_sum")) + saver.register("loss_plot", StateSaver(loss_plot)) + saver.register("dataset", StateSaver(dataset)) + if test_set: + saver.register("test_set", StateSaver(test_set)) + + if curriculum is not None: + saver.register("curriculum", StateSaver(curriculum)) + saver.register("curriculum_plot", StateSaver(curriculum_plot)) + saver.register("curriculum_accuracy", StateSaver(curriculum_accuracy)) + + if isinstance(dataset, NLPTask): + saver.register("word_embeddings", StateSaver(embedding)) + elif embedding is not None: + saver.register("embeddings", StateSaver(embedding)) + + visualizers = {} + + debug_schemas = { + "read_head": {"list_dim": 2}, + "temporal_links/forward_dists": {"list_dim": 2}, + "temporal_links/backward_dists": {"list_dim": 2}, + } + + def plot_debug(debug, prefix="", schema={}): + if debug is None: + return + + for k, v in debug.items(): + curr_name = prefix + k + if curr_name in debug_schemas: + curr_schema = schema.copy() + curr_schema.update(debug_schemas[curr_name]) + else: + curr_schema = schema + + if isinstance(v, dict): + plot_debug(v, curr_name + "/", curr_schema) + continue + + data = v[0] + + if curr_schema.get("list_dim", -1) > 0: + if data.ndim != 3: + print( + "WARNING: unknown data shape for array display: %s, tensor %s" + % (data.shape, curr_name) + ) + continue + + n_steps = data.shape[curr_schema["list_dim"] - 1] + if curr_name not in visualizers: + visualizers[curr_name] = [ + Visdom.Heatmap( + curr_name + "_%d" % i, + dumpdir=os.path.join(opt.name, "preview") + if opt.dump_heatmaps + else None, + ) + for i in range(n_steps) + ] + + for i in range(n_steps): + visualizers[curr_name][i].draw( + index_by_dim(data, curr_schema["list_dim"] - 1, i) + ) + else: + if data.ndim != 2: + print( + "WARNING: unknown data shape for simple display: %s, tensor %s" + % (data.shape, curr_name) + ) + continue + + if curr_name not in visualizers: + visualizers[curr_name] = Visdom.Heatmap( + curr_name, + dumpdir=os.path.join(opt.name, "preview") + if opt.dump_heatmaps + else None, + ) + + visualizers[curr_name].draw(data) + + def run_model(input, debug=None, demon=None, rollout_storage=None): + if isinstance(dataset, NLPTask): + input = embedding(input["input"]) + else: + input = input["input"] * 2.0 - 1.0 + + return model(input, debug=debug, demon=demon, rollout_storage=rollout_storage) + + def run_znet(mem_state): + mem_state = mem_state[:, 1:, :] + return znet(mem_state) + + def run_fnet(mem_state, marginal=False): + shuffled_mem_state = None + if not marginal: + input = torch.cat((mem_state[:, :-1, :], mem_state[:, 1:, :]), dim=2) + else: + shuffled_indx = torch.randperm(mem_state.size(0)).to( + device + ) # random index for shuffling the elements of batch + shuffled_mem_state = mem_state.index_select(0, shuffled_indx) + input = torch.cat( + (mem_state[:, :-1, :], shuffled_mem_state[:, 1:, :]), dim=2 + ) + + return fnet(input), shuffled_mem_state + + def multiply_grads(params, mul): + if mul == 1: + return + + for pa in params: + for p in pa["params"]: + p.grad.data *= mul + + def test(): + if test_set is None: + return + + print("TESTING...") + start_time = time.time() + t = test_set.start_test() + with torch.no_grad(): + for data in tqdm(test_loader): + data = { + k: v.to(device) if torch.is_tensor(v) else v + for k, v in data.items() + } + if hasattr(dataset, "prepare"): + data = dataset.prepare(data) + + net_out = run_model(data, demon=demon) + test_set.veify_result(t, data, net_out) + + test_set.show_test_results(i, t) + print("Test done in %gs" % (time.time() - start_time)) + + if opt.test_on_start.lower() in ["on", "1", "true", "quit"]: + test() + if opt.test_on_start.lower() == "quit": + saver.write(i) + sys.exit(-1) + + if opt.print_test: + model.eval() + total = 0 + correct = 0 + with torch.no_grad(): + for data in tqdm(test_loader): + if not running: + return + + data = { + k: v.to(device) if torch.is_tensor(v) else v + for k, v in data.items() + } + if hasattr(test_set, "prepare"): + data = test_set.prepare(data) + + net_out = run_model(data, demon) + + c, t = test_set.curriculum_measure(net_out, data["output"]) + total += t + correct += c + + print( + "Test result: %2.f%% (%d out of %d correct)" + % (100.0 * correct / total, correct, total) + ) + model.train() + return + + iter_start_time = time.time() if i % opt.info_interval == 0 else None + data_load_total_time = 0 + + start_i = i + + if opt.dump_profile: + profiler = torch.autograd.profiler.profile(use_cuda=True) + + if opt.dump_heatmaps: + dataset.set_dump_dir(os.path.join(opt.name, "preview")) + + @preview() + def do_visualize(raw_data, output, pos_map, debug): + if pos_map is not None: + output = embedding.backmap_output( + output, pos_map, raw_data["output"].shape[1] + ) + dataset.visualize_preview(raw_data, output) + + if debug is not None: + plot_debug(debug) + + preview_timer = OnceEvery(opt.preview_interval) + + pos_map = None + start_iter = i + + if curriculum is not None: + curriculum.init() + + ma_et = 1.0 + while running: + data_load_timer = time.time() + for data in data_loader: + if not running: + break + + if loader_reset: + print("Loader reset requested. Resetting...") + loader_reset = False + if curriculum is not None: + curriculum.lesson_started() + break + + if opt.dump_profile: + if i == start_i + 1: + print("Starting profiler") + profiler.__enter__() + elif i == start_i + 5 + 1: + print("Stopping profiler") + profiler.__exit__(None, None, None) + print("Average stats") + print(profiler.key_averages().table("cpu_time_total")) + print("Writing trace to file") + profiler.export_chrome_trace(opt.dump_profile) + print("Done.") + sys.exit(0) + else: + print("Step %d out of 5" % (i - start_i)) + + debug.dbg_print("-------------------------------------") + raw_data = data + + data = { + k: v.to(device) if torch.is_tensor(v) else v for k, v in data.items() + } + if hasattr(dataset, "prepare"): + data = dataset.prepare(data) + + data_load_total_time += time.time() - data_load_timer + + need_preview = preview_timer() + debug_data = {} if opt.debug and need_preview else None + + optimizer.zero_grad() + # demon_optim.zero_grad() + # znet_optim.zero_grad() + # fnet_optim.zero_grad() + + if opt.n_subbatch == "auto": + n_subbatch = math.ceil( + data["input"].numel() / opt.max_input_count_per_batch + ) + else: + n_subbatch = int(opt.n_subbatch) + + real_batch = max(math.floor(opt.batch_size / n_subbatch), 1) + n_subbatch = math.ceil(opt.batch_size / real_batch) + remaning_batch = opt.batch_size % real_batch + + for subbatch in range(n_subbatch): + if not running: + break + input = data["input"] + target = data["output"] + + if n_subbatch != 1: + input = input[subbatch * real_batch : (subbatch + 1) * real_batch] + target = target[subbatch * real_batch : (subbatch + 1) * real_batch] + + f2 = data.copy() + f2["input"] = input + + # Demon modifies the memory before DNC model + output = run_model( + f2, + debug=debug_data if subbatch == n_subbatch - 1 else None, + demon=demon, + rollout_storage=rollout_storage, + ) + l = dataset.loss(output, target) + l.backward() + + mem_state = torch.stack(model.mem_state, dim=1) + t, _ = run_fnet(mem_state.detach()) + z = run_znet(mem_state.detach()) + + et, shuffled_mem = run_fnet(mem_state.detach(), marginal=True) + et = torch.exp(et) + mi_lb = t - (torch.mean(et) / z + torch.log(z) - 1) + + demon_rewards = mi_lb + info_loss = -mi_lb.mean() + + info_loss.backward() + + fnet_optim.step() + fnet_optim.zero_grad() + + znet_optim.step() + znet_optim.zero_grad() + + del model.mem_state[:] # reset the mem state + + for j in range(0, demon_rewards.size(1)): + rollout_storage.rewards.append(demon_rewards[:, j].detach()) + + # update if its time + if i % 1 == 0: # TODO demon_update_timestep = C + demon_loss = demon.update(rollout_storage) + rollout_storage.clear_storage() + + debug.nan_check(l, force=True) + + if curriculum is not None: + curriculum.update(*dataset.curriculum_measure(output, target)) + + if remaning_batch != 0 and subbatch == n_subbatch - 2: + multiply_grads(params, real_batch / remaning_batch) + + if n_subbatch != 1: + if remaning_batch == 0: + multiply_grads(params, 1 / n_subbatch) + else: + multiply_grads(params, remaning_batch / opt.batch_size) + + for p in params: + torch.nn.utils.clip_grad_norm_(p["params"], opt.grad_clip) + + optimizer.step() + + i += 1 + + curr_loss = l.data.item() + loss_plot.add_point(i, curr_loss) + + # writer.add_scalar("associative-recall-loss/dnc-md", curr_loss, i) + # writer.add_scalar("associative-recall-mutual_info/dnc-md", info_loss, i) + # writer.add_scalar( + # "associative-recall-demon_loss/dnc-md", demon_loss.mean(), i + # ) + + loss_sum += curr_loss + + if i % opt.info_interval == 0: + tim = time.time() + loss_avg = loss_sum / opt.info_interval + + if curriculum is not None: + curriculum_accuracy.add_point(i, curriculum.get_accuracy()) + curriculum_plot.add_point(i, curriculum.step) + + message = "Iteration %d, loss: %.4f" % (i, loss_avg) + if iter_start_time is not None: + message += ( + " (%.2f ms/iter, load time %.2g ms/iter, visport: %s)" + % ( + (tim - iter_start_time) / opt.info_interval * 1000.0, + data_load_total_time / opt.info_interval * 1000.0, + Visdom.port, + ) + ) + print(message) + iter_start_time = tim + loss_sum = 0 + data_load_total_time = 0 + + debug.dbg_print("Iteration %d, loss %g" % (i, curr_loss)) + + if need_preview: + do_visualize(raw_data, output, pos_map, debug_data) + + if i % opt.test_interval == 0: + test() + + saver.tick(i) + + if opt.demo and opt.exit_after is None: + running = False + input("Press enter to quit.") + + if opt.exit_after is not None and (i - start_iter) >= opt.exit_after: + running = False + + data_load_timer = time.time() + + +if __name__ == "__main__": + #writer = SummaryWriter() + global running + running = True + + def signal_handler(signal, frame): + global running + print("You pressed Ctrl+C!") + running = False + + signal.signal(signal.SIGINT, signal_handler) + + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3caa71c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +tqdm +torch +visdom +numpy +tensorboard + + +