diff --git a/Dataset/.DS_Store b/Dataset/.DS_Store
new file mode 100644
index 0000000..3f2c165
Binary files /dev/null and b/Dataset/.DS_Store differ
diff --git a/Dataset/Bitmap/.DS_Store b/Dataset/Bitmap/.DS_Store
new file mode 100644
index 0000000..5008ddf
Binary files /dev/null and b/Dataset/Bitmap/.DS_Store differ
diff --git a/Dataset/Bitmap/AssociativeRecall.py b/Dataset/Bitmap/AssociativeRecall.py
new file mode 100644
index 0000000..86f9434
--- /dev/null
+++ b/Dataset/Bitmap/AssociativeRecall.py
@@ -0,0 +1,72 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import numpy as np
+from .BitmapTask import BitmapTask
+from Utils.Seed import get_randstate
+
+
+class AssociativeRecall(BitmapTask):
+ def __init__(self, length=None, bit_w=8, block_w=3, transform=lambda x: x):
+ super(AssociativeRecall, self).__init__()
+ self.length = length
+ self.bit_w = bit_w
+ self.block_w = block_w
+ self.transform = transform
+ self.seed = None
+
+ def __getitem__(self, key):
+ if self.seed is None:
+ self.seed = get_randstate()
+
+ length = self.length() if callable(self.length) else self.length
+ if length is None:
+ # Random length batch hack.
+ length = key
+
+ stride = self.block_w + 1
+
+ d = self.seed.randint(
+ 0, 2, [length * (self.block_w + 1), self.bit_w + 2]
+ ).astype(np.float32)
+ d[:, -2:] = 0
+
+ # Terminate input block
+ for i in range(1, length, 1):
+ d[i * stride - 1, :] = 0
+ d[i * stride - 1, -2] = 1
+
+ # Terminate input sequence
+ d[-1, :] = 0
+ d[-1, -1] = 1
+
+ # Add and terminate query
+ ti = self.seed.randint(0, length - 1)
+ d = np.concatenate(
+ (
+ d,
+ d[ti * stride : (ti + 1) * stride - 1],
+ np.zeros([self.block_w + 1, self.bit_w + 2], np.float32),
+ ),
+ axis=0,
+ )
+ d[-(1 + self.block_w), -1] = 1
+
+ # Target
+ target = np.zeros_like(d)
+ target[-self.block_w :] = d[(ti + 1) * stride : (ti + 2) * stride - 1]
+
+ return self.transform({"input": d, "output": target})
diff --git a/Dataset/Bitmap/BitmapTask.py b/Dataset/Bitmap/BitmapTask.py
new file mode 100644
index 0000000..eac03e4
--- /dev/null
+++ b/Dataset/Bitmap/BitmapTask.py
@@ -0,0 +1,82 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import torch
+import torch.nn.functional as F
+from Visualize.BitmapTask import visualize_bitmap_task
+from Utils import Visdom
+from Utils import universal as U
+
+import numpy as np
+
+
+class BitmapTask(torch.utils.data.Dataset):
+ def __init__(self):
+ super(BitmapTask, self).__init__()
+
+ self._img = Visdom.Image("preview")
+
+ def set_dump_dir(self, dir):
+ self._img.set_dump_dir(dir)
+
+ def __len__(self):
+ return 0x7FFFFFFF
+
+ def visualize_preview(self, data, net_output):
+ img = visualize_bitmap_task(
+ data["input"], [data["output"], U.sigmoid(net_output)]
+ )
+ self._img.draw(img)
+
+ def loss(self, net_output, target):
+ return F.binary_cross_entropy_with_logits(
+ net_output, target, reduction="sum"
+ ) / net_output.size(0)
+
+ def accuracy(self, net_output, target):
+ return F.binary_cross_entropy_with_logits(
+ net_output, target, reduction="sum"
+ ) / net_output.size(0)
+
+ def demon_loss(self, net_output, target, saved_actions, device):
+ """
+ computes the loss for the demon
+ :param net_output:
+ :param target:
+ :param saved_actions:
+ :return:
+ """
+ net_output = net_output.detach()
+ loss = F.binary_cross_entropy_with_logits(
+ net_output, target, reduction="none"
+ ).sum(dim=-1)
+
+ policy_losses = [] # list to save actor (policy) loss
+
+ discount_factor = 0.99
+ for i in range(0, loss.size(1)): # computing expected total reward
+ discount_vector = torch.from_numpy(np.array([np.power(discount_factor,i) for i in range(loss.size(1)-i)])).to(device)
+ policy_losses.append(((saved_actions[i].log_prob).squeeze(1) * (discount_vector*loss[:, i:]).mean(dim=1)))
+
+ demon_loss = torch.stack(policy_losses).mean(dim=0)/loss.size(1)
+
+ return demon_loss
+
+ def state_dict(self):
+ return {}
+
+ def load_state_dict(self, state):
+ pass
diff --git a/Dataset/Bitmap/BitmapTaskRepeater.py b/Dataset/Bitmap/BitmapTaskRepeater.py
new file mode 100644
index 0000000..0457b08
--- /dev/null
+++ b/Dataset/Bitmap/BitmapTaskRepeater.py
@@ -0,0 +1,56 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import numpy as np
+import random
+from .BitmapTask import BitmapTask
+
+
+class BitmapTaskRepeater(BitmapTask):
+ def __init__(self, dataset):
+ super(BitmapTaskRepeater, self).__init__()
+ self.dataset = dataset
+
+ def __getitem__(self, key):
+ r = [self.dataset[k] for k in key]
+ if len(r) == 1:
+ return r[0]
+ else:
+ return {
+ "input": np.concatenate([a["input"] for a in r], axis=0),
+ "output": np.concatenate([a["output"] for a in r], axis=0),
+ }
+
+ @staticmethod
+ def key_sampler(length, repeat):
+ def call_sampler(s):
+ if callable(s):
+ return s()
+ elif isinstance(s, list):
+ if len(s) == 2:
+ return random.randint(*s)
+ elif len(s) == 1:
+ return s[0]
+ else:
+ assert False, "Invalid sample parameter: %s" % s
+ else:
+ return s
+
+ def s():
+ r = call_sampler(repeat)
+ return [call_sampler(length) for i in range(r)]
+
+ return s
diff --git a/Dataset/Bitmap/CopyTask.py b/Dataset/Bitmap/CopyTask.py
new file mode 100644
index 0000000..953f9f3
--- /dev/null
+++ b/Dataset/Bitmap/CopyTask.py
@@ -0,0 +1,49 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import numpy as np
+from .BitmapTask import BitmapTask
+from Utils.Seed import get_randstate
+
+
+class CopyData(BitmapTask):
+ def __init__(self, length=None, bit_w=8, transform=lambda x: x):
+ super(CopyData, self).__init__()
+ self.length = length
+ self.bit_w = bit_w
+ self.transform = transform
+ self.seed = None
+
+ def __getitem__(self, key):
+ if self.seed is None:
+ self.seed = get_randstate()
+
+ length = self.length() if callable(self.length) else self.length
+ if length is None:
+ # Random length batch hack.
+ length = key
+
+ d = self.seed.randint(0, 2, [length + 1, self.bit_w + 1]).astype(np.float32)
+ z = np.zeros_like(d)
+
+ d[-1] = 0
+ d[:, -1] = 0
+ d[-1, -1] = 1
+
+ i_p = np.concatenate((d, z), axis=0)
+ o_p = np.concatenate((z, d), axis=0)
+
+ return self.transform({"input": i_p, "output": o_p})
diff --git a/Dataset/Bitmap/KeyValue.py b/Dataset/Bitmap/KeyValue.py
new file mode 100644
index 0000000..527dc9b
--- /dev/null
+++ b/Dataset/Bitmap/KeyValue.py
@@ -0,0 +1,81 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import math
+import numpy as np
+from .BitmapTask import BitmapTask
+from Utils.Seed import get_randstate
+
+
+class KeyValue(BitmapTask):
+ def __init__(self, length=None, bit_w=8, transform=lambda x: x):
+ assert bit_w % 2 == 0, "bit_w must be even"
+ super(KeyValue, self).__init__()
+ self.length = length
+ self.bit_w = bit_w
+ self.transform = transform
+ self.seed = None
+ self.key_w = self.bit_w // 2
+ self.max_key = 2 ** self.key_w - 1
+
+ def __getitem__(self, key):
+ if self.seed is None:
+ self.seed = get_randstate()
+
+ if self.length is None:
+ # Random length batch hack.
+ length = key
+ else:
+ length = self.length() if callable(self.length) else self.length
+
+ # keys must be unique
+ keys = None
+ last_size = 0
+ while last_size != length:
+ res = self.seed.random_integers(0, self.max_key, size=(length - last_size))
+ if keys is not None:
+ keys = np.concatenate((res, keys))
+ else:
+ keys = res
+
+ keys = np.unique(keys)
+ last_size = keys.size
+
+ # view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it
+ keys = keys.view(np.uint8).reshape(length, -1)
+ keys = keys[:, : math.ceil(self.key_w / 8)]
+ keys = np.unpackbits(np.expand_dims(keys, -1), axis=-1)
+ keys = np.flip(keys, axis=-1).reshape(keys.shape[0], -1)[:, : self.key_w]
+ keys = keys.astype(np.float32)
+
+ values = self.seed.randint(0, 2, keys.shape).astype(np.float32)
+
+ perm = self.seed.permutation(length)
+ keys_perm = keys[perm, :]
+ values_perm = values[perm, :]
+
+ i_p = np.zeros((2 * length + 2, self.bit_w + 1), dtype=np.float32)
+ i_p[:length, : self.key_w] = keys
+ i_p[:length, self.key_w : -1] = values
+ i_p[length + 1 : -1, : self.key_w] = keys_perm
+
+ i_p[length, -1] = 1
+ i_p[-1, -1] = 1
+
+ o_p = np.zeros((2 * length + 2, self.key_w), dtype=np.float32)
+ o_p[length + 1 : -1] = values_perm
+
+ return self.transform({"input": i_p, "output": o_p})
diff --git a/Dataset/Bitmap/KeyValue2Way.py b/Dataset/Bitmap/KeyValue2Way.py
new file mode 100644
index 0000000..0656b10
--- /dev/null
+++ b/Dataset/Bitmap/KeyValue2Way.py
@@ -0,0 +1,89 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import math
+import numpy as np
+from .BitmapTask import BitmapTask
+from Utils.Seed import get_randstate
+
+
+class KeyValue2Way(BitmapTask):
+ def __init__(self, length=None, bit_w=8, transform=lambda x: x):
+ assert bit_w % 2 == 0, "bit_w must be even"
+ super(KeyValue2Way, self).__init__()
+ self.length = length
+ self.bit_w = bit_w
+ self.transform = transform
+ self.seed = None
+ self.key_w = self.bit_w // 2
+ self.max_key = 2 ** self.key_w - 1
+
+ def __getitem__(self, key):
+ if self.seed is None:
+ self.seed = get_randstate()
+
+ if self.length is None:
+ # Random length batch hack.
+ length = key
+ else:
+ length = self.length() if callable(self.length) else self.length
+
+ # keys must be unique
+ keys = None
+ last_size = 0
+ while last_size != length:
+ res = self.seed.random_integers(0, self.max_key, size=(length - last_size))
+ if keys is not None:
+ keys = np.concatenate((res, keys))
+ else:
+ keys = res
+
+ keys = np.unique(keys)
+ last_size = keys.size
+
+ # view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it
+ keys = keys.view(np.uint8).reshape(length, -1)
+ keys = keys[:, : math.ceil(self.key_w / 8)]
+ keys = np.unpackbits(np.expand_dims(keys, -1), axis=-1)
+ keys = np.flip(keys, axis=-1).reshape(keys.shape[0], -1)[:, : self.key_w]
+ keys = keys.astype(np.float32)
+
+ values = self.seed.randint(0, 2, keys.shape).astype(np.float32)
+
+ perm = self.seed.permutation(length)
+ keys_perm = keys[perm, :]
+ values_perm = values[perm, :]
+
+ i_p = np.zeros((3 * (length + 1), self.bit_w + 2), dtype=np.float32)
+ o_p = np.zeros((3 * (length + 1), self.key_w), dtype=np.float32)
+
+ i_p[:length, : self.key_w] = keys
+ i_p[:length, self.key_w : -2] = values
+ i_p[length + 1 : 2 * length + 1, : self.key_w] = keys_perm
+ o_p[length + 1 : 2 * length + 1] = values_perm
+
+ perm = self.seed.permutation(length)
+ keys_perm = keys[perm, :]
+ values_perm = values[perm, :]
+
+ o_p[2 * (length + 1) : -1] = keys_perm
+ i_p[2 * (length + 1) : -1, : self.key_w] = values_perm
+
+ i_p[length, -2] = 1
+ i_p[2 * length + 1, -1] = 1
+ i_p[-1, -2:] = 1
+
+ return self.transform({"input": i_p, "output": o_p})
diff --git a/Dataset/Bitmap/__init__.py b/Dataset/Bitmap/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/Dataset/NLP/.DS_Store b/Dataset/NLP/.DS_Store
new file mode 100644
index 0000000..e98ece8
Binary files /dev/null and b/Dataset/NLP/.DS_Store differ
diff --git a/Dataset/NLP/.gitignore b/Dataset/NLP/.gitignore
new file mode 100644
index 0000000..06cf653
--- /dev/null
+++ b/Dataset/NLP/.gitignore
@@ -0,0 +1 @@
+cache
diff --git a/Dataset/NLP/NLPTask.py b/Dataset/NLP/NLPTask.py
new file mode 100644
index 0000000..419347e
--- /dev/null
+++ b/Dataset/NLP/NLPTask.py
@@ -0,0 +1,153 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import torch
+import torch.nn.functional as F
+import os
+from .Vocabulary import Vocabulary
+from Utils import Visdom
+from Utils import universal as U
+
+import numpy as np
+
+
+class NLPTask(torch.utils.data.Dataset):
+ def __init__(self):
+ super(NLPTask, self).__init__()
+
+ self.my_dir = os.path.abspath(os.path.dirname(__file__))
+ self.cache_dir = os.path.join(self.my_dir, "cache")
+
+ if not os.path.isdir(self.cache_dir):
+ os.makedirs(self.cache_dir)
+
+ self.vocabulary = self._load_vocabulary()
+
+ self._preview = None
+
+ def _load_vocabulary(self):
+ cache_file = os.path.join(self.cache_dir, "vocabulary.pth")
+ if not os.path.isfile(cache_file):
+ print("WARNING: Vocabulary not found. Removing cached files.")
+ for f in os.listdir(self.cache_dir):
+ f = os.path.join(self.cache_dir, f)
+ if f.endswith(".pth"):
+ print(" " + f)
+ os.remove(f)
+ return Vocabulary()
+ else:
+ return torch.load(cache_file)
+
+ def save_vocabulary(self):
+ cache_file = os.path.join(self.cache_dir, "vocabulary.pth")
+ if os.path.isfile(cache_file):
+ os.remove(cache_file)
+ torch.save(self.vocabulary, cache_file)
+
+ def loss(self, net_output, target):
+ s = list(net_output.size())
+ return (
+ F.cross_entropy(
+ net_output.view([s[0] * s[1], s[2]]),
+ target.view([-1]),
+ ignore_index=0,
+ reduction="sum",
+ )
+ / s[0]
+ )
+
+
+ def demon_loss(self, net_output, target, saved_actions, device):
+ """
+ computes the loss for the demon
+ :param net_output:
+ :param target:
+ :param saved_actions:
+ :return:
+ """
+ net_output = net_output.detach()
+ s = list(net_output.size())
+ loss = F.cross_entropy(
+ net_output.view([s[0] * s[1], s[2]]),
+ target.view([-1]),
+ ignore_index=0,
+ reduction="none",
+ ).view(s[0], s[1])
+
+ policy_losses = [] # list to save actor (policy) loss
+
+ discount_factor = 0.99
+ for i in range(0, loss.size(1)): # computing expected total reward
+ discount_vector = torch.from_numpy(
+ np.array([np.power(discount_factor, i) for i in range(loss.size(1) - i)])).to(device)
+ policy_losses.append(((saved_actions[i].log_prob).squeeze(1) * (discount_vector * loss[:, i:]).mean(dim=1)))
+
+ demon_loss = torch.stack(policy_losses).mean(dim=0)
+
+ return demon_loss
+
+
+ def generate_preview_text(self, data, net_output):
+ input = U.to_numpy(data["input"][0])
+ reference = U.to_numpy(data["output"][0])
+ net_out = U.argmax(net_output[0], -1)
+ net_out = U.to_numpy(net_out)
+
+ res = ""
+ start_index = 0
+
+ for i in range(input.shape[0]):
+ if reference[i] != 0:
+ if start_index < i:
+ end_index = i
+ while end_index > start_index and input[end_index] == 0:
+ end_index -= 1
+
+ if end_index > start_index:
+ sentence = (
+ " ".join(
+ self.vocabulary.indices_to_sentence(
+ input[start_index:i].tolist()
+ )
+ )
+ .replace(" .", ".")
+ .replace(" ,", ",")
+ .replace(" ?", "?")
+ .split(". ")
+ )
+ sentence = ". ".join([s.capitalize() for s in sentence])
+ res += sentence + "
"
+
+ start_index = i + 1
+
+ match = reference[i] == net_out[i]
+ res += '%s [%s]
' % (
+ "green" if match else "red",
+ self.vocabulary.indices_to_sentence([net_out[i]])[0],
+ self.vocabulary.indices_to_sentence([reference[i]])[0],
+ )
+ return res
+
+ def visualize_preview(self, data, net_output):
+ res = self.generate_preview_text(data, net_output)
+
+ if self._preview is None:
+ self._preview = Visdom.Text("Preview")
+
+ self._preview.set(res)
+
+ def set_dump_dir(self, dir):
+ pass
diff --git a/Dataset/NLP/Vocabulary.py b/Dataset/NLP/Vocabulary.py
new file mode 100644
index 0000000..5b77fdc
--- /dev/null
+++ b/Dataset/NLP/Vocabulary.py
@@ -0,0 +1,52 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+
+class Vocabulary:
+ def __init__(self):
+ self.words = {"-": 0, "?": 1, "": 2}
+ self.inv_words = {0: "-", 1: "?", 2: ""}
+ self.next_id = 3
+ self.punctations = [".", "?", ","]
+
+ def _process_word(self, w, add_words):
+ if not w.isalpha() and w not in self.punctations:
+ print("WARNING: word with unknown characters: %s", w)
+ w = ""
+
+ if w not in self.words:
+ if add_words:
+ self.words[w] = self.next_id
+ self.inv_words[self.next_id] = w
+ self.next_id += 1
+ else:
+ w = ""
+
+ return self.words[w]
+
+ def sentence_to_indices(self, sentence, add_words=True):
+ for p in self.punctations:
+ sentence = sentence.replace(p, " %s " % p)
+
+ return [
+ self._process_word(w, add_words) for w in sentence.lower().split(" ") if w
+ ]
+
+ def indices_to_sentence(self, indices):
+ return [self.inv_words[i] for i in indices]
+
+ def __len__(self):
+ return len(self.words)
diff --git a/Dataset/NLP/__init__.py b/Dataset/NLP/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/Dataset/NLP/bAbi.py b/Dataset/NLP/bAbi.py
new file mode 100644
index 0000000..013775f
--- /dev/null
+++ b/Dataset/NLP/bAbi.py
@@ -0,0 +1,297 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import os
+import glob
+import torch
+from collections import namedtuple
+import numpy as np
+from .NLPTask import NLPTask
+from Utils import Visdom
+
+Sentence = namedtuple("Sentence", ["sentence", "answer", "supporting_facts"])
+
+
+class bAbiDataset(NLPTask):
+ URL = "http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz"
+ DIR_NAME = "tasks_1-20_v1-2"
+
+ def __init__(
+ self, dirs=["en-10k"], sets=None, think_steps=0, dir_name=None, name=None
+ ):
+ super(bAbiDataset, self).__init__()
+
+ self._test_res_win = None
+ self._test_plot_win = None
+ self._think_steps = think_steps
+
+ if dir_name is None:
+ self._download()
+ dir_name = os.path.join(self.cache_dir, self.DIR_NAME)
+
+ self.data = {}
+ for d in dirs:
+ self.data[d] = self._load_or_create(os.path.join(dir_name, d))
+
+ self.all_tasks = None
+ self.name = name
+ self.use(sets=sets)
+
+ def _make_active_list(self, tasks, sets, dirs):
+ def verify(name, checker):
+ if checker is None:
+ return True
+
+ if callable(checker):
+ return checker(name)
+ elif isinstance(checker, list):
+ return name in checker
+ else:
+ return name == checker
+
+ res = []
+ for dirname, setlist in self.data.items():
+ if not verify(dirname, dirs):
+ continue
+
+ for sname, tasklist in setlist.items():
+ if not verify(sname, sets):
+ continue
+
+ for task, data in tasklist.items():
+ name = task.split("_")[0][2:]
+ if not verify(name, tasks):
+ continue
+
+ res += [(d, dirname, task, sname) for d in data]
+
+ return res
+
+ def use(self, tasks=None, sets=None, dirs=None):
+ self.all_tasks = self._make_active_list(tasks=tasks, sets=sets, dirs=dirs)
+
+ def __len__(self):
+ return len(self.all_tasks)
+
+ def _get_seq(self, index):
+ return self.all_tasks[index]
+
+ def _seq_to_nn_input(self, seq):
+ in_arr = []
+ out_arr = []
+ hasAnswer = False
+ for sentence in seq[0]:
+ in_arr += sentence.sentence
+ out_arr += [0] * len(sentence.sentence)
+ if sentence.answer is not None:
+ in_arr += [0] * (len(sentence.answer) + self._think_steps)
+ out_arr += [0] * self._think_steps + sentence.answer
+ hasAnswer = True
+
+ in_arr = np.asarray(in_arr, np.int64)
+ out_arr = np.asarray(out_arr, np.int64)
+
+ return {
+ "input": in_arr,
+ "output": out_arr,
+ "meta": {"dir": seq[1], "task": seq[2], "set": seq[3]},
+ }
+
+ def __getitem__(self, item):
+ seq = self._get_seq(item)
+ return self._seq_to_nn_input(seq)
+
+ def _load_or_create(self, directory):
+ cache_name = directory.replace("/", "_")
+ cache_file = os.path.join(self.cache_dir, cache_name + ".pth")
+ if not os.path.isfile(cache_file):
+ print("bAbI: Loading %s" % directory)
+ res = self._load_dir(directory)
+ print("Write: ", cache_file)
+ self.save_vocabulary()
+ torch.save(res, cache_file)
+ else:
+ res = torch.load(cache_file)
+ return res
+
+ def _download(self):
+ if not os.path.isdir(os.path.join(self.cache_dir, self.DIR_NAME)):
+ print(self.URL)
+ print("bAbi data not found. Downloading...")
+ import requests, tarfile, io
+
+ request = requests.get(
+ self.URL,
+ headers={
+ "User-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36"
+ },
+ )
+
+ decompressed_file = tarfile.open(
+ fileobj=io.BytesIO(request.content), mode="r|gz"
+ )
+ decompressed_file.extractall(self.cache_dir)
+ print("Done")
+
+ def _load_dir(
+ self,
+ directory,
+ parse_name=lambda x: x.split(".")[0],
+ parse_set=lambda x: x.split(".")[0].split("_")[-1],
+ ):
+ res = {}
+ for f in glob.glob(os.path.join(directory, "**", "*.txt"), recursive=True):
+ basename = os.path.basename(f)
+ task_name = parse_name(basename)
+ set = parse_set(basename)
+ print("Loading", f)
+
+ s = res.get(set)
+ if s is None:
+ s = {}
+ res[set] = s
+ s[task_name] = self._load_task(f, task_name)
+
+ return res
+
+ def _load_task(self, filename, task_name):
+ task = []
+ currTask = []
+
+ nextIndex = 1
+ with open(filename, "r") as f:
+ for line in f:
+ line = [f.strip() for f in line.split("\t")]
+ line[0] = line[0].split(" ")
+ i = int(line[0][0])
+ line[0] = " ".join(line[0][1:])
+
+ if i != nextIndex:
+ nextIndex = i
+ task.append(currTask)
+ currTask = []
+
+ isQuestion = len(line) > 1
+ currTask.append(
+ Sentence(
+ self.vocabulary.sentence_to_indices(line[0]),
+ self.vocabulary.sentence_to_indices(line[1].replace(",", " "))
+ if isQuestion
+ else None,
+ [int(f) for f in line[2].split(" ")] if isQuestion else None,
+ )
+ )
+
+ nextIndex += 1
+ return task
+
+ def start_test(self):
+ return {}
+
+ def veify_result(self, test, data, net_output):
+ _, net_output = net_output.max(-1)
+
+ ref = data["output"]
+
+ mask = 1.0 - ref.eq(0).float()
+
+ correct = (torch.eq(net_output, ref).float() * mask).sum(-1)
+ total = mask.sum(-1)
+
+ correct = correct.data.cpu().numpy()
+ total = total.data.cpu().numpy()
+
+ for i in range(correct.shape[0]):
+ task = data["meta"][i]["task"]
+ if task not in test:
+ test[task] = {"total": 0, "correct": 0}
+
+ d = test[task]
+ d["total"] += total[i]
+ d["correct"] += correct[i]
+
+ def _ensure_test_wins_exists(self, legend=None):
+ if self._test_res_win is None:
+ n = ("[" + self.name + "]") if self.name is not None else ""
+ self._test_res_win = Visdom.Text("Test results" + n)
+ self._test_plot_win = Visdom.Plot2D("Test results" + n, legend=legend)
+ elif self._test_plot_win.legend is None:
+ self._test_plot_win.set_legend(legend=legend)
+
+ def show_test_results(self, iteration, test):
+ res = {k: v["correct"] / v["total"] for k, v in test.items()}
+
+ t = ""
+
+ all_keys = list(res.keys())
+
+ num_keys = [k for k in all_keys if k.startswith("qa")]
+ tmp = [
+ i[0]
+ for i in sorted(
+ enumerate(num_keys), key=lambda x: int(x[1][2:].split("_")[0])
+ )
+ ]
+ num_keys = [num_keys[j] for j in tmp]
+
+ all_keys = num_keys + sorted([k for k in all_keys if not k.startswith("qa")])
+
+ err_precent = [(1.0 - res[k]) * 100.0 for k in all_keys]
+
+ n_passed = sum([int(p <= 5) for p in err_precent])
+ n_total = len(err_precent)
+ err_precent = err_precent + [sum(err_precent) / len(err_precent)]
+ all_keys += ["mean"]
+
+ for i, k in enumerate(all_keys):
+ t += '%s: %.2f%%
' % (
+ "green" if err_precent[i] <= 5 else "red",
+ k,
+ err_precent[i],
+ )
+
+ t += "
Total: %d of %d passed." % (n_passed, n_total)
+
+ self._ensure_test_wins_exists(
+ legend=[i.split("_")[0] if i.startswith("qa") else i for i in all_keys]
+ )
+
+ self._test_res_win.set(t)
+ self._test_plot_win.add_point(iteration, err_precent)
+
+ def state_dict(self):
+ if self._test_res_win is not None:
+ return {
+ "_test_res_win": self._test_res_win.state_dict(),
+ "_test_plot_win": self._test_plot_win.state_dict(),
+ }
+ else:
+ return {}
+
+ def load_state_dict(self, state):
+ if state:
+ self._ensure_test_wins_exists()
+ self._test_res_win.load_state_dict(state["_test_res_win"])
+ self._test_plot_win.load_state_dict(state["_test_plot_win"])
+ self._test_plot_win.legend = None
+
+ def visualize_preview(self, data, net_output):
+ res = self.generate_preview_text(data, net_output)
+ res = ("%s
" % data["meta"][0]["task"]) + res
+ if self._preview is None:
+ self._preview = Visdom.Text("Preview")
+
+ self._preview.set(res)
diff --git a/Dataset/__init__.py b/Dataset/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/Dataset/__init__.py
@@ -0,0 +1 @@
+
diff --git a/Models/.DS_Store b/Models/.DS_Store
new file mode 100644
index 0000000..e98ece8
Binary files /dev/null and b/Models/.DS_Store differ
diff --git a/Models/DNCA.py b/Models/DNCA.py
new file mode 100644
index 0000000..6dd7a30
--- /dev/null
+++ b/Models/DNCA.py
@@ -0,0 +1,965 @@
+# The Initial DNC Copyright 2017 Robert Csordas. All Rights Reserved.
+# The modification of the initial DNC implementation by Ari Azarafrooz.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+
+import torch
+import torch.utils.data
+import torch.nn.functional as F
+import torch.nn.init as init
+import functools
+import math
+
+
+def oneplus(t):
+ return F.softplus(t, 1, 20) + 1.0
+
+
+def get_next_tensor_part(src, dims, prev_pos=0):
+ if not isinstance(dims, list):
+ dims = [dims]
+ n = functools.reduce(lambda x, y: x * y, dims)
+ data = src.narrow(-1, prev_pos, n)
+ return (
+ data.contiguous().view(list(data.size())[:-1] + dims)
+ if len(dims) > 1
+ else data,
+ prev_pos + n,
+ )
+
+
+def split_tensor(src, shapes):
+ pos = 0
+ res = []
+ for s in shapes:
+ d, pos = get_next_tensor_part(src, s, pos)
+ res.append(d)
+ return res
+
+
+def dict_get(dict, name):
+ return dict.get(name) if dict is not None else None
+
+
+def dict_append(dict, name, val):
+ if dict is not None:
+ l = dict.get(name)
+ if not l:
+ l = []
+ dict[name] = l
+ l.append(val)
+
+
+def init_debug(debug, initial):
+ if debug is not None and not debug:
+ debug.update(initial)
+
+
+def merge_debug_tensors(d, dim):
+ if d is not None:
+ for k, v in d.items():
+ if isinstance(v, dict):
+ merge_debug_tensors(v, dim)
+ elif isinstance(v, list):
+ d[k] = torch.stack(v, dim)
+
+
+def linear_reset(module, gain=1.0):
+ assert isinstance(module, torch.nn.Linear)
+ init.xavier_uniform_(module.weight, gain=gain)
+ s = module.weight.size(1)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+_EPS = 1e-6
+
+
+class AllocationManager(torch.nn.Module):
+ def __init__(self):
+ super(AllocationManager, self).__init__()
+ self.usages = None
+ self.zero_usages = None
+ self.debug_sequ_init = False
+ self.one = None
+
+ def _init_sequence(self, prev_read_distributions):
+ # prev_read_distributions size is [batch, n_heads, cell count]
+ s = prev_read_distributions.size()
+ if self.zero_usages is None or list(self.zero_usages.size()) != [s[0], s[-1]]:
+ self.zero_usages = torch.zeros(
+ s[0], s[-1], device=prev_read_distributions.device
+ )
+ if self.debug_sequ_init:
+ self.zero_usages += torch.arange(0, s[-1]).unsqueeze(0) * 1e-10
+
+ self.usages = self.zero_usages
+
+ def _init_consts(self, device):
+ if self.one is None:
+ self.one = torch.ones(1, device=device)
+
+ def new_sequence(self):
+ self.usages = None
+
+ def update_usages(
+ self, prev_write_distribution, prev_read_distributions, free_gates
+ ):
+ # Read distributions shape: [batch, n_heads, cell count]
+ # Free gates shape: [batch, n_heads]
+
+ self._init_consts(prev_read_distributions.device)
+ phi = torch.addcmul(
+ self.one, -1, free_gates.unsqueeze(-1), prev_read_distributions
+ ).prod(-2)
+ # Phi is the free tensor, sized [batch, cell count]
+
+ # If memory usage counter if doesn't exists
+ if self.usages is None:
+ self._init_sequence(prev_read_distributions)
+ # in first timestep nothing is written or read yet, so we don't need any further processing
+ else:
+ self.usages = (
+ torch.addcmul(
+ self.usages, 1, prev_write_distribution.detach(), (1 - self.usages)
+ )
+ * phi
+ )
+
+ return phi
+
+ def forward(self, prev_write_distribution, prev_read_distributions, free_gates):
+ phi = self.update_usages(
+ prev_write_distribution, prev_read_distributions, free_gates
+ )
+ sorted_usage, free_list = (self.usages * (1.0 - _EPS) + _EPS).sort(-1)
+
+ u_prod = sorted_usage.cumprod(-1)
+ one_minus_usage = 1.0 - sorted_usage
+ sorted_scores = torch.cat(
+ [one_minus_usage[..., 0:1], one_minus_usage[..., 1:] * u_prod[..., :-1]],
+ dim=-1,
+ )
+
+ return sorted_scores.clone().scatter_(-1, free_list, sorted_scores), phi
+
+
+class ContentAddressGenerator(torch.nn.Module):
+ def __init__(
+ self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False
+ ):
+ super(ContentAddressGenerator, self).__init__()
+ self.disable_content_norm = disable_content_norm
+ self.mask_min = mask_min
+ self.disable_key_masking = disable_key_masking
+
+ def forward(self, memory, keys, betas, mask=None):
+ # Memory shape [batch, cell count, word length]
+ # Key shape [batch, n heads*, word length]
+ # Betas shape [batch, n heads]
+ if mask is not None and self.mask_min != 0:
+ mask = mask * (1.0 - self.mask_min) + self.mask_min
+
+ single_head = keys.dim() == 2
+ if single_head:
+ # Single head
+ keys = keys.unsqueeze(1)
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+
+ memory = memory.unsqueeze(1)
+ keys = keys.unsqueeze(-2)
+
+ if mask is not None:
+ mask = mask.unsqueeze(-2)
+ memory = memory * mask
+ if not self.disable_key_masking:
+ keys = keys * mask
+
+ # Shape [batch, n heads, cell count]
+ norm = keys.norm(dim=-1)
+ if not self.disable_content_norm:
+ norm = norm * memory.norm(dim=-1)
+
+ scores = (memory * keys).sum(-1) / (norm + _EPS)
+ scores *= betas.unsqueeze(-1)
+
+ res = F.softmax(scores, scores.dim() - 1)
+ return res.squeeze(1) if single_head else res
+
+
+class WriteHead(torch.nn.Module):
+ @staticmethod
+ def create_write_archive(write_dist, erase_vector, write_vector, phi):
+ return dict(
+ write_dist=write_dist,
+ erase_vector=erase_vector,
+ write_vector=write_vector,
+ phi=phi,
+ )
+
+ def __init__(
+ self,
+ dealloc_content=True,
+ disable_content_norm=False,
+ mask_min=0.0,
+ disable_key_masking=False,
+ ):
+ super(WriteHead, self).__init__()
+ self.write_content_generator = ContentAddressGenerator(
+ disable_content_norm,
+ mask_min=mask_min,
+ disable_key_masking=disable_key_masking,
+ )
+ self.allocation_manager = AllocationManager()
+ self.last_write = None
+ self.dealloc_content = dealloc_content
+ self.new_sequence()
+
+ def new_sequence(self):
+ self.last_write = None
+ self.allocation_manager.new_sequence()
+
+ @staticmethod
+ def mem_update(memory, write_dist, erase_vector, write_vector, phi):
+ # In original paper the memory content is NOT deallocated, which makes content based addressing basically
+ # unusable when multiple similar steps should be done. The reason for this is that the memory contents are
+ # still there, so the lookup will find them, unless an allocation clears it before the next search, which is
+ # completely random. So I'm arguing that erase matrix should also take in account the free gates (multiply it
+ # with phi)
+ write_dist = write_dist.unsqueeze(-1)
+
+ erase_matrix = 1.0 - write_dist * erase_vector.unsqueeze(-2)
+ if phi is not None:
+ erase_matrix = erase_matrix * phi.unsqueeze(-1)
+
+ update_matrix = write_dist * write_vector.unsqueeze(-2)
+ return memory * erase_matrix + update_matrix
+
+ def forward(
+ self,
+ demon_action,
+ memory,
+ write_content_key,
+ write_beta,
+ erase_vector,
+ write_vector,
+ alloc_gate,
+ write_gate,
+ free_gates,
+ prev_read_dist,
+ write_mask=None,
+ debug=None,
+ ):
+ last_w_dist = (
+ self.last_write["write_dist"] if self.last_write is not None else None
+ )
+
+ content_dist = self.write_content_generator(
+ memory, write_content_key, write_beta, mask=write_mask
+ )
+ alloc_dist, phi = self.allocation_manager(
+ last_w_dist, prev_read_dist, free_gates
+ )
+
+ # Shape [batch, cell count]
+ write_dist = write_gate * (
+ alloc_gate * alloc_dist + (1 - alloc_gate) * content_dist
+ )
+ self.last_write = WriteHead.create_write_archive(
+ write_dist,
+ erase_vector,
+ write_vector,
+ phi if self.dealloc_content else None,
+ )
+
+ dict_append(debug, "alloc_dist", alloc_dist)
+ dict_append(debug, "write_dist", write_dist)
+ dict_append(debug, "mem_usages", self.allocation_manager.usages)
+ dict_append(debug, "free_gates", free_gates)
+ dict_append(debug, "write_betas", write_beta)
+ dict_append(debug, "write_gate", write_gate)
+ dict_append(debug, "write_vector", write_vector)
+ dict_append(debug, "alloc_gate", alloc_gate)
+ dict_append(debug, "erase_vector", erase_vector)
+ if write_mask is not None:
+ dict_append(debug, "write_mask", write_mask)
+
+ return WriteHead.mem_update(memory, **self.last_write)
+
+
+class RawWriteHead(torch.nn.Module):
+ def __init__(
+ self,
+ n_read_heads,
+ word_length,
+ use_mask=False,
+ dealloc_content=True,
+ disable_content_norm=False,
+ mask_min=0.0,
+ disable_key_masking=False,
+ ):
+ super(RawWriteHead, self).__init__()
+ self.write_head = WriteHead(
+ dealloc_content=dealloc_content,
+ disable_content_norm=disable_content_norm,
+ mask_min=mask_min,
+ disable_key_masking=disable_key_masking,
+ )
+ self.word_length = word_length
+ self.n_read_heads = n_read_heads
+ self.use_mask = use_mask
+ self.input_size = (
+ 3 * self.word_length
+ + self.n_read_heads
+ + 3
+ + (self.word_length if use_mask else 0)
+ )
+
+ def new_sequence(self):
+ self.write_head.new_sequence()
+
+ def get_prev_write(self):
+ return self.write_head.last_write
+
+ def forward(self, demon_action, memory, nn_output, prev_read_dist, debug):
+ shapes = (
+ [[self.word_length]] * (4 if self.use_mask else 3)
+ + [[self.n_read_heads]]
+ + [[1]] * 3
+ )
+ tensors = split_tensor(nn_output, shapes)
+
+ if self.use_mask:
+ write_mask = torch.sigmoid(tensors[0])
+ tensors = tensors[1:]
+ else:
+ write_mask = None
+
+ (
+ write_content_key,
+ erase_vector,
+ write_vector,
+ free_gates,
+ write_beta,
+ alloc_gate,
+ write_gate,
+ ) = tensors
+
+ erase_vector = torch.sigmoid(erase_vector)
+ free_gates = torch.sigmoid(free_gates)
+ write_beta = oneplus(write_beta)
+ alloc_gate = torch.sigmoid(alloc_gate)
+ write_gate = torch.sigmoid(write_gate)
+
+ return self.write_head(
+ demon_action,
+ memory,
+ write_content_key,
+ write_beta,
+ erase_vector,
+ write_vector,
+ alloc_gate,
+ write_gate,
+ free_gates,
+ prev_read_dist,
+ debug=debug,
+ write_mask=write_mask,
+ )
+
+ def get_neural_input_size(self):
+ return self.input_size
+
+
+class TemporalMemoryLinkage(torch.nn.Module):
+ def __init__(self):
+ super(TemporalMemoryLinkage, self).__init__()
+ self.temp_link_mat = None
+ self.precedence_weighting = None
+ self.diag_mask = None
+
+ self.initial_temp_link_mat = None
+ self.initial_precedence_weighting = None
+ self.initial_diag_mask = None
+ self.initial_shape = None
+
+ def new_sequence(self):
+ self.temp_link_mat = None
+ self.precedence_weighting = None
+ self.diag_mask = None
+
+ def _init_link(self, w_dist):
+ s = list(w_dist.size())
+ if self.initial_shape is None or s != self.initial_shape:
+ self.initial_temp_link_mat = torch.zeros(s[0], s[-1], s[-1]).to(
+ w_dist.device
+ )
+ self.initial_precedence_weighting = torch.zeros(s[0], s[-1]).to(
+ w_dist.device
+ )
+ self.initial_diag_mask = (
+ 1.0 - torch.eye(s[-1]).unsqueeze(0).to(w_dist)
+ ).detach()
+
+ self.temp_link_mat = self.initial_temp_link_mat
+ self.precedence_weighting = self.initial_precedence_weighting
+ self.diag_mask = self.initial_diag_mask
+
+ def _update_precedence(self, w_dist):
+ # w_dist shape: [ batch, cell count ]
+ self.precedence_weighting = (
+ 1.0 - w_dist.sum(-1, keepdim=True)
+ ) * self.precedence_weighting + w_dist
+
+ def _update_links(self, w_dist):
+ if self.temp_link_mat is None:
+ self._init_link(w_dist)
+
+ wt_i = w_dist.unsqueeze(-1)
+ wt_j = w_dist.unsqueeze(-2)
+ pt_j = self.precedence_weighting.unsqueeze(-2)
+
+ self.temp_link_mat = (
+ (1 - wt_i - wt_j) * self.temp_link_mat + wt_i * pt_j
+ ) * self.diag_mask
+
+ def forward(self, w_dist, prev_r_dists, debug=None):
+ self._update_links(w_dist)
+ self._update_precedence(w_dist)
+
+ # prev_r_dists shape: [ batch, n heads, cell count ]
+ # Emulate matrix-vector multiplication by broadcast and sum. This way we don't need to transpose the matrix
+ tlm_multi_head = self.temp_link_mat.unsqueeze(1)
+
+ forward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-2)).sum(-1)
+ backward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-1)).sum(-2)
+
+ dict_append(debug, "forward_dists", forward_dist)
+ dict_append(debug, "backward_dists", backward_dist)
+ dict_append(debug, "precedence_weights", self.precedence_weighting)
+
+ # output shapes [ batch, n_heads, cell_count ]
+ return forward_dist, backward_dist
+
+
+class ReadHead(torch.nn.Module):
+ def __init__(
+ self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False
+ ):
+ super(ReadHead, self).__init__()
+ self.content_addr_generator = ContentAddressGenerator(
+ disable_content_norm=disable_content_norm,
+ mask_min=mask_min,
+ disable_key_masking=disable_key_masking,
+ )
+ self.read_dist = None
+ self.read_data = None
+ self.new_sequence()
+
+ def new_sequence(self):
+ self.read_dist = None
+ self.read_data = None
+
+ def forward(
+ self,
+ memory,
+ read_content_keys,
+ read_betas,
+ forward_dist,
+ backward_dist,
+ gates,
+ read_mask=None,
+ debug=None,
+ ):
+ content_dist = self.content_addr_generator(
+ memory, read_content_keys, read_betas, mask=read_mask
+ )
+
+ self.read_dist = (
+ backward_dist * gates[..., 0:1]
+ + content_dist * gates[..., 1:2]
+ + forward_dist * gates[..., 2:]
+ )
+
+ # memory shape: [ batch, cell count, word_length ]
+ # read_dist shape: [ batch, n heads, cell count ]
+ # result shape: [ batch, n_heads, word_length ]
+ self.read_data = (memory.unsqueeze(1) * self.read_dist.unsqueeze(-1)).sum(-2)
+
+ dict_append(debug, "content_dist", content_dist)
+ dict_append(debug, "balance", gates)
+ dict_append(debug, "read_dist", self.read_dist)
+ dict_append(debug, "read_content_keys", read_content_keys)
+ if read_mask is not None:
+ dict_append(debug, "read_mask", read_mask)
+ dict_append(debug, "read_betas", read_betas.unsqueeze(-2))
+ if read_mask is not None:
+ dict_append(debug, "read_mask", read_mask)
+
+ return self.read_data
+
+
+class RawReadHead(torch.nn.Module):
+ def __init__(
+ self,
+ n_heads,
+ word_length,
+ use_mask=False,
+ disable_content_norm=False,
+ mask_min=0.0,
+ disable_key_masking=False,
+ ):
+ super(RawReadHead, self).__init__()
+ self.read_head = ReadHead(
+ disable_content_norm=disable_content_norm,
+ mask_min=mask_min,
+ disable_key_masking=disable_key_masking,
+ )
+ self.n_heads = n_heads
+ self.word_length = word_length
+ self.use_mask = use_mask
+ self.input_size = self.n_heads * (
+ self.word_length * (2 if use_mask else 1) + 3 + 1
+ )
+
+ def get_prev_dist(self, memory):
+ if self.read_head.read_dist is not None:
+ return self.read_head.read_dist
+ else:
+ m_shape = memory.size()
+ return torch.zeros(m_shape[0], self.n_heads, m_shape[1]).to(memory)
+
+ def get_prev_data(self, memory):
+ if self.read_head.read_data is not None:
+ return self.read_head.read_data
+ else:
+ m_shape = memory.size()
+ return torch.zeros(m_shape[0], self.n_heads, m_shape[-1]).to(memory)
+
+ def new_sequence(self):
+ self.read_head.new_sequence()
+
+ def forward(self, memory, nn_output, forward_dist, backward_dist, debug):
+ shapes = [[self.n_heads, self.word_length]] * (2 if self.use_mask else 1) + [
+ [self.n_heads],
+ [self.n_heads, 3],
+ ]
+ tensors = split_tensor(nn_output, shapes)
+
+ if self.use_mask:
+ read_mask = torch.sigmoid(tensors[0])
+ tensors = tensors[1:]
+ else:
+ read_mask = None
+
+ keys, betas, gates = tensors
+
+ betas = oneplus(betas)
+ gates = F.softmax(gates, gates.dim() - 1)
+
+ return self.read_head(
+ memory,
+ keys,
+ betas,
+ forward_dist,
+ backward_dist,
+ gates,
+ debug=debug,
+ read_mask=read_mask,
+ )
+
+ def get_neural_input_size(self):
+ return self.input_size
+
+
+class DistSharpnessEnhancer(torch.nn.Module):
+ def __init__(self, n_heads):
+ super(DistSharpnessEnhancer, self).__init__()
+ self.n_heads = n_heads if isinstance(n_heads, list) else [n_heads]
+ self.n_data = sum(self.n_heads)
+
+ def forward(self, nn_input, *dists):
+ assert len(dists) == len(self.n_heads)
+ nn_input = oneplus(nn_input[..., : self.n_data])
+ factors = split_tensor(nn_input, self.n_heads)
+
+ res = []
+ for i, d in enumerate(dists):
+ s = list(d.size())
+ ndim = d.dim()
+ f = factors[i]
+ if ndim == 2:
+ assert self.n_heads[i] == 1
+ elif ndim == 3:
+ f = f.unsqueeze(-1)
+ else:
+ assert False
+
+ d += _EPS
+ d = d / d.max(dim=-1, keepdim=True)[0]
+ d = d.pow(f)
+ d = d / d.sum(dim=-1, keepdim=True)
+ res.append(d)
+ return res
+
+ def get_neural_input_size(self):
+ return self.n_data
+
+
+class DNC(torch.nn.Module):
+ def __init__(
+ self,
+ input_size,
+ output_size,
+ word_length,
+ cell_count,
+ n_read_heads,
+ controller,
+ batch_first=False,
+ clip_controller=20,
+ bias=True,
+ mask=False,
+ dealloc_content=True,
+ link_sharpness_control=True,
+ disable_content_norm=False,
+ mask_min=0.0,
+ disable_key_masking=False,
+ ):
+ super(DNC, self).__init__()
+
+ self.clip_controller = clip_controller
+
+ self.read_head = RawReadHead(
+ n_read_heads,
+ word_length,
+ use_mask=mask,
+ disable_content_norm=disable_content_norm,
+ mask_min=mask_min,
+ disable_key_masking=disable_key_masking,
+ )
+ self.write_head = RawWriteHead(
+ n_read_heads,
+ word_length,
+ use_mask=mask,
+ dealloc_content=dealloc_content,
+ disable_content_norm=disable_content_norm,
+ mask_min=mask_min,
+ disable_key_masking=disable_key_masking,
+ )
+ self.temporal_link = TemporalMemoryLinkage()
+ self.sharpness_control = (
+ DistSharpnessEnhancer([n_read_heads, n_read_heads])
+ if link_sharpness_control
+ else None
+ )
+
+ in_size = input_size + n_read_heads * word_length
+ control_channels = (
+ self.read_head.get_neural_input_size()
+ + self.write_head.get_neural_input_size()
+ + (
+ self.sharpness_control.get_neural_input_size()
+ if self.sharpness_control is not None
+ else 0
+ )
+ )
+
+ self.controller = controller
+ controller.init(in_size)
+ self.controller_to_controls = torch.nn.Linear(
+ controller.get_output_size(), control_channels, bias=bias
+ )
+ self.controller_to_out = torch.nn.Linear(
+ controller.get_output_size(), output_size, bias=bias
+ )
+ self.read_to_out = torch.nn.Linear(
+ word_length * n_read_heads, output_size, bias=bias
+ )
+
+ self.cell_count = cell_count
+ self.word_length = word_length
+
+ self.memory = None
+ self.reset_parameters()
+
+ self.batch_first = batch_first
+ self.zero_mem_tensor = None
+
+ self.mem_state = None
+
+ self.device = (
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ )
+
+ def reset_parameters(self):
+ linear_reset(self.controller_to_controls)
+ linear_reset(self.controller_to_out)
+ linear_reset(self.read_to_out)
+ self.controller.reset_parameters()
+
+ def _step(self, in_data, debug, demon, rollout_storage):
+ init_debug(debug, {"read_head": {}, "write_head": {}, "temporal_links": {}})
+
+ # input shape: [ batch, channels ]
+ batch_size = in_data.size(0)
+
+ # # run the demon if it is used
+ if demon:
+ # Running policy_old:
+ demon_action = demon.select_action(
+ torch.cat([in_data, self.memory.view(batch_size, -1)], -1),
+ rollout_storage,
+ )
+ in_data = in_data + demon_action
+
+ demon_action = None
+
+ # run the controller
+ prev_read_data = self.read_head.get_prev_data(self.memory).view(
+ [batch_size, -1]
+ )
+
+ control_data = self.controller(torch.cat([in_data, prev_read_data], -1))
+
+ # memory ops
+ controls = self.controller_to_controls(control_data).contiguous()
+ controls = (
+ controls.clamp(-self.clip_controller, self.clip_controller)
+ if self.clip_controller is not None
+ else controls
+ )
+
+ shapes = [
+ [self.write_head.get_neural_input_size()],
+ [self.read_head.get_neural_input_size()],
+ ]
+ if self.sharpness_control is not None:
+ shapes.append(self.sharpness_control.get_neural_input_size())
+
+ tensors = split_tensor(controls, shapes)
+
+ write_head_control, read_head_control = tensors[:2]
+ tensors = tensors[2:]
+
+ prev_read_dist = self.read_head.get_prev_dist(self.memory)
+
+ self.memory = self.write_head(
+ demon_action,
+ self.memory,
+ write_head_control,
+ prev_read_dist,
+ debug=dict_get(debug, "write_head"),
+ )
+
+ prev_write = self.write_head.get_prev_write()
+ forward_dist, backward_dist = self.temporal_link(
+ prev_write["write_dist"] if prev_write is not None else None,
+ prev_read_dist,
+ debug=dict_get(debug, "temporal_links"),
+ )
+
+ if self.sharpness_control is not None:
+ forward_dist, backward_dist = self.sharpness_control(
+ tensors[0], forward_dist, backward_dist
+ )
+
+ read_data = self.read_head(
+ self.memory,
+ read_head_control,
+ forward_dist,
+ backward_dist,
+ debug=dict_get(debug, "read_head"),
+ )
+
+ # output:
+ return self.controller_to_out(control_data) + self.read_to_out(
+ read_data.view(batch_size, -1)
+ )
+
+ def _mem_init(self, batch_size, device):
+ if self.zero_mem_tensor is None or self.zero_mem_tensor.size(0) != batch_size:
+ self.zero_mem_tensor = torch.zeros(
+ batch_size, self.cell_count, self.word_length
+ ).to(device)
+
+ self.memory = self.zero_mem_tensor
+
+ if self.mem_state is None:
+ self.mem_state = []
+
+ def forward(self, in_data, debug=None, demon=None, rollout_storage=None):
+ self.write_head.new_sequence()
+ self.read_head.new_sequence()
+ self.temporal_link.new_sequence()
+ self.controller.new_sequence()
+
+ self._mem_init(in_data.size(0 if self.batch_first else 1), in_data.device)
+
+ out_tsteps = []
+
+ if self.batch_first:
+ # input format: batch, time, channels
+ for t in range(in_data.size(1)):
+ out_tsteps.append(
+ self._step(in_data[:, t], debug, demon, rollout_storage)
+ )
+ self.mem_state.append(self.memory.view(in_data.size(0), -1))
+ else:
+ # input format: time, batch, channels
+ for t in range(in_data.size(0)):
+ out_tsteps.append(self._step(in_data[t], debug, demon, rollout_storage))
+ self.mem_state.append(self.memory.view(-1, in_data.size(0)))
+
+ merge_debug_tensors(debug, dim=1 if self.batch_first else 0)
+ return torch.stack(out_tsteps, dim=1 if self.batch_first else 0)
+
+
+class LSTMController(torch.nn.Module):
+ def __init__(self, layer_sizes, out_from_all_layers=True):
+ super(LSTMController, self).__init__()
+ self.out_from_all_layers = out_from_all_layers
+ self.layer_sizes = layer_sizes
+ self.states = None
+ self.outputs = None
+
+ def new_sequence(self):
+ self.states = [None] * len(self.layer_sizes)
+ self.outputs = [None] * len(self.layer_sizes)
+
+ def reset_parameters(self):
+ def init_layer(l, index):
+ size = self.layer_sizes[index]
+ # Initialize all matrices to sigmoid, just data input to tanh
+ a = math.sqrt(3.0) * self.stdevs[i]
+ l.weight.data[0:-size].uniform_(-a, a)
+ a *= init.calculate_gain("tanh")
+ l.weight.data[-size:].uniform_(-a, a)
+ if l.bias is not None:
+ l.bias.data[self.layer_sizes[i] :].fill_(0)
+ # init forget gate to large number.
+ l.bias.data[: self.layer_sizes[i]].fill_(1)
+
+ # xavier init merged input weights
+ for i in range(len(self.layer_sizes)):
+ init_layer(self.in_to_all[i], i)
+ init_layer(self.out_to_all[i], i)
+ if i > 0:
+ init_layer(self.prev_to_all[i - 1], i)
+
+ def _add_modules(self, name, m_list):
+ for i, m in enumerate(m_list):
+ self.add_module("%s_%d" % (name, i), m)
+
+ def init(self, input_size):
+ self.layer_sizes = self.layer_sizes
+
+ # Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
+ # So use xavier init according to this.
+ self.input_sizes = [
+ (self.layer_sizes[i - 1] if i > 0 else 0) + self.layer_sizes[i] + input_size
+ for i in range(len(self.layer_sizes))
+ ]
+ self.stdevs = [
+ math.sqrt(2.0 / (self.layer_sizes[i] + self.input_sizes[i]))
+ for i in range(len(self.layer_sizes))
+ ]
+ self.in_to_all = [
+ torch.nn.Linear(input_size, 4 * self.layer_sizes[i])
+ for i in range(len(self.layer_sizes))
+ ]
+ self.out_to_all = [
+ torch.nn.Linear(self.layer_sizes[i], 4 * self.layer_sizes[i], bias=False)
+ for i in range(len(self.layer_sizes))
+ ]
+ self.prev_to_all = [
+ torch.nn.Linear(
+ self.layer_sizes[i - 1], 4 * self.layer_sizes[i], bias=False
+ )
+ for i in range(1, len(self.layer_sizes))
+ ]
+
+ self._add_modules("in_to_all", self.in_to_all)
+ self._add_modules("out_to_all", self.out_to_all)
+ self._add_modules("prev_to_all", self.prev_to_all)
+
+ self.reset_parameters()
+
+ def get_output_size(self):
+ return (
+ sum(self.layer_sizes) if self.out_from_all_layers else self.layer_sizes[-1]
+ )
+
+ def forward(self, data):
+ for i, size in enumerate(self.layer_sizes):
+ d = self.in_to_all[i](data)
+ if self.outputs[i] is not None:
+ d += self.out_to_all[i](self.outputs[i])
+ if i > 0:
+ d += self.prev_to_all[i - 1](self.outputs[i - 1])
+
+ input_data = torch.tanh(d[..., -size:])
+ forget_gate, input_gate, output_gate = torch.sigmoid(d[..., :-size]).chunk(
+ 3, dim=-1
+ )
+
+ state_update = input_gate * input_data
+
+ if self.states[i] is not None:
+ self.states[i] = self.states[i] * forget_gate + state_update
+ else:
+ self.states[i] = state_update
+
+ self.outputs[i] = output_gate * torch.tanh(self.states[i])
+
+ return (
+ torch.cat(self.outputs, -1)
+ if self.out_from_all_layers
+ else self.outputs[-1]
+ )
+
+
+class FeedforwardController(torch.nn.Module):
+ def __init__(self, layer_sizes=[]):
+ super(FeedforwardController, self).__init__()
+ self.layer_sizes = layer_sizes
+
+ def new_sequence(self):
+ pass
+
+ def reset_parameters(self):
+ for module in self.model:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+
+ def get_output_size(self):
+ return self.layer_sizes[-1]
+
+ def init(self, input_size):
+ self.layer_sizes = self.layer_sizes
+
+ # Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
+ # So use xavier init according to this.
+ self.input_sizes = [input_size] + self.layer_sizes[:-1]
+
+ layers = []
+ for i, size in enumerate(self.layer_sizes):
+ layers.append(torch.nn.Linear(self.input_sizes[i], self.layer_sizes[i]))
+ layers.append(torch.nn.ReLU())
+ self.model = torch.nn.Sequential(*layers)
+ self.reset_parameters()
+
+ def forward(self, data):
+ return self.model(data)
diff --git a/Models/Demon.py b/Models/Demon.py
new file mode 100644
index 0000000..13ded38
--- /dev/null
+++ b/Models/Demon.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+
+from torch.distributions import Normal
+
+from collections import namedtuple
+
+LOG_SIG_MAX = 2
+LOG_SIG_MIN = -20
+EPSILON = 1e-6
+
+SavedAction = namedtuple("SavedAction", ["action", "log_prob", "mean"])
+
+
+def linear_reset(module, gain=1.0):
+ assert isinstance(module, torch.nn.Linear)
+ init.xavier_uniform_(module.weight, gain=gain)
+ s = module.weight.size(1)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+class ZNet(nn.Module):
+ def __init__(self):
+ super(ZNet, self).__init__()
+
+ def reset_parameters(self):
+ for module in self.lstm:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+
+ for module in self.hidden2z:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+
+ def init(self, input_size):
+ self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
+ self.hidden2z = nn.Sequential(nn.Linear(32, 1))
+ self.reset_parameters()
+
+ def forward(self, data):
+ output, (hn, cn) = self.lstm(data)
+ zvals = self.hidden2z(output)
+ return F.softplus(zvals)
+
+
+class FNet(nn.Module):
+ def __init__(self):
+ super(FNet, self).__init__()
+
+ def reset_parameters(self):
+ for module in self.lstm:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+
+ for module in self.hidden2z:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+
+ def init(self, input_size):
+ self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
+ self.hidden2z = nn.Sequential(nn.Linear(32, 1))
+ self.reset_parameters()
+
+ def forward(self, data):
+ output, (hn, cn) = self.lstm(data)
+ output = F.elu(output)
+ fvals = self.hidden2z(output)
+ return fvals
+
+
+class Demon(torch.nn.Module):
+ """
+ Demon manipulates the external memory of DNC.
+ """
+
+ def __init__(self, layer_sizes=[]):
+ super(Demon, self).__init__()
+ self.layer_sizes = layer_sizes
+ self.action_scale = torch.tensor(1)
+ self.action_bias = torch.tensor(0.0)
+ self.saved_actions = []
+
+ def get_output_size(self):
+ return self.layer_sizes[-1]
+
+ def reset_parameters(self):
+ for module in self.model:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+ linear_reset(self.embed_mean, gain=init.calculate_gain("relu"))
+ linear_reset(self.embed_log_std, gain=init.calculate_gain("relu"))
+
+ def init(self, input_size, output_size):
+ # Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
+ # So use xavier init according to this.
+ self.input_sizes = [input_size] + self.layer_sizes[:-1]
+ layers = []
+ for i, size in enumerate(self.layer_sizes):
+ layers.append(nn.Linear(self.input_sizes[i], self.layer_sizes[i]))
+ layers.append(nn.ReLU())
+
+ self.model = nn.Sequential(*layers)
+ self.embed_mean = nn.Linear(self.layer_sizes[-1], output_size)
+ self.embed_log_std = nn.Linear(self.layer_sizes[-1], output_size)
+
+ self.reset_parameters()
+
+ def forward(self, data):
+ x = self.model(data)
+ x = F.relu(x)
+ mean, log_std = self.embed_mean(x), self.embed_log_std(x)
+ log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
+ std = torch.exp(log_std)
+ return mean, std
+
+ def act(self, data):
+ """
+ pathwise derivative estimator for taking actions.
+ :param data:
+ :return:
+ """
+ mean, std = self.forward(data)
+ normal = Normal(mean, std)
+ x = normal.rsample()
+
+ y = torch.softmax(x, dim=1)
+
+ action = y * self.action_scale + self.action_bias
+ log_prob = normal.log_prob(action)
+ # Enforcing Action Bound
+ log_prob -= torch.log(self.action_scale * (1 - y.pow(2)) + EPSILON)
+ log_prob = log_prob.sum(1, keepdim=True)
+
+ mean = torch.softmax(mean, dim=1) * self.action_scale + self.action_bias
+ self.saved_actions.append(SavedAction(action, log_prob, mean))
+
+ return mean
diff --git a/Models/Information_Agents.py b/Models/Information_Agents.py
new file mode 100644
index 0000000..2abbbd4
--- /dev/null
+++ b/Models/Information_Agents.py
@@ -0,0 +1,238 @@
+import torch
+import torch.nn as nn
+from torch.distributions import MultivariateNormal
+import torch.nn.functional as F
+import torch.nn.init as init
+
+import numpy as np
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+class RolloutStorage:
+ def __init__(self):
+ self.actions = []
+ self.states = []
+ self.logprobs = []
+ self.rewards = []
+ self.is_terminals = []
+
+ def clear_storage(self):
+ del self.actions[:]
+ del self.states[:]
+ del self.logprobs[:]
+ del self.rewards[:]
+ del self.is_terminals[:]
+
+
+class ActorCritic(nn.Module):
+ def __init__(self, state_dim, action_dim, action_std):
+ super(ActorCritic, self).__init__()
+ self.actor = nn.Sequential(
+ nn.Linear(state_dim, 64),
+ nn.Tanh(),
+ nn.Linear(64, 32),
+ nn.Tanh(),
+ nn.Linear(32, action_dim),
+ nn.Softmax(dim=1),
+ )
+
+ # critic
+ self.critic = nn.Sequential(
+ nn.Linear(state_dim, 64),
+ nn.Tanh(),
+ nn.Linear(64, 32),
+ nn.Tanh(),
+ nn.Linear(32, 1),
+ )
+ self.action_var = torch.full((action_dim,), action_std * action_std).to(device)
+
+ def forward(self):
+ raise NotImplementedError
+
+ def act(self, state, rollout_storage):
+ action_mean = self.actor(state)
+ cov_mat = torch.diag(self.action_var).to(device)
+
+ dist = MultivariateNormal(action_mean, cov_mat)
+ action = dist.sample()
+ action_logprob = dist.log_prob(action)
+
+ if rollout_storage:
+ rollout_storage.states.append(state)
+ rollout_storage.actions.append(action)
+ rollout_storage.logprobs.append(action_logprob)
+
+ return action.detach()
+
+ def evaluate(self, state, action):
+ action_mean = self.actor(state)
+
+ action_var = self.action_var.expand_as(action_mean)
+ cov_mat = torch.diag_embed(action_var).to(device)
+
+ dist = MultivariateNormal(action_mean, cov_mat)
+
+ action_logprobs = dist.log_prob(action)
+ dist_entropy = dist.entropy()
+ state_value = self.critic(state)
+
+ return action_logprobs, torch.squeeze(state_value), dist_entropy
+
+
+class Demon:
+ def __init__(
+ self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip
+ ):
+ self.lr = lr
+ self.betas = betas
+ self.gamma = gamma
+ self.eps_clip = eps_clip
+ self.K_epochs = K_epochs
+
+ self.policy = ActorCritic(state_dim, action_dim, action_std).to(device)
+ self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
+
+ self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device)
+ self.policy_old.load_state_dict(self.policy.state_dict())
+
+ self.MseLoss = nn.MSELoss()
+
+ def select_action(self, state, rollout_storage):
+ return self.policy_old.act(state, rollout_storage)
+
+ def update(self, rollout_storage):
+ # Monte Carlo estimate of rewards:
+ rewards = []
+ discounted_reward = 0
+ for reward in reversed(rollout_storage.rewards):
+ discounted_reward = reward + (self.gamma * discounted_reward)
+ rewards.insert(0, discounted_reward)
+
+ # Normalizing the rewards:
+ rewards = torch.stack(rewards)
+ rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
+ rewards = rewards.squeeze(-1)
+
+ # convert list to tensor
+ old_states = torch.squeeze(
+ torch.stack(rollout_storage.states).to(device), 1
+ ).detach()
+ old_actions = torch.squeeze(
+ torch.stack(rollout_storage.actions).to(device), 1
+ ).detach()
+ old_logprobs = (
+ torch.squeeze(torch.stack(rollout_storage.logprobs), 1).to(device).detach()
+ )
+
+ # Optimize policy for K epochs:
+ for _ in range(self.K_epochs):
+ # Evaluating old actions and values :
+ logprobs, state_values, dist_entropy = self.policy.evaluate(
+ old_states, old_actions
+ )
+
+ # Finding the ratio (pi_theta / pi_theta__old):
+ ratios = torch.exp(logprobs - old_logprobs.detach())
+
+ try:
+ state_values = state_values[
+ :-1, :
+ ] # reward is computed as the mutual info between consequenct mem state,
+ # therefore n-1 values only.
+ ratios = ratios[:-1, :] # the same for ratio
+ dist_entropy = dist_entropy[:-1, :] # the same for entropy
+ advantages = rewards - state_values.detach()
+ surr1 = ratios * advantages
+ surr2 = (
+ torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip)
+ * advantages
+ )
+ # Finding Surrogate Loss:
+ loss = (
+ -torch.min(surr1, surr2)
+ + 0.5 * self.MseLoss(state_values, rewards)
+ - 0.01 * dist_entropy
+ )
+ # take gradient step
+ self.optimizer.zero_grad()
+ loss.mean().backward()
+ torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
+ self.optimizer.step()
+ except Exception:
+ # Do thing for the sequences of lentgh 1.
+ loss = torch.zeros_like(rewards).to(device)
+ continue
+
+ # Copy new weights into old policy:
+ self.policy_old.load_state_dict(self.policy.state_dict())
+ return loss
+
+
+############################################
+# Mutual information Estimator Network######
+############################################
+
+
+def linear_reset(module, gain=1.0):
+ assert isinstance(module, torch.nn.Linear)
+ init.xavier_uniform_(module.weight, gain=gain)
+ s = module.weight.size(1)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+class FNet(nn.Module):
+ """
+ Monte-Carlo estimators for Mutual Information Known as MINE.
+ Mine produces estimates that are neither an upper or lower bound on MI.
+ Other ZNet can be Introduced to address the problem of building bounds with finite samples (unlike Monte Carlo)
+ """
+
+ def __init__(self):
+ super(FNet, self).__init__()
+
+ def reset_parameters(self):
+ for module in self.lstm:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+
+ for module in self.hidden2f:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+
+ def init(self, input_size):
+ self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
+ self.hidden2f = nn.Sequential(nn.Linear(32, 1))
+ self.reset_parameters()
+
+ def forward(self, data):
+ output, (hn, cn) = self.lstm(data)
+ output = F.elu(output)
+ fvals = self.hidden2f(output)
+ return fvals
+
+
+class ZNet(nn.Module):
+ def __init__(self):
+ super(ZNet, self).__init__()
+
+ def reset_parameters(self):
+ for module in self.lstm:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+
+ for module in self.hidden2z:
+ if isinstance(module, torch.nn.Linear):
+ linear_reset(module, gain=init.calculate_gain("relu"))
+
+ def init(self, input_size):
+ self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
+ self.hidden2z = nn.Sequential(nn.Linear(32, 1))
+ self.reset_parameters()
+
+ def forward(self, data):
+ output, (hn, cn) = self.lstm(data)
+ output = F.elu(output)
+ zvals = self.hidden2z(output)
+ return F.softplus(zvals)
diff --git a/Utils/.DS_Store b/Utils/.DS_Store
new file mode 100644
index 0000000..5008ddf
Binary files /dev/null and b/Utils/.DS_Store differ
diff --git a/Utils/ArgumentParser.py b/Utils/ArgumentParser.py
new file mode 100644
index 0000000..438a5cf
--- /dev/null
+++ b/Utils/ArgumentParser.py
@@ -0,0 +1,167 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import os
+import json
+import argparse
+
+
+class ArgumentParser:
+ class Parsed:
+ pass
+
+ _type = type
+
+ @staticmethod
+ def str_or_none(none_string="none"):
+ def parse(s):
+ return None if s.lower()==none_string else s
+ return parse
+
+ @staticmethod
+ def list_or_none(none_string="none", type=int):
+ def parse(s):
+ return None if s.lower() == none_string else [type(a) for a in s.split(",") if a]
+ return parse
+
+ @staticmethod
+ def _merge_args(args, new_args, arg_schemas):
+ for name, val in new_args.items():
+ old = args.get(name)
+ if old is None:
+ args[name] = val
+ else:
+ args[name] = arg_schemas[name]["updater"](old, val)
+
+ class Profile:
+ def __init__(self, name, args=None, include=[]):
+ assert not (args is None and not include), "One of args or include must be defined"
+ self.name = name
+ self.args = args
+ if not isinstance(include, list):
+ include=[include]
+ self.include = include
+
+ def get_args(self, arg_schemas, profile_by_name):
+ res = {}
+
+ for n in self.include:
+ p = profile_by_name.get(n)
+ assert p is not None, "Included profile %s doesn't exists" % n
+
+ ArgumentParser._merge_args(res, p.get_args(arg_schemas, profile_by_name), arg_schemas)
+
+ ArgumentParser._merge_args(res, self.args, arg_schemas)
+ return res
+
+
+ def __init__(self, description=None):
+ self.parser = argparse.ArgumentParser(description=description)
+ self.loaded = {}
+ self.profiles = {}
+ self.args = {}
+ self.raw = None
+ self.parsed = None
+ self.parser.add_argument("-profile", type=str, help="Pre-defined profiles.")
+
+ def add_argument(self, name, type=None, default=None, help="", save=True, parser=lambda x:x, updater=lambda old, new:new):
+ assert name not in ["profile"], "Argument name %s is reserved" % name
+ assert not (type is None and default is None), "Either type or default must be given"
+
+ if type is None:
+ type = ArgumentParser._type(default)
+
+ self.parser.add_argument(name, type=int if type==bool else type, default=None, help=help)
+ if name[0] == '-':
+ name = name[1:]
+
+ self.args[name] = {
+ "type": type,
+ "default": int(default) if type==bool else default,
+ "save": save,
+ "parser": parser,
+ "updater": updater
+ }
+
+ def add_profile(self, prof):
+ if isinstance(prof, list):
+ for p in prof:
+ self.add_profile(p)
+ else:
+ self.profiles[prof.name] = prof
+
+ def do_parse_args(self, loaded={}):
+ self.raw = self.parser.parse_args()
+
+ profile = {}
+ if self.raw.profile:
+ assert not loaded, "Loading arguments from file, but profile given."
+ for pr in self.raw.profile.split(","):
+ p = self.profiles.get(pr)
+ assert p is not None, "Invalid profile: %s. Valid profiles: %s" % (pr, self.profiles.keys())
+ p = p.get_args(self.args, self.profiles)
+ self._merge_args(profile, p, self.args)
+
+ for k, v in self.raw.__dict__.items():
+ if k in ["profile"]:
+ continue
+
+ if v is None:
+ if k in loaded and self.args[k]["save"]:
+ self.raw.__dict__[k] = loaded[k]
+ else:
+ self.raw.__dict__[k] = profile.get(k, self.args[k]["default"])
+
+ self.parsed = ArgumentParser.Parsed()
+ self.parsed.__dict__.update({k: self.args[k]["parser"](self.args[k]["type"](v)) if v is not None else None
+ for k,v in self.raw.__dict__.items() if k in self.args})
+
+ return self.parsed
+
+ def parse_or_cache(self):
+ if self.parsed is None:
+ self.do_parse_args()
+
+ def parse(self):
+ self.parse_or_cache()
+ return self.parsed
+
+ def save(self, fname):
+ self.parse_or_cache()
+ with open(fname, 'w') as outfile:
+ json.dump(self.raw.__dict__, outfile, indent=4)
+ return True
+
+ def load(self, fname):
+ if os.path.isfile(fname):
+ map = {}
+ with open(fname, "r") as data_file:
+ map = json.load(data_file)
+
+ self.do_parse_args(map)
+ return self.parsed
+
+ def sync(self, fname, save=True):
+ if os.path.isfile(fname):
+ self.load(fname)
+
+ if save:
+ dir = os.path.dirname(fname)
+ if not os.path.isdir(dir):
+ os.makedirs(dir)
+
+ self.save(fname)
+ return self.parsed
diff --git a/Utils/Collate.py b/Utils/Collate.py
new file mode 100644
index 0000000..3b593d2
--- /dev/null
+++ b/Utils/Collate.py
@@ -0,0 +1,75 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import numpy as np
+import torch
+from operator import mul
+from functools import reduce
+
+class VarLengthCollate:
+ def __init__(self, ignore_symbol=0):
+ self.ignore_symbol = ignore_symbol
+
+ def _measure_array_max_dim(self, batch):
+ s=list(batch[0].size())
+ different=[False] * len(s)
+ for i in range(1, len(batch)):
+ ns = batch[i].size()
+ different = [different[j] or s[j]!=ns[j] for j in range(len(s))]
+ s=[max(s[j], ns[j]) for j in range(len(s))]
+ return s, different
+
+ def _merge_var_len_array(self, batch):
+ max_size, different = self._measure_array_max_dim(batch)
+ s=[len(batch)] + max_size
+ storage = batch[0].storage()._new_shared(reduce(mul, s, 1))
+ out = batch[0].new(storage).view(s).fill_(self.ignore_symbol)
+ for i, d in enumerate(batch):
+ this_o = out[i]
+ for j, diff in enumerate(different):
+ if different[j]:
+ this_o = this_o.narrow(j,0,d.size(j))
+ this_o.copy_(d)
+ return out
+
+
+ def __call__(self, batch):
+ if isinstance(batch[0], dict):
+ return {k: self([b[k] for b in batch]) for k in batch[0].keys()}
+ elif isinstance(batch[0], np.ndarray):
+ return self([torch.from_numpy(a) for a in batch])
+ elif torch.is_tensor(batch[0]):
+ return self._merge_var_len_array(batch)
+ else:
+ assert False, "Unknown type: %s" % type(batch[0])
+
+class MetaCollate:
+ def __init__(self, meta_name="meta", collate=VarLengthCollate()):
+ self.meta_name = meta_name
+ self.collate = collate
+
+ def __call__(self, batch):
+ if isinstance(batch[0], dict):
+ meta = [b[self.meta_name] for b in batch]
+ batch = [{k: v for k,v in b.items() if k!=self.meta_name} for b in batch]
+ else:
+ meta = None
+
+ res = self.collate(batch)
+ if meta is not None:
+ res[self.meta_name] = meta
+
+ return res
\ No newline at end of file
diff --git a/Utils/Debug.py b/Utils/Debug.py
new file mode 100644
index 0000000..779ecaf
--- /dev/null
+++ b/Utils/Debug.py
@@ -0,0 +1,133 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import numpy as np
+import sys
+import traceback
+import torch
+
+enableDebug = False
+
+def nan_check(arg, name=None, force=False):
+ if not enableDebug and not force:
+ return arg
+ is_nan = False
+ curr_nan = False
+ if isinstance(arg, torch.autograd.Variable):
+ curr_nan = not np.isfinite(arg.sum().cpu().data.numpy())
+ elif isinstance(arg, torch.nn.parameter.Parameter):
+ curr_nan = (not np.isfinite(arg.sum().cpu().data.numpy())) or (not np.isfinite(arg.grad.sum().cpu().data.numpy()))
+ elif isinstance(arg, float):
+ curr_nan = not np.isfinite(arg)
+ elif isinstance(arg, (list, tuple)):
+ for a in arg:
+ nan_check(a)
+ else:
+ assert False, "Unsupported type %s" % type(arg)
+
+ if curr_nan:
+ if sys.exc_info()[0] is not None:
+ trace = str(traceback.format_exc())
+ else:
+ trace = "".join(traceback.format_stack())
+
+ print(arg)
+ if name is not None:
+ print("NaN found in %s." % name)
+ else:
+ print("NaN found.")
+ if isinstance(arg, torch.autograd.Variable):
+ print(" Argument is a torch tensor. Shape: %s" % list(arg.size()))
+
+ print(trace)
+ sys.exit(-1)
+
+ return arg
+
+
+def assert_range(t, min=0.0, max=1.0):
+ if not enableDebug:
+ return
+
+ if t.min().cpu().data.numpy()max:
+ print(t)
+ assert False
+
+
+def assert_dist(t, use_lower_limit=True):
+ if not enableDebug:
+ return
+
+ assert_range(t)
+
+ if t.sum(-1).max().cpu().data.numpy()>1.001:
+ print("MAT:", t)
+ print("SUM:", t.sum(-1))
+ assert False
+
+ if use_lower_limit and t.sum(-1).max().cpu().data.numpy()<0.999:
+ print(t)
+ print("SUM:", t.sum(-1))
+ assert False
+
+
+def print_stat(name, t):
+ if not enableDebug:
+ return
+
+ min = t.min().cpu().data.numpy()
+ max = t.max().cpu().data.numpy()
+ mean = t.mean().cpu().data.numpy()
+
+ print("%s: min: %g, mean: %g, max: %g" % (name, min, mean, max))
+
+
+def dbg_print(*things):
+ if not enableDebug:
+ return
+ print(*things)
+
+class GradPrinter(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, a):
+ return a
+
+ @staticmethod
+ def backward(ctx, g):
+ print("Grad (print_grad): ", g[0])
+ return g
+
+def print_grad(t):
+ return GradPrinter.apply(t)
+
+def assert_equal(t1, ref, limit=1e-5, force=True):
+ if not (enableDebug or force):
+ return
+
+ assert t1.shape==ref.shape, "Tensor shapes differ: got %s, ref %s" % (t1.shape, ref.shape)
+ norm = ref.abs().sum() / ref.nonzero().sum().float()
+ threshold = norm * limit
+
+ errcnt = ((t1 - ref).abs() > threshold).sum()
+ if errcnt > 0:
+ print("Tensors differ. (max difference: %g, norm %f). No of errors: %d of %d" %
+ ((t1 - ref).abs().max().item(), norm, errcnt, t1.numel()))
+ print("---------------------------------------------Out-----------------------------------------------")
+ print(t1)
+ print("---------------------------------------------Ref-----------------------------------------------")
+ print(ref)
+ print("-----------------------------------------------------------------------------------------------")
+ assert False
\ No newline at end of file
diff --git a/Utils/Helpers.py b/Utils/Helpers.py
new file mode 100644
index 0000000..4868150
--- /dev/null
+++ b/Utils/Helpers.py
@@ -0,0 +1,24 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import torch
+import torch.autograd
+
+def as_numpy(data):
+ if isinstance(data, (torch.Tensor, torch.autograd.Variable)):
+ return data.detach().cpu().numpy()
+ else:
+ return data
diff --git a/Utils/Index.py b/Utils/Index.py
new file mode 100644
index 0000000..3813caf
--- /dev/null
+++ b/Utils/Index.py
@@ -0,0 +1,20 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+def index_by_dim(arr, dim, i_start, i_end=None):
+ if dim<0:
+ dim += arr.ndim
+ return arr[tuple([slice(None,None)] * dim + [slice(i_start, i_end) if i_end is not None else i_start])]
\ No newline at end of file
diff --git a/Utils/Process.py b/Utils/Process.py
new file mode 100644
index 0000000..12fe617
--- /dev/null
+++ b/Utils/Process.py
@@ -0,0 +1,51 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import sys
+import ctypes
+import subprocess
+import os
+
+def run(cmd, hide_stderr = True):
+ libc_search_dirs = ["/lib", "/lib/x86_64-linux-gnu", "/lib/powerpc64le-linux-gnu"]
+
+ if sys.platform=="linux" :
+ found = None
+ for d in libc_search_dirs:
+ file = os.path.join(d, "libc.so.6")
+ if os.path.isfile(file):
+ found = file
+ break
+
+ if not found:
+ print("WARNING: Cannot find libc.so.6. Cannot kill process when parent dies.")
+ killer = None
+ else:
+ libc = ctypes.CDLL(found)
+ PR_SET_PDEATHSIG = 1
+ KILL = 9
+ killer = lambda: libc.prctl(PR_SET_PDEATHSIG, KILL)
+ else:
+ print("WARNING: OS not linux. Cannot kill process when parent dies.")
+ killer = None
+
+
+ if hide_stderr:
+ stderr = open(os.devnull,'w')
+ else:
+ stderr = None
+
+ return subprocess.Popen(cmd.split(" "), stderr=stderr, preexec_fn=killer)
diff --git a/Utils/Profile.py b/Utils/Profile.py
new file mode 100644
index 0000000..60eb546
--- /dev/null
+++ b/Utils/Profile.py
@@ -0,0 +1,48 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import atexit
+
+ENABLED=False
+
+_profiler = None
+
+
+def construct():
+ global _profiler
+ if not ENABLED:
+ return
+
+ if _profiler is None:
+ from line_profiler import LineProfiler
+ _profiler = LineProfiler()
+
+
+def do_profile(follow=[]):
+ construct()
+ def inner(func):
+ if _profiler is not None:
+ _profiler.add_function(func)
+ for f in follow:
+ _profiler.add_function(f)
+ _profiler.enable_by_count()
+ return func
+ return inner
+
+@atexit.register
+def print_prof():
+ if _profiler is not None:
+ _profiler.print_stats()
diff --git a/Utils/Saver.py b/Utils/Saver.py
new file mode 100644
index 0000000..99f6984
--- /dev/null
+++ b/Utils/Saver.py
@@ -0,0 +1,233 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import torch
+import os
+import inspect
+import time
+
+
+class SaverElement:
+ def save(self):
+ raise NotImplementedError
+
+ def load(self, saved_state):
+ raise NotImplementedError
+
+
+class CallbackSaver(SaverElement):
+ def __init__(self, save_fn, load_fn):
+ super().__init__()
+ self.save = save_fn
+ self.load = load_fn
+
+
+class StateSaver(SaverElement):
+ def __init__(self, model):
+ super().__init__()
+ self._model = model
+
+ def load(self, state):
+ try:
+ self._model.load_state_dict(state)
+ except Exception as e:
+ if hasattr(self._model, "named_parameters"):
+ names = set([n for n, _ in self._model.named_parameters()])
+ loaded = set(self._model.keys())
+ if names!=loaded:
+ d = loaded.difference(names)
+ if d:
+ print("Loaded, but not in model: %s" % list(d))
+ d = names.difference(loaded)
+ if d:
+ print("In model, but not loaded: %s" % list(d))
+ if isinstance(self._model, torch.optim.Optimizer):
+ print("WARNING: optimizer parameters not loaded!")
+ else:
+ raise e
+
+ def save(self):
+ return self._model.state_dict()
+
+
+class GlobalVarSaver(SaverElement):
+ def __init__(self, name):
+ caller_frame = inspect.getouterframes(inspect.currentframe())[1]
+ self._vars = caller_frame.frame.f_globals
+ self._name = name
+
+ def load(self, state):
+ self._vars.update({self._name: state})
+
+ def save(self):
+ return self._vars[self._name]
+
+
+class PyObjectSaver(SaverElement):
+ def __init__(self, obj):
+ self._obj = obj
+
+ def load(self, state):
+ def _load(target, state):
+ if isinstance(target, dict):
+ for k, v in state.items():
+ target[k] = _load(target.get(k), v)
+ elif isinstance(target, list):
+ if len(target)!=len(state):
+ target.clear()
+ for v in state:
+ target.append(v)
+ else:
+ for i, v in enumerate(state):
+ target[i] = _load(target[i], v)
+
+ elif hasattr(target, "__dict__"):
+ _load(target.__dict__, state)
+ else:
+ return state
+ return target
+
+ _load(self._obj, state)
+
+ def save(self):
+ def _save(target):
+ if isinstance(target, dict):
+ res = {k: _save(v) for k, v in target.items()}
+ elif isinstance(target, list):
+ res = [_save(v) for v in target]
+ elif hasattr(target, "__dict__"):
+ res = {k: _save(v) for k, v in target.__dict__.items()}
+ else:
+ res = target
+
+ return res
+
+ return _save(self._obj)
+
+ @staticmethod
+ def obj_supported(obj):
+ return isinstance(obj, (list, dict)) or hasattr(obj, "__dict__")
+
+
+class Saver:
+ def __init__(self, dir, short_interval, keep_every_n_hours=4):
+ self.savers = {}
+ self.short_interval = short_interval
+ os.makedirs(dir, exist_ok=True)
+ self.dir = dir
+ self._keep_every_n_seconds = keep_every_n_hours * 3600
+
+ def register(self, name, saver):
+ assert name not in self.savers, "Saver %s already registered" % name
+
+ if isinstance(saver, SaverElement):
+ self.savers[name] = saver
+ elif hasattr(saver, "state_dict") and callable(saver.state_dict):
+ self.savers[name] = StateSaver(saver)
+ elif PyObjectSaver.obj_supported(saver):
+ self.savers[name] = PyObjectSaver(saver)
+ else:
+ assert "Unsupported thing to save: %s" % type(saver)
+
+ def __setitem__(self, key, value):
+ self.register(key, value)
+
+ def write(self, iter):
+ fname = os.path.join(self.dir, self.model_name_from_index(iter))
+ print("Saving %s" % fname)
+
+ state = {}
+ for name, fns in self.savers.items():
+ state[name] = fns.save()
+
+ torch.save(state, fname)
+ print("Saved.")
+
+ self._cleanup()
+
+ def tick(self, iter):
+ if iter % self.short_interval != 0:
+ return
+
+ self.write(iter)
+
+ @staticmethod
+ def model_name_from_index(index):
+ return "model-%d.pth" % index
+
+ @staticmethod
+ def get_checkpoint_index_list(dir):
+ return list(reversed(sorted(
+ [int(fn.split(".")[0].split("-")[-1]) for fn in os.listdir(dir) if fn.split(".")[-1] == "pth"])))
+
+ @staticmethod
+ def get_ckpts_in_time_window(dir, time_window_s, index_list=None):
+ if index_list is None:
+ index_list = Saver.get_checkpoint_index_list(dir)
+
+
+ now = time.time()
+
+ res = []
+ for i in index_list:
+ name = Saver.model_name_from_index(i)
+ mtime = os.path.getmtime(os.path.join(dir, name))
+ if now - mtime > time_window_s:
+ break
+
+ res.append(name)
+
+ return res
+
+ @staticmethod
+ def load_last_checkpoint(dir):
+ last_checkpoint = Saver.get_checkpoint_index_list(dir)
+
+ if last_checkpoint:
+ for index in last_checkpoint:
+ fname = Saver.model_name_from_index(index)
+ try:
+ print("Loading %s" % fname)
+ data = torch.load(os.path.join(dir, fname))
+ except:
+ print("WARNING: Loading %s failed. Maybe file is corrupted?" % fname)
+ continue
+ return data
+ return None
+
+ def _cleanup(self):
+ index_list = self.get_checkpoint_index_list(self.dir)
+ new_files = self.get_ckpts_in_time_window(self.dir, self._keep_every_n_seconds, index_list[2:])
+ new_files = new_files[:-1]
+
+ for f in new_files:
+ os.remove(os.path.join(self.dir, f))
+
+ def load(self, fname=None):
+ if fname is None:
+ state = self.load_last_checkpoint(self.dir)
+ if not state:
+ return False
+ else:
+ state = torch.load(fname)
+
+ for k,s in state.items():
+ if k not in self.savers:
+ print("WARNING: failed to load state of %s. It doesn't exists." % k)
+ continue
+ self.savers[k].load(s)
+
+ return True
diff --git a/Utils/Seed.py b/Utils/Seed.py
new file mode 100644
index 0000000..ee32095
--- /dev/null
+++ b/Utils/Seed.py
@@ -0,0 +1,31 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import numpy as np
+
+seed = None
+
+
+def fix():
+ global seed
+ seed = 0xB0C1FA52
+
+
+def get_randstate():
+ if seed:
+ return np.random.RandomState(seed)
+ else:
+ return np.random.RandomState()
diff --git a/Utils/Visdom.py b/Utils/Visdom.py
new file mode 100644
index 0000000..ec69fb4
--- /dev/null
+++ b/Utils/Visdom.py
@@ -0,0 +1,323 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import time
+import visdom
+import sys
+import numpy as np
+import os
+import socket
+from . import Process
+
+vis = None
+port = None
+visdom_fail_count = 0
+
+
+def port_used(port):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ result = sock.connect_ex(('127.0.0.1', port))
+ if result == 0:
+ sock.close()
+ return True
+ else:
+ return False
+
+
+def alloc_port(start_from=7000):
+ while True:
+ if port_used(start_from):
+ print("Port already used: %d" % start_from)
+ start_from += 1
+ else:
+ return start_from
+
+
+def wait_for_port(port, timeout=5):
+ star_time = time.time()
+ while not port_used(port):
+ if time.time() - star_time > timeout:
+ return False
+
+ time.sleep(0.1)
+ return True
+
+
+def start(on_port=None):
+ global vis
+ global port
+ global visdom_fail_count
+ assert vis is None, "Cannot start more than 1 visdom servers."
+
+ if visdom_fail_count>=3:
+ return
+
+ port = alloc_port() if on_port is None else on_port
+
+ print("Starting Visdom server on %d" % port)
+ Process.run("%s -m visdom.server -p %d" % (sys.executable, port))
+ if not wait_for_port(port):
+ print("ERROR: failed to start Visdom server. Server not responding.")
+ visdom_fail_count += 1
+ return
+ print("Done.")
+
+ vis = visdom.Visdom(port=port)
+
+
+def _start_if_not_running():
+ if vis is None:
+ start()
+
+
+def save_heatmap(dir, title, img):
+ if dir is not None:
+ fname = os.path.join(dir, title.replace(" ", "_") + ".npy")
+ d = os.path.dirname(fname)
+ os.makedirs(d, exist_ok=True)
+ np.save(fname, img)
+
+
+class Plot2D:
+ TO_SAVE = ["x", "y", "curr_accu", "curr_cnt", "legend"]
+
+ def __init__(self, name, store_interval=1, legend=None, xlabel=None, ylabel=None):
+ _start_if_not_running()
+
+ self.x = []
+ self.y = []
+ self.store_interval = store_interval
+ self.curr_accu = None
+ self.curr_cnt = 0
+ self.name = name
+ self.legend = legend
+ self.xlabel = xlabel
+ self.ylabel = ylabel
+
+ self.replot = False
+ self.visplot = None
+ self.last_vis_update_pos = 0
+
+ def set_legend(self, legend):
+ if self.legend != legend:
+ self.legend = legend
+ self.replot = True
+
+ def _send_update(self):
+ if not self.x or vis is None:
+ return
+
+ if self.visplot is None or self.replot:
+ opts = {
+ "title": self.name
+ }
+ if self.xlabel:
+ opts["xlabel"] = self.xlabel
+ if self.ylabel:
+ opts["ylabel"] = self.ylabel
+
+ if self.legend is not None:
+ opts["legend"] = self.legend
+
+ self.visplot = vis.line(X=np.asfarray(self.x), Y=np.asfarray(self.y), opts=opts, win=self.visplot)
+ self.replot = False
+ else:
+ vis.line(
+ X=np.asfarray(self.x[self.last_vis_update_pos:]),
+ Y=np.asfarray(self.y[self.last_vis_update_pos:]),
+ win=self.visplot,
+ update='append'
+ )
+
+ self.last_vis_update_pos = len(self.x) - 1
+
+ def add_point(self, x, y):
+ if not isinstance(y, list):
+ y = [y]
+
+
+
+ if self.curr_accu is None:
+ self.curr_accu = [0.0] * len(y)
+
+ if len(self.curr_accu) < len(y):
+ # The number of curves increased.
+ need_to_add = (len(y) - len(self.curr_accu))
+
+ self.curr_accu += [0.0] * need_to_add
+ count = len(self.x)
+ if count>0:
+ self.replot = True
+ if not isinstance(self.x[0], list):
+ self.x = [[x] for x in self.x]
+ self.y = [[y] for y in self.y]
+
+ nan = float("nan")
+ for a in self.x:
+ a += [nan] * need_to_add
+ for a in self.y:
+ a += [nan] * need_to_add
+ elif len(self.curr_accu) > len(y):
+ y = y[:] + [float("nan")] * (len(self.curr_accu) - len(y))
+
+ self.curr_accu = [self.curr_accu[i] + y[i] for i in range(len(y))]
+ self.curr_cnt += 1
+ if self.curr_cnt == self.store_interval:
+ if len(y) > 1:
+ self.x.append([x] * len(y))
+ self.y.append([a / self.curr_cnt for a in self.curr_accu])
+ else:
+ self.x.append(x)
+ self.y.append(self.curr_accu[0] / self.curr_cnt)
+
+ self.curr_accu = [0.0] * len(y)
+ self.curr_cnt = 0
+
+ self._send_update()
+
+ def state_dict(self):
+ s = {k: self.__dict__[k] for k in self.TO_SAVE}
+ return s
+
+ def load_state_dict(self, state):
+ if self.legend is not None:
+ # Load legend only if not given in the constructor.
+ state["legend"] = self.legend
+ self.__dict__.update(state)
+ self.last_vis_update_pos = 0
+
+ # Read old format
+ if not isinstance(self.curr_accu, list) and self.curr_accu is not None:
+ self.curr_accu = [self.curr_accu]
+
+ self._send_update()
+
+
+class Image:
+ def __init__(self, title, dumpdir=None):
+ _start_if_not_running()
+
+ self.win = None
+ self.opts = dict(title=title)
+ self.dumpdir = dumpdir
+
+ def set_dump_dir(self, dumpdir):
+ self.dumpdir = dumpdir
+
+ def draw(self, img):
+ if vis is None:
+ return
+
+ if isinstance(img, list):
+ if self.win is None:
+ self.win = vis.images(img, opts=self.opts)
+ else:
+ vis.images(img, win=self.win, opts=self.opts)
+ else:
+ if len(img.shape)==2:
+ img = np.expand_dims(img, 0)
+ elif img.shape[-1] in [1,3] and img.shape[0] not in [1,3]:
+ # cv2 image
+ img = img.transpose(2,0,1)
+ img = img[::-1]
+
+ if img.dtype==np.uint8:
+ img = img.astype(np.float32)/255
+
+ self.opts["width"] = img.shape[2]
+ self.opts["height"] = img.shape[1]
+
+ save_heatmap(self.dumpdir, self.opts["title"], img)
+ if self.win is None:
+ self.win = vis.image(img, opts=self.opts)
+ else:
+ vis.image(img, win=self.win, opts=self.opts)
+
+ def __call__(self, img):
+ self.draw(img)
+
+
+class Text:
+ def __init__(self, title):
+ _start_if_not_running()
+
+ self.win = None
+ self.title = title
+ self.curr_text = ""
+
+ def set(self, text):
+ self.curr_text = text
+
+ if vis is None:
+ return
+
+ if self.win is None:
+ self.win = vis.text(text, opts=dict(
+ title=self.title
+ ))
+ else:
+ vis.text(text, win=self.win)
+
+ def state_dict(self):
+ return {"text": self.curr_text}
+
+ def load_state_dict(self, state):
+ self.set(state["text"])
+
+ def __call__(self, text):
+ self.set(text)
+
+
+class Heatmap:
+ def __init__(self, title, min=None, max=None, xlabel=None, ylabel=None, colormap='Viridis', dumpdir=None):
+ _start_if_not_running()
+
+ self.win = None
+ self.opt = dict(title=title, colormap=colormap)
+ self.dumpdir = dumpdir
+ if min is not None:
+ self.opt["xmin"] = min
+ if max is not None:
+ self.opt["xmax"] = max
+
+ if xlabel:
+ self.opt["xlabel"] = xlabel
+ if ylabel:
+ self.opt["ylabel"] = ylabel
+
+ def set_dump_dir(self, dumpdir):
+ self.dumpdir = dumpdir
+
+ def draw(self, img):
+ if vis is None:
+ return
+
+ o = self.opt.copy()
+ if "xmin" not in o:
+ o["xmin"] = float(img.min())
+
+ if "xmax" not in o:
+ o["xmax"] = float(img.max())
+
+ save_heatmap(self.dumpdir, o["title"], img)
+
+ if self.win is None:
+ self.win = vis.heatmap(img, opts=o)
+ else:
+ vis.heatmap(img, win=self.win, opts=o)
+
+ def __call__(self, img):
+ self.draw(img)
diff --git a/Utils/download.py b/Utils/download.py
new file mode 100644
index 0000000..e78b5e1
--- /dev/null
+++ b/Utils/download.py
@@ -0,0 +1,189 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import requests, tarfile, io, os, zipfile
+
+from io import BytesIO, SEEK_SET, SEEK_END
+
+
+class UrlStream:
+ def __init__(self, url):
+ self._url = url
+ headers = requests.head(url).headers
+ headers = {k.lower(): v for k,v in headers.items()}
+ self._seek_supported = headers.get('accept-ranges')=='bytes' and 'content-length' in headers
+ if self._seek_supported:
+ self._size = int(headers['content-length'])
+ self._curr_pos = 0
+ self._buf_start_pos = 0
+ self._iter = None
+ self._buffer = None
+ self._buf_size = 0
+ self._loaded_all = False
+
+ def seekable(self):
+ return self._seek_supported
+
+ def _load_all(self):
+ if self._loaded_all:
+ return
+ self._make_request()
+ old_buf_pos = self._buffer.tell()
+ self._buffer.seek(0, SEEK_END)
+ for chunk in self._iter:
+ self._buffer.write(chunk)
+ self._buf_size = self._buffer.tell()
+ self._buffer.seek(old_buf_pos, SEEK_SET)
+ self._loaded_all = True
+
+ def seek(self, position, whence=SEEK_SET):
+ if whence == SEEK_END:
+ assert position<=0
+ if self._seek_supported:
+ self.seek(self._size + position)
+ else:
+ self._load_all()
+ self._buffer.seek(position, SEEK_END)
+ self._curr_pos = self._buffer.tell()
+ elif whence==SEEK_SET:
+ if self._curr_pos != position:
+ self._curr_pos = position
+ if self._seek_supported:
+ self._iter = None
+ self._buffer = None
+ else:
+ self._load_until(position)
+ self._buffer.seek(position)
+ self._curr_pos = position
+ else:
+ assert "Invalid whence %s" % whence
+
+ return self.tell()
+
+ def tell(self):
+ return self._curr_pos
+
+ def _load_until(self, goal_position):
+ self._make_request()
+ old_buf_pos = self._buffer.tell()
+ current_position = self._buffer.seek(0, SEEK_END)
+
+ goal_position = goal_position - self._buf_start_pos
+ while current_position < goal_position:
+ try:
+ d = next(self._iter)
+ self._buffer.write(d)
+ current_position += len(d)
+ except StopIteration:
+ break
+ self._buf_size = current_position
+ self._buffer.seek(old_buf_pos, SEEK_SET)
+
+ def _new_buffer(self):
+ remaining = self._buffer.read() if self._buffer is not None else None
+ self._buffer = BytesIO()
+ if remaining is not None:
+ self._buffer.write(remaining)
+ self._buf_start_pos = self._curr_pos
+ self._buf_size = 0 if remaining is None else len(remaining)
+ self._buffer.seek(0, SEEK_SET)
+ self._loaded_all = False
+
+ def _make_request(self):
+ if self._iter is None:
+ h = {
+ "User-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36",
+ }
+ if self._seek_supported:
+ h["Range"] = "bytes=%d-%d" % (self._curr_pos, self._size - 1)
+
+ r = requests.get(self._url, headers=h, stream=True)
+
+ self._iter = r.iter_content(1024 * 1024)
+ self._new_buffer()
+ elif self._seek_supported and self._buf_size > 128 * 1024 * 1024:
+ self._new_buffer()
+
+ def size(self):
+ if self._seek_supported:
+ return self._size
+ else:
+ self._load_all()
+ return self._buf_size
+
+ def read(self, size=None):
+ if size is None:
+ size = self.size()
+
+ self._load_until(self._curr_pos + size)
+ if self._seek_supported:
+ self._curr_pos = min(self._curr_pos+size, self._size)
+
+ return self._buffer.read(size)
+
+ def iter_content(self, block_size):
+ while True:
+ d = self.read(block_size)
+ if not len(d):
+ break
+ yield d
+
+
+def download(url, dest=None, extract=True, ignore_if_exists=False):
+ """
+ Download a file from the internet.
+
+ Args:
+ url: the url to download
+ dest: destination file if extract=False, or destionation dir if extract=True. If None, it will be the last part of URL.
+ extract: extract a tar.gz or zip file?
+ ignore_if_exists: don't do anything if file exists
+
+ Returns:
+ the destination filename.
+ """
+
+ base_url = url.split("?")[0]
+
+ if dest is None:
+ dest = [f for f in base_url.split("/") if f][-1]
+
+ if os.path.exists(dest) and ignore_if_exists:
+ return dest
+
+ stream = UrlStream(url)
+ extension = base_url.split(".")[-1].lower()
+
+ if extract and extension in ['gz', 'bz2', 'zip']:
+ os.makedirs(dest, exist_ok=True)
+
+ if extension in ['gz', 'bz2']:
+ decompressed_file = tarfile.open(fileobj=stream, mode='r|'+extension)
+ elif extension=='zip':
+ decompressed_file = zipfile.ZipFile(stream, mode='r')
+ else:
+ assert False, "Invalid extension: %s" % extension
+
+ decompressed_file.extractall(dest)
+ else:
+ try:
+ with open(dest, 'wb') as f:
+ for d in stream.iter_content(1024*1024):
+ f.write(d)
+ except:
+ os.remove(dest)
+ raise
+ return dest
diff --git a/Utils/gpu_allocator.py b/Utils/gpu_allocator.py
new file mode 100644
index 0000000..be88455
--- /dev/null
+++ b/Utils/gpu_allocator.py
@@ -0,0 +1,111 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import subprocess
+import os
+import torch
+from Utils.lockfile import LockFile
+
+def get_memory_usage():
+ try:
+ proc = subprocess.Popen("nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits".split(" "),
+ stdout=subprocess.PIPE)
+ lines = [s.strip().split(" ") for s in proc.communicate()[0].decode().split("\n") if s]
+ return {int(g[0][:-1]): int(g[1]) for g in lines}
+ except:
+ return None
+
+
+def get_free_gpus():
+ try:
+ free = []
+ proc = subprocess.Popen("nvidia-smi --query-compute-apps=gpu_uuid --format=csv,noheader,nounits".split(" "),
+ stdout=subprocess.PIPE)
+ uuids = [s.strip() for s in proc.communicate()[0].decode().split("\n") if s]
+
+ proc = subprocess.Popen("nvidia-smi --query-gpu=index,uuid --format=csv,noheader,nounits".split(" "),
+ stdout=subprocess.PIPE)
+
+ id_uid_pair = [s.strip().split(", ") for s in proc.communicate()[0].decode().split("\n") if s]
+ for i in id_uid_pair:
+ id, uid = i
+
+ if uid not in uuids:
+ free.append(int(id))
+
+ return free
+ except:
+ return None
+
+def _fix_order():
+ os.environ["CUDA_DEVICE_ORDER"] = os.environ.get("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
+
+def allocate(n:int = 1):
+ _fix_order()
+ with LockFile("/tmp/gpu_allocation_lock"):
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ print("WARNING: trying to allocate %d GPUs, but CUDA_VISIBLE_DEVICES already set to %s" %
+ (n, os.environ["CUDA_VISIBLE_DEVICES"]))
+ return
+
+ allocated = get_free_gpus()
+ if allocated is None:
+ print("WARNING: failed to allocate %d GPUs" % n)
+ return
+ allocated = allocated[:n]
+
+ if len(allocated) < n:
+ print("There is no more free GPUs. Allocating the one with least memory usage.")
+ usage = get_memory_usage()
+ if usage is None:
+ print("WARNING: failed to allocate %d GPUs" % n)
+ return
+
+ inv_usages = {}
+
+ for k, v in usage.items():
+ if v not in inv_usages:
+ inv_usages[v] = []
+
+ inv_usages[v].append(k)
+
+ min_usage = list(sorted(inv_usages.keys()))
+ min_usage_devs = []
+ for u in min_usage:
+ min_usage_devs += inv_usages[u]
+
+ min_usage_devs = [m for m in min_usage_devs if m not in allocated]
+
+ n2 = n - len(allocated)
+ if n2>len(min_usage_devs):
+ print("WARNING: trying to allocate %d GPUs but only %d available" % (n, len(min_usage_devs)+len(allocated)))
+ n2 = len(min_usage_devs)
+
+ allocated += min_usage_devs[:n2]
+
+ os.environ["CUDA_VISIBLE_DEVICES"]=",".join([str(a) for a in allocated])
+ for i in range(len(allocated)):
+ a = torch.FloatTensor([0.0])
+ a.cuda(i)
+
+def use_gpu(gpu="auto", n_autoalloc=1):
+ _fix_order()
+
+ gpu = gpu.lower()
+ if gpu in ["auto", ""]:
+ allocate(n_autoalloc)
+ else:
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpu
diff --git a/Utils/lockfile.py b/Utils/lockfile.py
new file mode 100644
index 0000000..f713b34
--- /dev/null
+++ b/Utils/lockfile.py
@@ -0,0 +1,41 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import os
+import fcntl
+
+
+class LockFile:
+ def __init__(self, fname):
+ self._fname = fname
+ self._fd = None
+
+ def acquire(self):
+ self._fd=open(self._fname, "w")
+ os.chmod(self._fname, 0o777)
+
+ fcntl.lockf(self._fd, fcntl.LOCK_EX)
+
+ def release(self):
+ fcntl.lockf(self._fd, fcntl.LOCK_UN)
+ self._fd.close()
+ self._fd = None
+
+ def __enter__(self):
+ self.acquire()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.release()
diff --git a/Utils/timer.py b/Utils/timer.py
new file mode 100644
index 0000000..402052a
--- /dev/null
+++ b/Utils/timer.py
@@ -0,0 +1,54 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import time
+
+
+class OnceEvery:
+ def __init__(self, interval):
+ self._interval = interval
+ self._last_check = 0
+
+ def __call__(self):
+ now = time.time()
+ if now - self._last_check >= self._interval:
+ self._last_check = now
+ return True
+ else:
+ return False
+
+
+class Measure:
+ def __init__(self, average=1):
+ self._start = None
+ self._average = average
+ self._accu_value = 0.0
+ self._history_list = []
+
+ def start(self):
+ self._start = time.time()
+
+ def passed(self):
+ if self._start is None:
+ return None
+
+ p = time.time() - self._start
+ self._history_list.append(p)
+ self._accu_value += p
+ if len(self._history_list) > self._average:
+ self._accu_value -= self._history_list.pop(0)
+
+ return self._accu_value / len(self._history_list)
diff --git a/Utils/universal.py b/Utils/universal.py
new file mode 100644
index 0000000..70ebf0c
--- /dev/null
+++ b/Utils/universal.py
@@ -0,0 +1,263 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+float32 = [np.float32, torch.float32]
+float64 = [np.float64, torch.float64]
+uint8 = [np.uint8, torch.uint8]
+
+_all_types = [float32, float64, uint8]
+
+_dtype_numpy_map = {v[0]().dtype.name:v for v in _all_types}
+_dtype_pytorch_map = {v[1]:v for v in _all_types}
+
+
+def dtype(t):
+ if torch.is_tensor(t):
+ return _dtype_pytorch_map[t.dtype]
+ else:
+ return _dtype_numpy_map[t.dtype.name]
+
+
+def cast(t, type):
+ if torch.is_tensor(t):
+ return t.type(type[1])
+ else:
+ return t.astype(type[0])
+
+
+def to_numpy(t):
+ if torch.is_tensor(t):
+ return t.detach().cpu().numpy()
+ else:
+ return t
+
+
+def to_list(t):
+ t = to_numpy(t)
+ if isinstance(t, np.ndarray):
+ t = t.tolist()
+ return t
+
+
+def is_tensor(t):
+ return torch.is_tensor(t) or isinstance(t, np.ndarray)
+
+
+def first_batch(t):
+ if is_tensor(t):
+ return t[0]
+ else:
+ return t
+
+
+def ndim(t):
+ if torch.is_tensor(t):
+ return t.dim()
+ else:
+ return t.ndim
+
+
+def shape(t):
+ return list(t.shape)
+
+
+def transpose(t, axis):
+ if torch.is_tensor(t):
+ return t.permute(axis)
+ else:
+ return np.transpose(t, axis)
+
+
+def apply_recursive(d, fn, filter=None):
+ if isinstance(d, list):
+ return [apply_recursive(da, fn) for da in d]
+ elif isinstance(d, tuple):
+ return tuple(apply_recursive(list(d), fn))
+ elif isinstance(d, dict):
+ return {k: apply_recursive(v, fn) for k, v in d.items()}
+ else:
+ if filter is None or filter(d):
+ return fn(d)
+ else:
+ return d
+
+
+def apply_to_tensors(d, fn):
+ return apply_recursive(d, fn, torch.is_tensor)
+
+
+def recursive_decorator(apply_this_fn):
+ def decorator(func):
+ def wrapped_funct(*args, **kwargs):
+ args = apply_recursive(args, apply_this_fn)
+ kwargs = apply_recursive(kwargs, apply_this_fn)
+
+ return func(*args, **kwargs)
+
+ return wrapped_funct
+
+ return decorator
+
+untensor = recursive_decorator(to_numpy)
+unnumpy = recursive_decorator(to_list)
+
+def unbatch(only_if_dim_equal=None):
+ if only_if_dim_equal is not None and not isinstance(only_if_dim_equal, list):
+ only_if_dim_equal = [only_if_dim_equal]
+
+ def get_first_batch(t):
+ if is_tensor(t) and (only_if_dim_equal is None or ndim(t) in only_if_dim_equal):
+ return t[0]
+ else:
+ return t
+
+ return recursive_decorator(get_first_batch)
+
+
+def sigmoid(t):
+ if torch.is_tensor(t):
+ return torch.sigmoid(t)
+ else:
+ return 1.0 / (1.0 + np.exp(-t))
+
+
+def argmax(t, dim):
+ if torch.is_tensor(t):
+ _, res = t.max(dim)
+ else:
+ res = np.argmax(t, axis=dim)
+
+ return res
+
+
+def flip(t, axis):
+ if torch.is_tensor(t):
+ return t.flip(axis)
+ else:
+ return np.flip(t, axis)
+
+
+def transpose(t, axes):
+ if torch.is_tensor(t):
+ return t.permute(*axes)
+ else:
+ return np.transpose(t, axes)
+
+
+def split_n(t, axis):
+ if torch.is_tensor(t):
+ return t.split(1, dim=axis)
+ else:
+ return np.split(t, t.shape[axis], axis=axis)
+
+
+def cat(array_of_tensors, axis):
+ if torch.is_tensor(array_of_tensors[0]):
+ return torch.cat(array_of_tensors, axis)
+ else:
+ return np.concatenate(array_of_tensors, axis)
+
+
+def clamp(t, min=None, max=None):
+ if torch.is_tensor(t):
+ return t.clamp(min, max)
+ else:
+ if min is not None:
+ t = np.maximum(t, min)
+
+ if max is not None:
+ t = np.minimum(t, max)
+
+ return t
+
+
+def power(t, p):
+ if torch.is_tensor(t) or torch.is_tensor(p):
+ return torch.pow(t, p)
+ else:
+ return np.power(t, p)
+
+
+def random_normal_as(a, mean, std, seed=None):
+ if torch.is_tensor(a):
+ return torch.randn_like(a) * std + mean
+ else:
+ if seed is None:
+ seed = np.random
+ return seed.normal(loc=mean, scale=std, size=shape(a))
+
+
+def pad(t, pad):
+ assert ndim(t) == 4
+
+ if torch.is_tensor(t):
+ return F.pad(t, pad)
+ else:
+ assert np.pad(t, ([0,0], [0,0], pad[0:2], pad[2:]))
+
+
+def dx(img):
+ lsh = img[:, :, :, 2:]
+ orig = img[:, :, :, :-2]
+
+ return pad(0.5 * (lsh - orig), (1, 1, 0, 0))
+
+
+def dy(img):
+ ush = img[:, :, 2:, :]
+ orig = img[:, :, :-2, :]
+
+ return pad(0.5 * (ush - orig), (0, 0, 1, 1))
+
+
+def reshape(t, shape):
+ if torch.is_tensor(t):
+ return t.view(*shape)
+ else:
+ return t.reshape(*shape)
+
+
+def broadcast_to_beginning(t, target):
+ if torch.is_tensor(t):
+ nd_target = target.dim()
+ t_shape = list(t.shape)
+ return t.view(*t_shape, *([1]*(nd_target-len(t_shape))))
+ else:
+ nd_target = target.ndim
+ t_shape = list(t.shape)
+ return t.reshape(*t_shape, *([1] * (nd_target - len(t_shape))))
+
+
+def lin_combine(d1,w1, d2,w2, bcast_begin=False):
+ if isinstance(d1, (list, tuple)):
+ assert len(d1) == len(d2)
+ res = [lin_combine(d1[i], w1, d2[i], w2) for i in range(len(d1))]
+ if isinstance(d1, tuple):
+ res = tuple(d1)
+ elif isinstance(d1, dict):
+ res = {k: lin_combine(v, w1, d2[k], w2) for k, v in d1.items()}
+ else:
+ if bcast_begin:
+ w1 = broadcast_to_beginning(w1, d1)
+ w2 = broadcast_to_beginning(w2, d2)
+
+ res = d1 * w1 + d2 * w2
+
+ return res
diff --git a/Visualize/.DS_Store b/Visualize/.DS_Store
new file mode 100644
index 0000000..5008ddf
Binary files /dev/null and b/Visualize/.DS_Store differ
diff --git a/Visualize/BitmapTask.py b/Visualize/BitmapTask.py
new file mode 100644
index 0000000..7721337
--- /dev/null
+++ b/Visualize/BitmapTask.py
@@ -0,0 +1,73 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import numpy as np
+try:
+ import cv2
+except:
+ cv2=None
+
+from Utils.Helpers import as_numpy
+
+def visualize_bitmap_task(i_data, o_data, zoom=8):
+ if not isinstance(o_data, list):
+ o_data = [o_data]
+
+ imgs = []
+ for d in [i_data]+o_data:
+ if d is None:
+ continue
+
+ d=as_numpy(d)
+ if d.ndim>2:
+ d=d[0]
+
+ imgs.append(np.expand_dims(d.T*255, -1).astype(np.uint8))
+
+ img = np.concatenate(imgs, 0)
+ return nearest_zoom(img, zoom)
+
+def visualize_01(t, zoom=8):
+ return nearest_zoom(np.expand_dims(t*255,-1).astype(np.uint8), zoom)
+
+def nearest_zoom(img, zoom=1):
+ if zoom>1 and cv2 is not None:
+ return cv2.resize(img, (img.shape[1] * zoom, img.shape[0] * zoom), interpolation=cv2.INTER_NEAREST)
+ else:
+ return img
+
+def concatenate_tensors(tensors):
+ max_size = None
+ dtype = None
+
+ for t in tensors:
+ s = t.shape
+ if max_size is None:
+ max_size = list(s)
+ dtype = t.dtype
+ continue
+
+ assert t.ndim ==len(max_size), "Can only concatenate tensors with same ndim."
+ assert t.dtype == dtype, "Tensors must have the same type"
+ max_size = [max(max_size[i], s[i]) for i in range(len(max_size))]
+
+ res = np.zeros([len(tensors)] + max_size, dtype=dtype)
+ for i, t in enumerate(tensors):
+ res[i][tuple([slice(0,t.shape[i]) for i in range(t.ndim)])] = t
+ return res
+
+
+
diff --git a/Visualize/__init__.py b/Visualize/__init__.py
new file mode 100644
index 0000000..06f5289
--- /dev/null
+++ b/Visualize/__init__.py
@@ -0,0 +1 @@
+from .BitmapTask import *
diff --git a/Visualize/preview.py b/Visualize/preview.py
new file mode 100644
index 0000000..ea70521
--- /dev/null
+++ b/Visualize/preview.py
@@ -0,0 +1,64 @@
+# Copyright 2017 Robert Csordas. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import time
+import threading
+import traceback
+import sys
+from Utils import universal as U
+
+
+def preview(vis_interval=None, to_numpy=True, debatch=False):
+ def decorator(func):
+ state = {
+ "last_vis_time": 0,
+ "thread_running": False
+ }
+
+ def wrapper_function(*args, **kwargs):
+ if state["thread_running"]:
+ return
+
+ if vis_interval is not None:
+ now = time.time()
+ if now - state["last_vis_time"] < vis_interval:
+ return
+ state["last_vis_time"] = now
+
+ state["thread_running"] = True
+
+ if debatch:
+ args = U.apply_recursive(args, U.first_batch)
+ kwargs = U.apply_recursive(kwargs, U.first_batch)
+
+ if to_numpy:
+ args = U.apply_recursive(args, U.to_numpy)
+ kwargs = U.apply_recursive(kwargs, U.to_numpy)
+
+ def run():
+ try:
+ func(*args, **kwargs)
+ except Exception:
+ traceback.print_exc()
+ sys.exit(-1)
+
+ state["thread_running"] = False
+
+ download_thread = threading.Thread(target=run)
+ download_thread.start()
+
+ return wrapper_function
+ return decorator
diff --git a/assets/demon.png b/assets/demon.png
new file mode 100644
index 0000000..bad4372
Binary files /dev/null and b/assets/demon.png differ
diff --git a/memory_demon.py b/memory_demon.py
new file mode 100644
index 0000000..726ca58
--- /dev/null
+++ b/memory_demon.py
@@ -0,0 +1,1070 @@
+#!/usr/bin/env python3#
+# The Initial DNC Copyright 2017 Robert Csordas. All Rights Reserved.
+# The modification of the initial DNC implementation by Ari Azarafrooz.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+import functools
+import os
+
+import torch.utils.data
+
+import Utils.Debug as debug
+from Dataset.Bitmap.AssociativeRecall import AssociativeRecall
+from Dataset.Bitmap.BitmapTaskRepeater import BitmapTaskRepeater
+from Dataset.Bitmap.KeyValue import KeyValue
+from Dataset.Bitmap.CopyTask import CopyData
+from Dataset.Bitmap.KeyValue2Way import KeyValue2Way
+from Dataset.NLP.bAbi import bAbiDataset
+from Models.DNCA import DNC, LSTMController, FeedforwardController
+from Models.Information_Agents import RolloutStorage, Demon, FNet, ZNet
+
+from Utils import Visdom
+from Utils.ArgumentParser import ArgumentParser
+from Utils.Index import index_by_dim
+from Utils.Saver import Saver, GlobalVarSaver, StateSaver
+from Utils.Collate import MetaCollate
+from Utils import gpu_allocator
+from Dataset.NLP.NLPTask import NLPTask
+from tqdm import tqdm
+from Visualize.preview import preview
+from Utils.timer import OnceEvery
+from Utils import Seed
+import time
+import sys
+import signal
+import math
+from Utils import Profile
+import shutil
+import math
+
+#from torch.utils.tensorboard import SummaryWriter
+
+import numpy as np
+
+model_dir = ""
+
+Profile.ENABLED = False
+
+random_seed = 1
+
+if random_seed:
+ print("Random Seed: {}".format(random_seed))
+ torch.manual_seed(random_seed)
+ np.random.seed(random_seed)
+
+if os.path.exists("tmp_train_dir"):
+ shutil.rmtree("tmp_train_dir")
+
+action_std = 0.1 # constant std for action distribution (Multivariate Normal)
+K_epochs = 1 # update policy for K epochs
+eps_clip = 0.2 # clip parameter for PPO
+gamma = 1 # discount factor
+lr = 0.001 # parameters for Adam optimizer #0.01
+betas = (0.9, 0.999)
+
+
+def main():
+ global i
+ global loss_sum
+ global running
+ parser = ArgumentParser()
+ parser.add_argument(
+ "-bit_w", type=int, default=8, help="Bit vector length for copy task"
+ )
+ parser.add_argument(
+ "-block_w", type=int, default=3, help="Block width to associative recall task"
+ )
+ parser.add_argument(
+ "-len",
+ type=str,
+ default="4",
+ help="Sequence length for copy task",
+ parser=lambda x: [int(a) for a in x.split("-")],
+ )
+ parser.add_argument(
+ "-repeat",
+ type=str,
+ default="1",
+ help="Sequence length for copy task",
+ parser=lambda x: [int(a) for a in x.split("-")],
+ )
+ parser.add_argument(
+ "-batch_size", type=int, default=16, help="Sequence length for copy task"
+ )
+ parser.add_argument(
+ "-n_subbatch",
+ type=str,
+ default="auto",
+ help="Average this much forward passes to a backward pass",
+ )
+ parser.add_argument(
+ "-max_input_count_per_batch",
+ type=int,
+ default=6000,
+ help="Max batch_size*len that can fit into memory",
+ )
+ parser.add_argument("-lr", type=float, default=0.0001, help="Learning rate")
+ parser.add_argument("-wd", type=float, default=1e-5, help="Weight decay")
+ parser.add_argument(
+ "-optimizer", type=str, default="rmsprop", help="Optimizer algorithm"
+ )
+ parser.add_argument("-name", type=str, help="Save training to this directory")
+ parser.add_argument(
+ "-preview_interval",
+ type=int,
+ default=10,
+ help="Show preview every nth iteration",
+ )
+ parser.add_argument(
+ "-info_interval", type=int, default=10, help="Show info every nth iteration"
+ )
+ parser.add_argument(
+ "-save_interval", type=int, default=500, help="Save network every nth iteration"
+ )
+ parser.add_argument(
+ "-masked_lookup", type=bool, default=1, help="Enable masking in content lookups"
+ )
+ parser.add_argument(
+ "-visport",
+ type=int,
+ default=-1,
+ help="Port to run Visdom server on. -1 to disable",
+ )
+ parser.add_argument("-gpu", default="auto", type=str, help="Run on this GPU.")
+ parser.add_argument("-debug", type=bool, default=0, help="Enable debugging")
+ parser.add_argument("-task", type=str, default="copy", help="Task to learn")
+ parser.add_argument(
+ "-mem_count", type=int, default=16, help="Number of memory cells"
+ )
+ parser.add_argument(
+ "-data_word_size", type=int, default=128, help="Memory word size"
+ )
+ parser.add_argument(
+ "-n_read_heads", type=int, default=1, help="Number of read heads"
+ )
+ parser.add_argument(
+ "-layer_sizes",
+ type=str,
+ default="256",
+ help="Controller layer sizes. Separate with ,. For example 512,256,256",
+ parser=lambda x: [int(y) for y in x.split(",") if y],
+ )
+ parser.add_argument("-debug_log", type=bool, default=0, help="Enable debug log")
+ parser.add_argument(
+ "-controller_type",
+ type=str,
+ default="lstm",
+ help="Controller type: lstm or linear",
+ )
+ parser.add_argument(
+ "-lstm_use_all_outputs",
+ type=bool,
+ default=1,
+ help="Use all LSTM outputs as controller output vs use only the last layer",
+ )
+ parser.add_argument(
+ "-momentum", type=float, default=0.9, help="Momentum for optimizer"
+ )
+ parser.add_argument(
+ "-embedding_size",
+ type=int,
+ default=256,
+ help="Size of word embedding for NLP tasks",
+ )
+ parser.add_argument(
+ "-test_interval", type=int, default=10, help="Run test in this interval"
+ )
+ parser.add_argument(
+ "-dealloc_content",
+ type=bool,
+ default=1,
+ help="Deallocate memory content, unlike DNC, which leaves it unchanged, just decreases the usage counter, causing problems with lookup",
+ )
+ parser.add_argument(
+ "-sharpness_control",
+ type=bool,
+ default=1,
+ help="Distribution sharpness control for forward and backward links",
+ )
+ parser.add_argument(
+ "-think_steps",
+ type=int,
+ default=0,
+ help="Iddle steps before requiring the answer (for bAbi)",
+ )
+ parser.add_argument("-dump_profile", type=str, save=False)
+ parser.add_argument("-test_on_start", default="0", save=False)
+ parser.add_argument("-dump_heatmaps", default=False, save=False)
+ parser.add_argument("-test_batch_size", default=16)
+ parser.add_argument("-mask_min", default=0.0)
+ parser.add_argument("-load", type=str, save=False)
+ parser.add_argument(
+ "-dataset_path",
+ type=str,
+ default="none",
+ parser=ArgumentParser.str_or_none(),
+ help="Specify babi path manually",
+ )
+ parser.add_argument(
+ "-babi_train_tasks",
+ type=str,
+ default="none",
+ parser=ArgumentParser.list_or_none(type=str),
+ help="babi task list to use for training",
+ )
+ parser.add_argument(
+ "-babi_test_tasks",
+ type=str,
+ default="none",
+ parser=ArgumentParser.list_or_none(type=str),
+ help="babi task list to use for testing",
+ )
+ parser.add_argument(
+ "-babi_train_sets",
+ type=str,
+ default="train",
+ parser=ArgumentParser.list_or_none(type=str),
+ help="babi train sets to use",
+ )
+ parser.add_argument(
+ "-babi_test_sets",
+ type=str,
+ default="test",
+ parser=ArgumentParser.list_or_none(type=str),
+ help="babi test sets to use",
+ )
+ parser.add_argument(
+ "-noargsave",
+ type=bool,
+ default=False,
+ help="Do not save modified arguments",
+ save=False,
+ )
+ parser.add_argument(
+ "-demo",
+ type=bool,
+ default=False,
+ help="Do a single step with fixed seed",
+ save=False,
+ )
+ parser.add_argument(
+ "-exit_after",
+ type=int,
+ help="Exit after this amount of steps. Useful for debugging.",
+ save=False,
+ )
+ parser.add_argument(
+ "-grad_clip", type=float, default=10.0, help="Max gradient norm"
+ )
+ parser.add_argument(
+ "-clip_controller", type=float, default=20.0, help="Max gradient norm"
+ )
+ parser.add_argument("-print_test", default=False, save=False)
+
+ parser.add_profile(
+ [
+ ArgumentParser.Profile(
+ "babi",
+ {
+ "preview_interval": 10,
+ "save_interval": 500,
+ "task": "babi",
+ "mem_count": 64,
+ "data_word_size": 64,
+ "n_read_heads": 4,
+ "layer_sizes": "128",
+ "controller_type": "lstm",
+ "lstm_use_all_outputs": True,
+ "momentum": 0.9,
+ "embedding_size": 128,
+ "test_interval": 10000,
+ "think_steps": 3,
+ "batch_size": 4,
+ },
+ include=["dnc-msd"],
+ ),
+ ArgumentParser.Profile(
+ "repeat_copy",
+ {
+ "bit_w": 8,
+ "repeat": "1-8",
+ "len": "2-14",
+ "task": "copy",
+ "think_steps": 1,
+ "preview_interval": 10,
+ "info_interval": 10,
+ "save_interval": 100,
+ "data_word_size": 16,
+ "layer_sizes": "32",
+ "n_subbatch": 1,
+ "controller_type": "lstm",
+ },
+ ),
+ ArgumentParser.Profile(
+ "repeat_copy_simple",
+ {
+ "repeat": "1-3",
+ },
+ include="repeat_copy",
+ ),
+ ArgumentParser.Profile(
+ "dnc",
+ {
+ "masked_lookup": False,
+ "sharpness_control": False,
+ "dealloc_content": False,
+ },
+ ),
+ ArgumentParser.Profile(
+ "dnc-m",
+ {
+ "masked_lookup": True,
+ "sharpness_control": False,
+ "dealloc_content": False,
+ },
+ ),
+ ArgumentParser.Profile(
+ "dnc-s",
+ {
+ "masked_lookup": False,
+ "sharpness_control": True,
+ "dealloc_content": False,
+ },
+ ),
+ ArgumentParser.Profile(
+ "dnc-d",
+ {
+ "masked_lookup": False,
+ "sharpness_control": False,
+ "dealloc_content": True,
+ },
+ ),
+ ArgumentParser.Profile(
+ "dnc-md",
+ {
+ "masked_lookup": True,
+ "sharpness_control": False,
+ "dealloc_content": True,
+ },
+ ),
+ ArgumentParser.Profile(
+ "dnc-ms",
+ {
+ "masked_lookup": True,
+ "sharpness_control": True,
+ "dealloc_content": False,
+ },
+ ),
+ ArgumentParser.Profile(
+ "dnc-sd",
+ {
+ "masked_lookup": False,
+ "sharpness_control": True,
+ "dealloc_content": True,
+ },
+ ),
+ ArgumentParser.Profile(
+ "dnc-msd",
+ {
+ "masked_lookup": True,
+ "sharpness_control": True,
+ "dealloc_content": True,
+ },
+ ),
+ ArgumentParser.Profile(
+ "keyvalue",
+ {
+ "repeat": "1",
+ "len": "2-16",
+ "mem_count": 16,
+ "task": "keyvalue",
+ "think_steps": 1,
+ "preview_interval": 10,
+ "info_interval": 10,
+ "data_word_size": 32,
+ "bit_w": 12,
+ "save_interval": 1000,
+ "layer_sizes": "32",
+ },
+ ),
+ ArgumentParser.Profile(
+ "keyvalue2way",
+ {
+ "task": "keyvalue2way",
+ },
+ include="keyvalue",
+ ),
+ ArgumentParser.Profile(
+ "associative_recall",
+ {
+ "task": "recall",
+ "bit_w": 8,
+ "len": "2-16",
+ "mem_count": 64,
+ "data_word_size": 32,
+ "n_read_heads": 1,
+ "layer_sizes": "128",
+ "controller_type": "lstm",
+ "lstm_use_all_outputs": 1,
+ "think_steps": 1,
+ "mask_min": 0.1,
+ "info_interval": 10,
+ "save_interval": 1000,
+ "preview_interval": 10,
+ "n_subbatch": 1,
+ },
+ ),
+ ]
+ )
+
+ opt = parser.parse()
+ assert opt.name is not None, "Training dir (-name parameter) not given"
+ opt = parser.sync(os.path.join(opt.name, "args.json"), save=not opt.noargsave)
+
+ if opt.demo:
+ Seed.fix()
+
+ os.makedirs(os.path.join(opt.name, "save"), exist_ok=True)
+ os.makedirs(os.path.join(opt.name, "preview"), exist_ok=True)
+
+ gpu_allocator.use_gpu(opt.gpu)
+
+ debug.enableDebug = opt.debug_log
+
+ if opt.visport > 0:
+ Visdom.start(opt.visport)
+
+ Visdom.Text("Name").set(opt.name)
+
+ class LengthHackSampler:
+ def __init__(self, batch_size, length):
+ self.length = length
+ self.batch_size = batch_size
+
+ def __iter__(self):
+ while True:
+ len = self.length() if callable(self.length) else self.length
+ yield [len] * self.batch_size
+
+ def __len__(self):
+ return 0x7FFFFFFF
+
+ embedding = None
+ test_set = None
+ curriculum = None
+ loader_reset = False
+ if opt.task == "copy":
+ dataset = CopyData(bit_w=opt.bit_w)
+ in_size = opt.bit_w + 1
+ out_size = in_size
+ elif opt.task == "recall":
+ dataset = AssociativeRecall(bit_w=opt.bit_w, block_w=opt.block_w)
+ in_size = opt.bit_w + 2
+ out_size = in_size
+ elif opt.task == "keyvalue":
+ assert opt.bit_w % 2 == 0, "Key-value datasets works only with even bit_w"
+ dataset = KeyValue(bit_w=opt.bit_w)
+ in_size = opt.bit_w + 1
+ out_size = opt.bit_w // 2
+ elif opt.task == "keyvalue2way":
+ assert opt.bit_w % 2 == 0, "Key-value datasets works only with even bit_w"
+ dataset = KeyValue2Way(bit_w=opt.bit_w)
+ in_size = opt.bit_w + 2
+ out_size = opt.bit_w // 2
+ elif opt.task == "babi":
+ dataset = bAbiDataset(think_steps=opt.think_steps, dir_name=opt.dataset_path)
+ test_set = bAbiDataset(
+ think_steps=opt.think_steps, dir_name=opt.dataset_path, name="test"
+ )
+ dataset.use(opt.babi_train_tasks, opt.babi_train_sets)
+ in_size = opt.embedding_size
+ print("bAbi: loaded total of %d sequences." % len(dataset))
+ test_set.use(opt.babi_test_tasks, opt.babi_test_sets)
+ out_size = len(dataset.vocabulary)
+ print(
+ "bAbi: using %d sequences for training, %d for testing"
+ % (len(dataset), len(test_set))
+ )
+ else:
+ assert False, "Invalid task: %s" % opt.task
+
+ if opt.task in ["babi"]:
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=opt.batch_size,
+ num_workers=4,
+ pin_memory=True,
+ shuffle=True,
+ collate_fn=MetaCollate(),
+ )
+ test_loader = (
+ torch.utils.data.DataLoader(
+ test_set,
+ batch_size=opt.test_batch_size,
+ num_workers=opt.test_batch_size,
+ pin_memory=True,
+ shuffle=False,
+ collate_fn=MetaCollate(),
+ )
+ if test_set is not None
+ else None
+ )
+ else:
+ dataset = BitmapTaskRepeater(dataset)
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_sampler=LengthHackSampler(
+ opt.batch_size, BitmapTaskRepeater.key_sampler(opt.len, opt.repeat)
+ ),
+ num_workers=1,
+ pin_memory=True,
+ )
+
+ if opt.controller_type == "lstm":
+ controller_constructor = functools.partial(
+ LSTMController, out_from_all_layers=opt.lstm_use_all_outputs
+ )
+ elif opt.controller_type == "linear":
+ controller_constructor = FeedforwardController
+ else:
+ assert False, "Invalid controller: %s" % opt.controller_type
+
+ parity_size = 0
+
+ model = DNC(
+ in_size + parity_size,
+ out_size,
+ opt.data_word_size,
+ opt.mem_count,
+ opt.n_read_heads,
+ controller_constructor(opt.layer_sizes),
+ batch_first=True,
+ mask=opt.masked_lookup,
+ dealloc_content=opt.dealloc_content,
+ link_sharpness_control=opt.sharpness_control,
+ mask_min=opt.mask_min,
+ clip_controller=opt.clip_controller,
+ )
+
+ # model.load_state_dict(torch.load(model_dir, map_location="cpu")["model"])
+
+ print("data_word_size: {}".format(opt.data_word_size))
+ rollout_storage = RolloutStorage()
+ demon_state_dim = (
+ in_size + opt.mem_count * opt.data_word_size
+ ) #:TODO opt.mem_count * opt.data_word_size
+ demon_action_dim = in_size
+
+ demon = Demon(
+ demon_state_dim,
+ demon_action_dim,
+ action_std,
+ lr,
+ betas,
+ gamma,
+ K_epochs,
+ eps_clip,
+ )
+
+ fnet = FNet()
+ fnet.init(2 * opt.mem_count * opt.data_word_size)
+ #fnet.load_state_dict(torch.load(model_dir, map_location="cpu")["FNet"])
+
+ znet = ZNet()
+ znet.init(opt.mem_count * opt.data_word_size)
+ #znet.load_state_dict(torch.load(model_dir, map_location="cpu")["ZNet"])
+
+ params = [
+ {"params": [p for n, p in model.named_parameters() if not n.endswith(".bias")]},
+ {
+ "params": [p for n, p in model.named_parameters() if n.endswith(".bias")],
+ "weight_decay": 0,
+ },
+ ]
+
+ device = (
+ torch.device("cuda")
+ if opt.gpu != "none" and torch.cuda.is_available()
+ else torch.device("cpu")
+ )
+ print("DEVICE: ", device)
+
+ if isinstance(dataset, NLPTask):
+ embedding = torch.nn.Embedding(len(dataset.vocabulary), opt.embedding_size).to(
+ device
+ )
+ params.append({"params": embedding.parameters(), "weight_decay": 0})
+ # embedding.load_state_dict(
+ # torch.load(model_dir, map_location="cpu")["word_embeddings"]
+ # )
+
+ if opt.optimizer == "sgd":
+ optimizer = torch.optim.SGD(
+ params, lr=opt.lr, weight_decay=opt.wd, momentum=opt.momentum
+ )
+ elif opt.optimizer == "adam":
+ optimizer = torch.optim.Adam(params, lr=opt.lr, weight_decay=opt.wd)
+ elif opt.optimizer == "rmsprop":
+ optimizer = torch.optim.RMSprop(
+ params, lr=opt.lr, weight_decay=opt.wd, momentum=opt.momentum, eps=1e-10
+ )
+ else:
+ assert "Invalid optimizer: %s" % opt.optimizer
+
+ n_params = sum([sum([t.numel() for t in d["params"]]) for d in params])
+ print("Number of parameters: %d" % n_params)
+
+ model = model.to(device)
+ fnet = fnet.to(device)
+ znet = znet.to(device)
+ znet_optim = torch.optim.Adam(znet.parameters(), lr=0.001)
+ fnet_optim = torch.optim.Adam(fnet.parameters(), lr=0.001)
+
+ if embedding is not None and hasattr(embedding, "to"):
+ embedding = embedding.to(device)
+
+ i = 0
+ loss_sum = 0
+
+ loss_plot = Visdom.Plot2D(
+ "loss", store_interval=opt.info_interval, xlabel="iterations", ylabel="loss"
+ )
+
+ if curriculum is not None:
+ curriculum_plot = Visdom.Plot2D(
+ "curriculum lesson"
+ + (
+ " (last %d)" % (curriculum.n_lessons - 1)
+ if curriculum.n_lessons is not None
+ else ""
+ ),
+ xlabel="iterations",
+ ylabel="lesson",
+ )
+ curriculum_accuracy = Visdom.Plot2D(
+ "curriculum accuracy", xlabel="iterations", ylabel="accuracy"
+ )
+
+ saver = Saver(os.path.join(opt.name, "save"), short_interval=opt.save_interval)
+ saver.register("model", StateSaver(model))
+ saver.register("optimizer", StateSaver(optimizer))
+ saver.register("i", GlobalVarSaver("i"))
+ saver.register("loss_sum", GlobalVarSaver("loss_sum"))
+ saver.register("loss_plot", StateSaver(loss_plot))
+ saver.register("dataset", StateSaver(dataset))
+ if test_set:
+ saver.register("test_set", StateSaver(test_set))
+
+ if curriculum is not None:
+ saver.register("curriculum", StateSaver(curriculum))
+ saver.register("curriculum_plot", StateSaver(curriculum_plot))
+ saver.register("curriculum_accuracy", StateSaver(curriculum_accuracy))
+
+ if isinstance(dataset, NLPTask):
+ saver.register("word_embeddings", StateSaver(embedding))
+ elif embedding is not None:
+ saver.register("embeddings", StateSaver(embedding))
+
+ visualizers = {}
+
+ debug_schemas = {
+ "read_head": {"list_dim": 2},
+ "temporal_links/forward_dists": {"list_dim": 2},
+ "temporal_links/backward_dists": {"list_dim": 2},
+ }
+
+ def plot_debug(debug, prefix="", schema={}):
+ if debug is None:
+ return
+
+ for k, v in debug.items():
+ curr_name = prefix + k
+ if curr_name in debug_schemas:
+ curr_schema = schema.copy()
+ curr_schema.update(debug_schemas[curr_name])
+ else:
+ curr_schema = schema
+
+ if isinstance(v, dict):
+ plot_debug(v, curr_name + "/", curr_schema)
+ continue
+
+ data = v[0]
+
+ if curr_schema.get("list_dim", -1) > 0:
+ if data.ndim != 3:
+ print(
+ "WARNING: unknown data shape for array display: %s, tensor %s"
+ % (data.shape, curr_name)
+ )
+ continue
+
+ n_steps = data.shape[curr_schema["list_dim"] - 1]
+ if curr_name not in visualizers:
+ visualizers[curr_name] = [
+ Visdom.Heatmap(
+ curr_name + "_%d" % i,
+ dumpdir=os.path.join(opt.name, "preview")
+ if opt.dump_heatmaps
+ else None,
+ )
+ for i in range(n_steps)
+ ]
+
+ for i in range(n_steps):
+ visualizers[curr_name][i].draw(
+ index_by_dim(data, curr_schema["list_dim"] - 1, i)
+ )
+ else:
+ if data.ndim != 2:
+ print(
+ "WARNING: unknown data shape for simple display: %s, tensor %s"
+ % (data.shape, curr_name)
+ )
+ continue
+
+ if curr_name not in visualizers:
+ visualizers[curr_name] = Visdom.Heatmap(
+ curr_name,
+ dumpdir=os.path.join(opt.name, "preview")
+ if opt.dump_heatmaps
+ else None,
+ )
+
+ visualizers[curr_name].draw(data)
+
+ def run_model(input, debug=None, demon=None, rollout_storage=None):
+ if isinstance(dataset, NLPTask):
+ input = embedding(input["input"])
+ else:
+ input = input["input"] * 2.0 - 1.0
+
+ return model(input, debug=debug, demon=demon, rollout_storage=rollout_storage)
+
+ def run_znet(mem_state):
+ mem_state = mem_state[:, 1:, :]
+ return znet(mem_state)
+
+ def run_fnet(mem_state, marginal=False):
+ shuffled_mem_state = None
+ if not marginal:
+ input = torch.cat((mem_state[:, :-1, :], mem_state[:, 1:, :]), dim=2)
+ else:
+ shuffled_indx = torch.randperm(mem_state.size(0)).to(
+ device
+ ) # random index for shuffling the elements of batch
+ shuffled_mem_state = mem_state.index_select(0, shuffled_indx)
+ input = torch.cat(
+ (mem_state[:, :-1, :], shuffled_mem_state[:, 1:, :]), dim=2
+ )
+
+ return fnet(input), shuffled_mem_state
+
+ def multiply_grads(params, mul):
+ if mul == 1:
+ return
+
+ for pa in params:
+ for p in pa["params"]:
+ p.grad.data *= mul
+
+ def test():
+ if test_set is None:
+ return
+
+ print("TESTING...")
+ start_time = time.time()
+ t = test_set.start_test()
+ with torch.no_grad():
+ for data in tqdm(test_loader):
+ data = {
+ k: v.to(device) if torch.is_tensor(v) else v
+ for k, v in data.items()
+ }
+ if hasattr(dataset, "prepare"):
+ data = dataset.prepare(data)
+
+ net_out = run_model(data, demon=demon)
+ test_set.veify_result(t, data, net_out)
+
+ test_set.show_test_results(i, t)
+ print("Test done in %gs" % (time.time() - start_time))
+
+ if opt.test_on_start.lower() in ["on", "1", "true", "quit"]:
+ test()
+ if opt.test_on_start.lower() == "quit":
+ saver.write(i)
+ sys.exit(-1)
+
+ if opt.print_test:
+ model.eval()
+ total = 0
+ correct = 0
+ with torch.no_grad():
+ for data in tqdm(test_loader):
+ if not running:
+ return
+
+ data = {
+ k: v.to(device) if torch.is_tensor(v) else v
+ for k, v in data.items()
+ }
+ if hasattr(test_set, "prepare"):
+ data = test_set.prepare(data)
+
+ net_out = run_model(data, demon)
+
+ c, t = test_set.curriculum_measure(net_out, data["output"])
+ total += t
+ correct += c
+
+ print(
+ "Test result: %2.f%% (%d out of %d correct)"
+ % (100.0 * correct / total, correct, total)
+ )
+ model.train()
+ return
+
+ iter_start_time = time.time() if i % opt.info_interval == 0 else None
+ data_load_total_time = 0
+
+ start_i = i
+
+ if opt.dump_profile:
+ profiler = torch.autograd.profiler.profile(use_cuda=True)
+
+ if opt.dump_heatmaps:
+ dataset.set_dump_dir(os.path.join(opt.name, "preview"))
+
+ @preview()
+ def do_visualize(raw_data, output, pos_map, debug):
+ if pos_map is not None:
+ output = embedding.backmap_output(
+ output, pos_map, raw_data["output"].shape[1]
+ )
+ dataset.visualize_preview(raw_data, output)
+
+ if debug is not None:
+ plot_debug(debug)
+
+ preview_timer = OnceEvery(opt.preview_interval)
+
+ pos_map = None
+ start_iter = i
+
+ if curriculum is not None:
+ curriculum.init()
+
+ ma_et = 1.0
+ while running:
+ data_load_timer = time.time()
+ for data in data_loader:
+ if not running:
+ break
+
+ if loader_reset:
+ print("Loader reset requested. Resetting...")
+ loader_reset = False
+ if curriculum is not None:
+ curriculum.lesson_started()
+ break
+
+ if opt.dump_profile:
+ if i == start_i + 1:
+ print("Starting profiler")
+ profiler.__enter__()
+ elif i == start_i + 5 + 1:
+ print("Stopping profiler")
+ profiler.__exit__(None, None, None)
+ print("Average stats")
+ print(profiler.key_averages().table("cpu_time_total"))
+ print("Writing trace to file")
+ profiler.export_chrome_trace(opt.dump_profile)
+ print("Done.")
+ sys.exit(0)
+ else:
+ print("Step %d out of 5" % (i - start_i))
+
+ debug.dbg_print("-------------------------------------")
+ raw_data = data
+
+ data = {
+ k: v.to(device) if torch.is_tensor(v) else v for k, v in data.items()
+ }
+ if hasattr(dataset, "prepare"):
+ data = dataset.prepare(data)
+
+ data_load_total_time += time.time() - data_load_timer
+
+ need_preview = preview_timer()
+ debug_data = {} if opt.debug and need_preview else None
+
+ optimizer.zero_grad()
+ # demon_optim.zero_grad()
+ # znet_optim.zero_grad()
+ # fnet_optim.zero_grad()
+
+ if opt.n_subbatch == "auto":
+ n_subbatch = math.ceil(
+ data["input"].numel() / opt.max_input_count_per_batch
+ )
+ else:
+ n_subbatch = int(opt.n_subbatch)
+
+ real_batch = max(math.floor(opt.batch_size / n_subbatch), 1)
+ n_subbatch = math.ceil(opt.batch_size / real_batch)
+ remaning_batch = opt.batch_size % real_batch
+
+ for subbatch in range(n_subbatch):
+ if not running:
+ break
+ input = data["input"]
+ target = data["output"]
+
+ if n_subbatch != 1:
+ input = input[subbatch * real_batch : (subbatch + 1) * real_batch]
+ target = target[subbatch * real_batch : (subbatch + 1) * real_batch]
+
+ f2 = data.copy()
+ f2["input"] = input
+
+ # Demon modifies the memory before DNC model
+ output = run_model(
+ f2,
+ debug=debug_data if subbatch == n_subbatch - 1 else None,
+ demon=demon,
+ rollout_storage=rollout_storage,
+ )
+ l = dataset.loss(output, target)
+ l.backward()
+
+ mem_state = torch.stack(model.mem_state, dim=1)
+ t, _ = run_fnet(mem_state.detach())
+ z = run_znet(mem_state.detach())
+
+ et, shuffled_mem = run_fnet(mem_state.detach(), marginal=True)
+ et = torch.exp(et)
+ mi_lb = t - (torch.mean(et) / z + torch.log(z) - 1)
+
+ demon_rewards = mi_lb
+ info_loss = -mi_lb.mean()
+
+ info_loss.backward()
+
+ fnet_optim.step()
+ fnet_optim.zero_grad()
+
+ znet_optim.step()
+ znet_optim.zero_grad()
+
+ del model.mem_state[:] # reset the mem state
+
+ for j in range(0, demon_rewards.size(1)):
+ rollout_storage.rewards.append(demon_rewards[:, j].detach())
+
+ # update if its time
+ if i % 1 == 0: # TODO demon_update_timestep = C
+ demon_loss = demon.update(rollout_storage)
+ rollout_storage.clear_storage()
+
+ debug.nan_check(l, force=True)
+
+ if curriculum is not None:
+ curriculum.update(*dataset.curriculum_measure(output, target))
+
+ if remaning_batch != 0 and subbatch == n_subbatch - 2:
+ multiply_grads(params, real_batch / remaning_batch)
+
+ if n_subbatch != 1:
+ if remaning_batch == 0:
+ multiply_grads(params, 1 / n_subbatch)
+ else:
+ multiply_grads(params, remaning_batch / opt.batch_size)
+
+ for p in params:
+ torch.nn.utils.clip_grad_norm_(p["params"], opt.grad_clip)
+
+ optimizer.step()
+
+ i += 1
+
+ curr_loss = l.data.item()
+ loss_plot.add_point(i, curr_loss)
+
+ # writer.add_scalar("associative-recall-loss/dnc-md", curr_loss, i)
+ # writer.add_scalar("associative-recall-mutual_info/dnc-md", info_loss, i)
+ # writer.add_scalar(
+ # "associative-recall-demon_loss/dnc-md", demon_loss.mean(), i
+ # )
+
+ loss_sum += curr_loss
+
+ if i % opt.info_interval == 0:
+ tim = time.time()
+ loss_avg = loss_sum / opt.info_interval
+
+ if curriculum is not None:
+ curriculum_accuracy.add_point(i, curriculum.get_accuracy())
+ curriculum_plot.add_point(i, curriculum.step)
+
+ message = "Iteration %d, loss: %.4f" % (i, loss_avg)
+ if iter_start_time is not None:
+ message += (
+ " (%.2f ms/iter, load time %.2g ms/iter, visport: %s)"
+ % (
+ (tim - iter_start_time) / opt.info_interval * 1000.0,
+ data_load_total_time / opt.info_interval * 1000.0,
+ Visdom.port,
+ )
+ )
+ print(message)
+ iter_start_time = tim
+ loss_sum = 0
+ data_load_total_time = 0
+
+ debug.dbg_print("Iteration %d, loss %g" % (i, curr_loss))
+
+ if need_preview:
+ do_visualize(raw_data, output, pos_map, debug_data)
+
+ if i % opt.test_interval == 0:
+ test()
+
+ saver.tick(i)
+
+ if opt.demo and opt.exit_after is None:
+ running = False
+ input("Press enter to quit.")
+
+ if opt.exit_after is not None and (i - start_iter) >= opt.exit_after:
+ running = False
+
+ data_load_timer = time.time()
+
+
+if __name__ == "__main__":
+ #writer = SummaryWriter()
+ global running
+ running = True
+
+ def signal_handler(signal, frame):
+ global running
+ print("You pressed Ctrl+C!")
+ running = False
+
+ signal.signal(signal.SIGINT, signal_handler)
+
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..3caa71c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+tqdm
+torch
+visdom
+numpy
+tensorboard
+
+
+