dnc-demon impl
This commit is contained in:
parent
a278481269
commit
e6b110aaae
BIN
Dataset/.DS_Store
vendored
Normal file
BIN
Dataset/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
Dataset/Bitmap/.DS_Store
vendored
Normal file
BIN
Dataset/Bitmap/.DS_Store
vendored
Normal file
Binary file not shown.
72
Dataset/Bitmap/AssociativeRecall.py
Normal file
72
Dataset/Bitmap/AssociativeRecall.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
from .BitmapTask import BitmapTask
|
||||
from Utils.Seed import get_randstate
|
||||
|
||||
|
||||
class AssociativeRecall(BitmapTask):
|
||||
def __init__(self, length=None, bit_w=8, block_w=3, transform=lambda x: x):
|
||||
super(AssociativeRecall, self).__init__()
|
||||
self.length = length
|
||||
self.bit_w = bit_w
|
||||
self.block_w = block_w
|
||||
self.transform = transform
|
||||
self.seed = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
if self.seed is None:
|
||||
self.seed = get_randstate()
|
||||
|
||||
length = self.length() if callable(self.length) else self.length
|
||||
if length is None:
|
||||
# Random length batch hack.
|
||||
length = key
|
||||
|
||||
stride = self.block_w + 1
|
||||
|
||||
d = self.seed.randint(
|
||||
0, 2, [length * (self.block_w + 1), self.bit_w + 2]
|
||||
).astype(np.float32)
|
||||
d[:, -2:] = 0
|
||||
|
||||
# Terminate input block
|
||||
for i in range(1, length, 1):
|
||||
d[i * stride - 1, :] = 0
|
||||
d[i * stride - 1, -2] = 1
|
||||
|
||||
# Terminate input sequence
|
||||
d[-1, :] = 0
|
||||
d[-1, -1] = 1
|
||||
|
||||
# Add and terminate query
|
||||
ti = self.seed.randint(0, length - 1)
|
||||
d = np.concatenate(
|
||||
(
|
||||
d,
|
||||
d[ti * stride : (ti + 1) * stride - 1],
|
||||
np.zeros([self.block_w + 1, self.bit_w + 2], np.float32),
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
d[-(1 + self.block_w), -1] = 1
|
||||
|
||||
# Target
|
||||
target = np.zeros_like(d)
|
||||
target[-self.block_w :] = d[(ti + 1) * stride : (ti + 2) * stride - 1]
|
||||
|
||||
return self.transform({"input": d, "output": target})
|
82
Dataset/Bitmap/BitmapTask.py
Normal file
82
Dataset/Bitmap/BitmapTask.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from Visualize.BitmapTask import visualize_bitmap_task
|
||||
from Utils import Visdom
|
||||
from Utils import universal as U
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BitmapTask(torch.utils.data.Dataset):
|
||||
def __init__(self):
|
||||
super(BitmapTask, self).__init__()
|
||||
|
||||
self._img = Visdom.Image("preview")
|
||||
|
||||
def set_dump_dir(self, dir):
|
||||
self._img.set_dump_dir(dir)
|
||||
|
||||
def __len__(self):
|
||||
return 0x7FFFFFFF
|
||||
|
||||
def visualize_preview(self, data, net_output):
|
||||
img = visualize_bitmap_task(
|
||||
data["input"], [data["output"], U.sigmoid(net_output)]
|
||||
)
|
||||
self._img.draw(img)
|
||||
|
||||
def loss(self, net_output, target):
|
||||
return F.binary_cross_entropy_with_logits(
|
||||
net_output, target, reduction="sum"
|
||||
) / net_output.size(0)
|
||||
|
||||
def accuracy(self, net_output, target):
|
||||
return F.binary_cross_entropy_with_logits(
|
||||
net_output, target, reduction="sum"
|
||||
) / net_output.size(0)
|
||||
|
||||
def demon_loss(self, net_output, target, saved_actions, device):
|
||||
"""
|
||||
computes the loss for the demon
|
||||
:param net_output:
|
||||
:param target:
|
||||
:param saved_actions:
|
||||
:return:
|
||||
"""
|
||||
net_output = net_output.detach()
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
net_output, target, reduction="none"
|
||||
).sum(dim=-1)
|
||||
|
||||
policy_losses = [] # list to save actor (policy) loss
|
||||
|
||||
discount_factor = 0.99
|
||||
for i in range(0, loss.size(1)): # computing expected total reward
|
||||
discount_vector = torch.from_numpy(np.array([np.power(discount_factor,i) for i in range(loss.size(1)-i)])).to(device)
|
||||
policy_losses.append(((saved_actions[i].log_prob).squeeze(1) * (discount_vector*loss[:, i:]).mean(dim=1)))
|
||||
|
||||
demon_loss = torch.stack(policy_losses).mean(dim=0)/loss.size(1)
|
||||
|
||||
return demon_loss
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
56
Dataset/Bitmap/BitmapTaskRepeater.py
Normal file
56
Dataset/Bitmap/BitmapTaskRepeater.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
from .BitmapTask import BitmapTask
|
||||
|
||||
|
||||
class BitmapTaskRepeater(BitmapTask):
|
||||
def __init__(self, dataset):
|
||||
super(BitmapTaskRepeater, self).__init__()
|
||||
self.dataset = dataset
|
||||
|
||||
def __getitem__(self, key):
|
||||
r = [self.dataset[k] for k in key]
|
||||
if len(r) == 1:
|
||||
return r[0]
|
||||
else:
|
||||
return {
|
||||
"input": np.concatenate([a["input"] for a in r], axis=0),
|
||||
"output": np.concatenate([a["output"] for a in r], axis=0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def key_sampler(length, repeat):
|
||||
def call_sampler(s):
|
||||
if callable(s):
|
||||
return s()
|
||||
elif isinstance(s, list):
|
||||
if len(s) == 2:
|
||||
return random.randint(*s)
|
||||
elif len(s) == 1:
|
||||
return s[0]
|
||||
else:
|
||||
assert False, "Invalid sample parameter: %s" % s
|
||||
else:
|
||||
return s
|
||||
|
||||
def s():
|
||||
r = call_sampler(repeat)
|
||||
return [call_sampler(length) for i in range(r)]
|
||||
|
||||
return s
|
49
Dataset/Bitmap/CopyTask.py
Normal file
49
Dataset/Bitmap/CopyTask.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
from .BitmapTask import BitmapTask
|
||||
from Utils.Seed import get_randstate
|
||||
|
||||
|
||||
class CopyData(BitmapTask):
|
||||
def __init__(self, length=None, bit_w=8, transform=lambda x: x):
|
||||
super(CopyData, self).__init__()
|
||||
self.length = length
|
||||
self.bit_w = bit_w
|
||||
self.transform = transform
|
||||
self.seed = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
if self.seed is None:
|
||||
self.seed = get_randstate()
|
||||
|
||||
length = self.length() if callable(self.length) else self.length
|
||||
if length is None:
|
||||
# Random length batch hack.
|
||||
length = key
|
||||
|
||||
d = self.seed.randint(0, 2, [length + 1, self.bit_w + 1]).astype(np.float32)
|
||||
z = np.zeros_like(d)
|
||||
|
||||
d[-1] = 0
|
||||
d[:, -1] = 0
|
||||
d[-1, -1] = 1
|
||||
|
||||
i_p = np.concatenate((d, z), axis=0)
|
||||
o_p = np.concatenate((z, d), axis=0)
|
||||
|
||||
return self.transform({"input": i_p, "output": o_p})
|
81
Dataset/Bitmap/KeyValue.py
Normal file
81
Dataset/Bitmap/KeyValue.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from .BitmapTask import BitmapTask
|
||||
from Utils.Seed import get_randstate
|
||||
|
||||
|
||||
class KeyValue(BitmapTask):
|
||||
def __init__(self, length=None, bit_w=8, transform=lambda x: x):
|
||||
assert bit_w % 2 == 0, "bit_w must be even"
|
||||
super(KeyValue, self).__init__()
|
||||
self.length = length
|
||||
self.bit_w = bit_w
|
||||
self.transform = transform
|
||||
self.seed = None
|
||||
self.key_w = self.bit_w // 2
|
||||
self.max_key = 2 ** self.key_w - 1
|
||||
|
||||
def __getitem__(self, key):
|
||||
if self.seed is None:
|
||||
self.seed = get_randstate()
|
||||
|
||||
if self.length is None:
|
||||
# Random length batch hack.
|
||||
length = key
|
||||
else:
|
||||
length = self.length() if callable(self.length) else self.length
|
||||
|
||||
# keys must be unique
|
||||
keys = None
|
||||
last_size = 0
|
||||
while last_size != length:
|
||||
res = self.seed.random_integers(0, self.max_key, size=(length - last_size))
|
||||
if keys is not None:
|
||||
keys = np.concatenate((res, keys))
|
||||
else:
|
||||
keys = res
|
||||
|
||||
keys = np.unique(keys)
|
||||
last_size = keys.size
|
||||
|
||||
# view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it
|
||||
keys = keys.view(np.uint8).reshape(length, -1)
|
||||
keys = keys[:, : math.ceil(self.key_w / 8)]
|
||||
keys = np.unpackbits(np.expand_dims(keys, -1), axis=-1)
|
||||
keys = np.flip(keys, axis=-1).reshape(keys.shape[0], -1)[:, : self.key_w]
|
||||
keys = keys.astype(np.float32)
|
||||
|
||||
values = self.seed.randint(0, 2, keys.shape).astype(np.float32)
|
||||
|
||||
perm = self.seed.permutation(length)
|
||||
keys_perm = keys[perm, :]
|
||||
values_perm = values[perm, :]
|
||||
|
||||
i_p = np.zeros((2 * length + 2, self.bit_w + 1), dtype=np.float32)
|
||||
i_p[:length, : self.key_w] = keys
|
||||
i_p[:length, self.key_w : -1] = values
|
||||
i_p[length + 1 : -1, : self.key_w] = keys_perm
|
||||
|
||||
i_p[length, -1] = 1
|
||||
i_p[-1, -1] = 1
|
||||
|
||||
o_p = np.zeros((2 * length + 2, self.key_w), dtype=np.float32)
|
||||
o_p[length + 1 : -1] = values_perm
|
||||
|
||||
return self.transform({"input": i_p, "output": o_p})
|
89
Dataset/Bitmap/KeyValue2Way.py
Normal file
89
Dataset/Bitmap/KeyValue2Way.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from .BitmapTask import BitmapTask
|
||||
from Utils.Seed import get_randstate
|
||||
|
||||
|
||||
class KeyValue2Way(BitmapTask):
|
||||
def __init__(self, length=None, bit_w=8, transform=lambda x: x):
|
||||
assert bit_w % 2 == 0, "bit_w must be even"
|
||||
super(KeyValue2Way, self).__init__()
|
||||
self.length = length
|
||||
self.bit_w = bit_w
|
||||
self.transform = transform
|
||||
self.seed = None
|
||||
self.key_w = self.bit_w // 2
|
||||
self.max_key = 2 ** self.key_w - 1
|
||||
|
||||
def __getitem__(self, key):
|
||||
if self.seed is None:
|
||||
self.seed = get_randstate()
|
||||
|
||||
if self.length is None:
|
||||
# Random length batch hack.
|
||||
length = key
|
||||
else:
|
||||
length = self.length() if callable(self.length) else self.length
|
||||
|
||||
# keys must be unique
|
||||
keys = None
|
||||
last_size = 0
|
||||
while last_size != length:
|
||||
res = self.seed.random_integers(0, self.max_key, size=(length - last_size))
|
||||
if keys is not None:
|
||||
keys = np.concatenate((res, keys))
|
||||
else:
|
||||
keys = res
|
||||
|
||||
keys = np.unique(keys)
|
||||
last_size = keys.size
|
||||
|
||||
# view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it
|
||||
keys = keys.view(np.uint8).reshape(length, -1)
|
||||
keys = keys[:, : math.ceil(self.key_w / 8)]
|
||||
keys = np.unpackbits(np.expand_dims(keys, -1), axis=-1)
|
||||
keys = np.flip(keys, axis=-1).reshape(keys.shape[0], -1)[:, : self.key_w]
|
||||
keys = keys.astype(np.float32)
|
||||
|
||||
values = self.seed.randint(0, 2, keys.shape).astype(np.float32)
|
||||
|
||||
perm = self.seed.permutation(length)
|
||||
keys_perm = keys[perm, :]
|
||||
values_perm = values[perm, :]
|
||||
|
||||
i_p = np.zeros((3 * (length + 1), self.bit_w + 2), dtype=np.float32)
|
||||
o_p = np.zeros((3 * (length + 1), self.key_w), dtype=np.float32)
|
||||
|
||||
i_p[:length, : self.key_w] = keys
|
||||
i_p[:length, self.key_w : -2] = values
|
||||
i_p[length + 1 : 2 * length + 1, : self.key_w] = keys_perm
|
||||
o_p[length + 1 : 2 * length + 1] = values_perm
|
||||
|
||||
perm = self.seed.permutation(length)
|
||||
keys_perm = keys[perm, :]
|
||||
values_perm = values[perm, :]
|
||||
|
||||
o_p[2 * (length + 1) : -1] = keys_perm
|
||||
i_p[2 * (length + 1) : -1, : self.key_w] = values_perm
|
||||
|
||||
i_p[length, -2] = 1
|
||||
i_p[2 * length + 1, -1] = 1
|
||||
i_p[-1, -2:] = 1
|
||||
|
||||
return self.transform({"input": i_p, "output": o_p})
|
0
Dataset/Bitmap/__init__.py
Normal file
0
Dataset/Bitmap/__init__.py
Normal file
BIN
Dataset/NLP/.DS_Store
vendored
Normal file
BIN
Dataset/NLP/.DS_Store
vendored
Normal file
Binary file not shown.
1
Dataset/NLP/.gitignore
vendored
Normal file
1
Dataset/NLP/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
cache
|
153
Dataset/NLP/NLPTask.py
Normal file
153
Dataset/NLP/NLPTask.py
Normal file
@ -0,0 +1,153 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
from .Vocabulary import Vocabulary
|
||||
from Utils import Visdom
|
||||
from Utils import universal as U
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NLPTask(torch.utils.data.Dataset):
|
||||
def __init__(self):
|
||||
super(NLPTask, self).__init__()
|
||||
|
||||
self.my_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
self.cache_dir = os.path.join(self.my_dir, "cache")
|
||||
|
||||
if not os.path.isdir(self.cache_dir):
|
||||
os.makedirs(self.cache_dir)
|
||||
|
||||
self.vocabulary = self._load_vocabulary()
|
||||
|
||||
self._preview = None
|
||||
|
||||
def _load_vocabulary(self):
|
||||
cache_file = os.path.join(self.cache_dir, "vocabulary.pth")
|
||||
if not os.path.isfile(cache_file):
|
||||
print("WARNING: Vocabulary not found. Removing cached files.")
|
||||
for f in os.listdir(self.cache_dir):
|
||||
f = os.path.join(self.cache_dir, f)
|
||||
if f.endswith(".pth"):
|
||||
print(" " + f)
|
||||
os.remove(f)
|
||||
return Vocabulary()
|
||||
else:
|
||||
return torch.load(cache_file)
|
||||
|
||||
def save_vocabulary(self):
|
||||
cache_file = os.path.join(self.cache_dir, "vocabulary.pth")
|
||||
if os.path.isfile(cache_file):
|
||||
os.remove(cache_file)
|
||||
torch.save(self.vocabulary, cache_file)
|
||||
|
||||
def loss(self, net_output, target):
|
||||
s = list(net_output.size())
|
||||
return (
|
||||
F.cross_entropy(
|
||||
net_output.view([s[0] * s[1], s[2]]),
|
||||
target.view([-1]),
|
||||
ignore_index=0,
|
||||
reduction="sum",
|
||||
)
|
||||
/ s[0]
|
||||
)
|
||||
|
||||
|
||||
def demon_loss(self, net_output, target, saved_actions, device):
|
||||
"""
|
||||
computes the loss for the demon
|
||||
:param net_output:
|
||||
:param target:
|
||||
:param saved_actions:
|
||||
:return:
|
||||
"""
|
||||
net_output = net_output.detach()
|
||||
s = list(net_output.size())
|
||||
loss = F.cross_entropy(
|
||||
net_output.view([s[0] * s[1], s[2]]),
|
||||
target.view([-1]),
|
||||
ignore_index=0,
|
||||
reduction="none",
|
||||
).view(s[0], s[1])
|
||||
|
||||
policy_losses = [] # list to save actor (policy) loss
|
||||
|
||||
discount_factor = 0.99
|
||||
for i in range(0, loss.size(1)): # computing expected total reward
|
||||
discount_vector = torch.from_numpy(
|
||||
np.array([np.power(discount_factor, i) for i in range(loss.size(1) - i)])).to(device)
|
||||
policy_losses.append(((saved_actions[i].log_prob).squeeze(1) * (discount_vector * loss[:, i:]).mean(dim=1)))
|
||||
|
||||
demon_loss = torch.stack(policy_losses).mean(dim=0)
|
||||
|
||||
return demon_loss
|
||||
|
||||
|
||||
def generate_preview_text(self, data, net_output):
|
||||
input = U.to_numpy(data["input"][0])
|
||||
reference = U.to_numpy(data["output"][0])
|
||||
net_out = U.argmax(net_output[0], -1)
|
||||
net_out = U.to_numpy(net_out)
|
||||
|
||||
res = ""
|
||||
start_index = 0
|
||||
|
||||
for i in range(input.shape[0]):
|
||||
if reference[i] != 0:
|
||||
if start_index < i:
|
||||
end_index = i
|
||||
while end_index > start_index and input[end_index] == 0:
|
||||
end_index -= 1
|
||||
|
||||
if end_index > start_index:
|
||||
sentence = (
|
||||
" ".join(
|
||||
self.vocabulary.indices_to_sentence(
|
||||
input[start_index:i].tolist()
|
||||
)
|
||||
)
|
||||
.replace(" .", ".")
|
||||
.replace(" ,", ",")
|
||||
.replace(" ?", "?")
|
||||
.split(". ")
|
||||
)
|
||||
sentence = ". ".join([s.capitalize() for s in sentence])
|
||||
res += sentence + "<br>"
|
||||
|
||||
start_index = i + 1
|
||||
|
||||
match = reference[i] == net_out[i]
|
||||
res += '<b><font color="%s">%s [%s]</font><br></b>' % (
|
||||
"green" if match else "red",
|
||||
self.vocabulary.indices_to_sentence([net_out[i]])[0],
|
||||
self.vocabulary.indices_to_sentence([reference[i]])[0],
|
||||
)
|
||||
return res
|
||||
|
||||
def visualize_preview(self, data, net_output):
|
||||
res = self.generate_preview_text(data, net_output)
|
||||
|
||||
if self._preview is None:
|
||||
self._preview = Visdom.Text("Preview")
|
||||
|
||||
self._preview.set(res)
|
||||
|
||||
def set_dump_dir(self, dir):
|
||||
pass
|
52
Dataset/NLP/Vocabulary.py
Normal file
52
Dataset/NLP/Vocabulary.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Vocabulary:
|
||||
def __init__(self):
|
||||
self.words = {"-": 0, "?": 1, "<UNK>": 2}
|
||||
self.inv_words = {0: "-", 1: "?", 2: "<UNK>"}
|
||||
self.next_id = 3
|
||||
self.punctations = [".", "?", ","]
|
||||
|
||||
def _process_word(self, w, add_words):
|
||||
if not w.isalpha() and w not in self.punctations:
|
||||
print("WARNING: word with unknown characters: %s", w)
|
||||
w = "<UNK>"
|
||||
|
||||
if w not in self.words:
|
||||
if add_words:
|
||||
self.words[w] = self.next_id
|
||||
self.inv_words[self.next_id] = w
|
||||
self.next_id += 1
|
||||
else:
|
||||
w = "<UNK>"
|
||||
|
||||
return self.words[w]
|
||||
|
||||
def sentence_to_indices(self, sentence, add_words=True):
|
||||
for p in self.punctations:
|
||||
sentence = sentence.replace(p, " %s " % p)
|
||||
|
||||
return [
|
||||
self._process_word(w, add_words) for w in sentence.lower().split(" ") if w
|
||||
]
|
||||
|
||||
def indices_to_sentence(self, indices):
|
||||
return [self.inv_words[i] for i in indices]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.words)
|
0
Dataset/NLP/__init__.py
Normal file
0
Dataset/NLP/__init__.py
Normal file
297
Dataset/NLP/bAbi.py
Normal file
297
Dataset/NLP/bAbi.py
Normal file
@ -0,0 +1,297 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
from .NLPTask import NLPTask
|
||||
from Utils import Visdom
|
||||
|
||||
Sentence = namedtuple("Sentence", ["sentence", "answer", "supporting_facts"])
|
||||
|
||||
|
||||
class bAbiDataset(NLPTask):
|
||||
URL = "http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz"
|
||||
DIR_NAME = "tasks_1-20_v1-2"
|
||||
|
||||
def __init__(
|
||||
self, dirs=["en-10k"], sets=None, think_steps=0, dir_name=None, name=None
|
||||
):
|
||||
super(bAbiDataset, self).__init__()
|
||||
|
||||
self._test_res_win = None
|
||||
self._test_plot_win = None
|
||||
self._think_steps = think_steps
|
||||
|
||||
if dir_name is None:
|
||||
self._download()
|
||||
dir_name = os.path.join(self.cache_dir, self.DIR_NAME)
|
||||
|
||||
self.data = {}
|
||||
for d in dirs:
|
||||
self.data[d] = self._load_or_create(os.path.join(dir_name, d))
|
||||
|
||||
self.all_tasks = None
|
||||
self.name = name
|
||||
self.use(sets=sets)
|
||||
|
||||
def _make_active_list(self, tasks, sets, dirs):
|
||||
def verify(name, checker):
|
||||
if checker is None:
|
||||
return True
|
||||
|
||||
if callable(checker):
|
||||
return checker(name)
|
||||
elif isinstance(checker, list):
|
||||
return name in checker
|
||||
else:
|
||||
return name == checker
|
||||
|
||||
res = []
|
||||
for dirname, setlist in self.data.items():
|
||||
if not verify(dirname, dirs):
|
||||
continue
|
||||
|
||||
for sname, tasklist in setlist.items():
|
||||
if not verify(sname, sets):
|
||||
continue
|
||||
|
||||
for task, data in tasklist.items():
|
||||
name = task.split("_")[0][2:]
|
||||
if not verify(name, tasks):
|
||||
continue
|
||||
|
||||
res += [(d, dirname, task, sname) for d in data]
|
||||
|
||||
return res
|
||||
|
||||
def use(self, tasks=None, sets=None, dirs=None):
|
||||
self.all_tasks = self._make_active_list(tasks=tasks, sets=sets, dirs=dirs)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_tasks)
|
||||
|
||||
def _get_seq(self, index):
|
||||
return self.all_tasks[index]
|
||||
|
||||
def _seq_to_nn_input(self, seq):
|
||||
in_arr = []
|
||||
out_arr = []
|
||||
hasAnswer = False
|
||||
for sentence in seq[0]:
|
||||
in_arr += sentence.sentence
|
||||
out_arr += [0] * len(sentence.sentence)
|
||||
if sentence.answer is not None:
|
||||
in_arr += [0] * (len(sentence.answer) + self._think_steps)
|
||||
out_arr += [0] * self._think_steps + sentence.answer
|
||||
hasAnswer = True
|
||||
|
||||
in_arr = np.asarray(in_arr, np.int64)
|
||||
out_arr = np.asarray(out_arr, np.int64)
|
||||
|
||||
return {
|
||||
"input": in_arr,
|
||||
"output": out_arr,
|
||||
"meta": {"dir": seq[1], "task": seq[2], "set": seq[3]},
|
||||
}
|
||||
|
||||
def __getitem__(self, item):
|
||||
seq = self._get_seq(item)
|
||||
return self._seq_to_nn_input(seq)
|
||||
|
||||
def _load_or_create(self, directory):
|
||||
cache_name = directory.replace("/", "_")
|
||||
cache_file = os.path.join(self.cache_dir, cache_name + ".pth")
|
||||
if not os.path.isfile(cache_file):
|
||||
print("bAbI: Loading %s" % directory)
|
||||
res = self._load_dir(directory)
|
||||
print("Write: ", cache_file)
|
||||
self.save_vocabulary()
|
||||
torch.save(res, cache_file)
|
||||
else:
|
||||
res = torch.load(cache_file)
|
||||
return res
|
||||
|
||||
def _download(self):
|
||||
if not os.path.isdir(os.path.join(self.cache_dir, self.DIR_NAME)):
|
||||
print(self.URL)
|
||||
print("bAbi data not found. Downloading...")
|
||||
import requests, tarfile, io
|
||||
|
||||
request = requests.get(
|
||||
self.URL,
|
||||
headers={
|
||||
"User-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36"
|
||||
},
|
||||
)
|
||||
|
||||
decompressed_file = tarfile.open(
|
||||
fileobj=io.BytesIO(request.content), mode="r|gz"
|
||||
)
|
||||
decompressed_file.extractall(self.cache_dir)
|
||||
print("Done")
|
||||
|
||||
def _load_dir(
|
||||
self,
|
||||
directory,
|
||||
parse_name=lambda x: x.split(".")[0],
|
||||
parse_set=lambda x: x.split(".")[0].split("_")[-1],
|
||||
):
|
||||
res = {}
|
||||
for f in glob.glob(os.path.join(directory, "**", "*.txt"), recursive=True):
|
||||
basename = os.path.basename(f)
|
||||
task_name = parse_name(basename)
|
||||
set = parse_set(basename)
|
||||
print("Loading", f)
|
||||
|
||||
s = res.get(set)
|
||||
if s is None:
|
||||
s = {}
|
||||
res[set] = s
|
||||
s[task_name] = self._load_task(f, task_name)
|
||||
|
||||
return res
|
||||
|
||||
def _load_task(self, filename, task_name):
|
||||
task = []
|
||||
currTask = []
|
||||
|
||||
nextIndex = 1
|
||||
with open(filename, "r") as f:
|
||||
for line in f:
|
||||
line = [f.strip() for f in line.split("\t")]
|
||||
line[0] = line[0].split(" ")
|
||||
i = int(line[0][0])
|
||||
line[0] = " ".join(line[0][1:])
|
||||
|
||||
if i != nextIndex:
|
||||
nextIndex = i
|
||||
task.append(currTask)
|
||||
currTask = []
|
||||
|
||||
isQuestion = len(line) > 1
|
||||
currTask.append(
|
||||
Sentence(
|
||||
self.vocabulary.sentence_to_indices(line[0]),
|
||||
self.vocabulary.sentence_to_indices(line[1].replace(",", " "))
|
||||
if isQuestion
|
||||
else None,
|
||||
[int(f) for f in line[2].split(" ")] if isQuestion else None,
|
||||
)
|
||||
)
|
||||
|
||||
nextIndex += 1
|
||||
return task
|
||||
|
||||
def start_test(self):
|
||||
return {}
|
||||
|
||||
def veify_result(self, test, data, net_output):
|
||||
_, net_output = net_output.max(-1)
|
||||
|
||||
ref = data["output"]
|
||||
|
||||
mask = 1.0 - ref.eq(0).float()
|
||||
|
||||
correct = (torch.eq(net_output, ref).float() * mask).sum(-1)
|
||||
total = mask.sum(-1)
|
||||
|
||||
correct = correct.data.cpu().numpy()
|
||||
total = total.data.cpu().numpy()
|
||||
|
||||
for i in range(correct.shape[0]):
|
||||
task = data["meta"][i]["task"]
|
||||
if task not in test:
|
||||
test[task] = {"total": 0, "correct": 0}
|
||||
|
||||
d = test[task]
|
||||
d["total"] += total[i]
|
||||
d["correct"] += correct[i]
|
||||
|
||||
def _ensure_test_wins_exists(self, legend=None):
|
||||
if self._test_res_win is None:
|
||||
n = ("[" + self.name + "]") if self.name is not None else ""
|
||||
self._test_res_win = Visdom.Text("Test results" + n)
|
||||
self._test_plot_win = Visdom.Plot2D("Test results" + n, legend=legend)
|
||||
elif self._test_plot_win.legend is None:
|
||||
self._test_plot_win.set_legend(legend=legend)
|
||||
|
||||
def show_test_results(self, iteration, test):
|
||||
res = {k: v["correct"] / v["total"] for k, v in test.items()}
|
||||
|
||||
t = ""
|
||||
|
||||
all_keys = list(res.keys())
|
||||
|
||||
num_keys = [k for k in all_keys if k.startswith("qa")]
|
||||
tmp = [
|
||||
i[0]
|
||||
for i in sorted(
|
||||
enumerate(num_keys), key=lambda x: int(x[1][2:].split("_")[0])
|
||||
)
|
||||
]
|
||||
num_keys = [num_keys[j] for j in tmp]
|
||||
|
||||
all_keys = num_keys + sorted([k for k in all_keys if not k.startswith("qa")])
|
||||
|
||||
err_precent = [(1.0 - res[k]) * 100.0 for k in all_keys]
|
||||
|
||||
n_passed = sum([int(p <= 5) for p in err_precent])
|
||||
n_total = len(err_precent)
|
||||
err_precent = err_precent + [sum(err_precent) / len(err_precent)]
|
||||
all_keys += ["mean"]
|
||||
|
||||
for i, k in enumerate(all_keys):
|
||||
t += '<font color="%s">%s: <b>%.2f%%</b></font><br>' % (
|
||||
"green" if err_precent[i] <= 5 else "red",
|
||||
k,
|
||||
err_precent[i],
|
||||
)
|
||||
|
||||
t += "<br><b>Total: %d of %d passed.</b>" % (n_passed, n_total)
|
||||
|
||||
self._ensure_test_wins_exists(
|
||||
legend=[i.split("_")[0] if i.startswith("qa") else i for i in all_keys]
|
||||
)
|
||||
|
||||
self._test_res_win.set(t)
|
||||
self._test_plot_win.add_point(iteration, err_precent)
|
||||
|
||||
def state_dict(self):
|
||||
if self._test_res_win is not None:
|
||||
return {
|
||||
"_test_res_win": self._test_res_win.state_dict(),
|
||||
"_test_plot_win": self._test_plot_win.state_dict(),
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
if state:
|
||||
self._ensure_test_wins_exists()
|
||||
self._test_res_win.load_state_dict(state["_test_res_win"])
|
||||
self._test_plot_win.load_state_dict(state["_test_plot_win"])
|
||||
self._test_plot_win.legend = None
|
||||
|
||||
def visualize_preview(self, data, net_output):
|
||||
res = self.generate_preview_text(data, net_output)
|
||||
res = ("<b><u>%s</u></b><br>" % data["meta"][0]["task"]) + res
|
||||
if self._preview is None:
|
||||
self._preview = Visdom.Text("Preview")
|
||||
|
||||
self._preview.set(res)
|
1
Dataset/__init__.py
Normal file
1
Dataset/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
BIN
Models/.DS_Store
vendored
Normal file
BIN
Models/.DS_Store
vendored
Normal file
Binary file not shown.
965
Models/DNCA.py
Normal file
965
Models/DNCA.py
Normal file
@ -0,0 +1,965 @@
|
||||
# The Initial DNC Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
# The modification of the initial DNC implementation by Ari Azarafrooz.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
import functools
|
||||
import math
|
||||
|
||||
|
||||
def oneplus(t):
|
||||
return F.softplus(t, 1, 20) + 1.0
|
||||
|
||||
|
||||
def get_next_tensor_part(src, dims, prev_pos=0):
|
||||
if not isinstance(dims, list):
|
||||
dims = [dims]
|
||||
n = functools.reduce(lambda x, y: x * y, dims)
|
||||
data = src.narrow(-1, prev_pos, n)
|
||||
return (
|
||||
data.contiguous().view(list(data.size())[:-1] + dims)
|
||||
if len(dims) > 1
|
||||
else data,
|
||||
prev_pos + n,
|
||||
)
|
||||
|
||||
|
||||
def split_tensor(src, shapes):
|
||||
pos = 0
|
||||
res = []
|
||||
for s in shapes:
|
||||
d, pos = get_next_tensor_part(src, s, pos)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
|
||||
def dict_get(dict, name):
|
||||
return dict.get(name) if dict is not None else None
|
||||
|
||||
|
||||
def dict_append(dict, name, val):
|
||||
if dict is not None:
|
||||
l = dict.get(name)
|
||||
if not l:
|
||||
l = []
|
||||
dict[name] = l
|
||||
l.append(val)
|
||||
|
||||
|
||||
def init_debug(debug, initial):
|
||||
if debug is not None and not debug:
|
||||
debug.update(initial)
|
||||
|
||||
|
||||
def merge_debug_tensors(d, dim):
|
||||
if d is not None:
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
merge_debug_tensors(v, dim)
|
||||
elif isinstance(v, list):
|
||||
d[k] = torch.stack(v, dim)
|
||||
|
||||
|
||||
def linear_reset(module, gain=1.0):
|
||||
assert isinstance(module, torch.nn.Linear)
|
||||
init.xavier_uniform_(module.weight, gain=gain)
|
||||
s = module.weight.size(1)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
_EPS = 1e-6
|
||||
|
||||
|
||||
class AllocationManager(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(AllocationManager, self).__init__()
|
||||
self.usages = None
|
||||
self.zero_usages = None
|
||||
self.debug_sequ_init = False
|
||||
self.one = None
|
||||
|
||||
def _init_sequence(self, prev_read_distributions):
|
||||
# prev_read_distributions size is [batch, n_heads, cell count]
|
||||
s = prev_read_distributions.size()
|
||||
if self.zero_usages is None or list(self.zero_usages.size()) != [s[0], s[-1]]:
|
||||
self.zero_usages = torch.zeros(
|
||||
s[0], s[-1], device=prev_read_distributions.device
|
||||
)
|
||||
if self.debug_sequ_init:
|
||||
self.zero_usages += torch.arange(0, s[-1]).unsqueeze(0) * 1e-10
|
||||
|
||||
self.usages = self.zero_usages
|
||||
|
||||
def _init_consts(self, device):
|
||||
if self.one is None:
|
||||
self.one = torch.ones(1, device=device)
|
||||
|
||||
def new_sequence(self):
|
||||
self.usages = None
|
||||
|
||||
def update_usages(
|
||||
self, prev_write_distribution, prev_read_distributions, free_gates
|
||||
):
|
||||
# Read distributions shape: [batch, n_heads, cell count]
|
||||
# Free gates shape: [batch, n_heads]
|
||||
|
||||
self._init_consts(prev_read_distributions.device)
|
||||
phi = torch.addcmul(
|
||||
self.one, -1, free_gates.unsqueeze(-1), prev_read_distributions
|
||||
).prod(-2)
|
||||
# Phi is the free tensor, sized [batch, cell count]
|
||||
|
||||
# If memory usage counter if doesn't exists
|
||||
if self.usages is None:
|
||||
self._init_sequence(prev_read_distributions)
|
||||
# in first timestep nothing is written or read yet, so we don't need any further processing
|
||||
else:
|
||||
self.usages = (
|
||||
torch.addcmul(
|
||||
self.usages, 1, prev_write_distribution.detach(), (1 - self.usages)
|
||||
)
|
||||
* phi
|
||||
)
|
||||
|
||||
return phi
|
||||
|
||||
def forward(self, prev_write_distribution, prev_read_distributions, free_gates):
|
||||
phi = self.update_usages(
|
||||
prev_write_distribution, prev_read_distributions, free_gates
|
||||
)
|
||||
sorted_usage, free_list = (self.usages * (1.0 - _EPS) + _EPS).sort(-1)
|
||||
|
||||
u_prod = sorted_usage.cumprod(-1)
|
||||
one_minus_usage = 1.0 - sorted_usage
|
||||
sorted_scores = torch.cat(
|
||||
[one_minus_usage[..., 0:1], one_minus_usage[..., 1:] * u_prod[..., :-1]],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
return sorted_scores.clone().scatter_(-1, free_list, sorted_scores), phi
|
||||
|
||||
|
||||
class ContentAddressGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False
|
||||
):
|
||||
super(ContentAddressGenerator, self).__init__()
|
||||
self.disable_content_norm = disable_content_norm
|
||||
self.mask_min = mask_min
|
||||
self.disable_key_masking = disable_key_masking
|
||||
|
||||
def forward(self, memory, keys, betas, mask=None):
|
||||
# Memory shape [batch, cell count, word length]
|
||||
# Key shape [batch, n heads*, word length]
|
||||
# Betas shape [batch, n heads]
|
||||
if mask is not None and self.mask_min != 0:
|
||||
mask = mask * (1.0 - self.mask_min) + self.mask_min
|
||||
|
||||
single_head = keys.dim() == 2
|
||||
if single_head:
|
||||
# Single head
|
||||
keys = keys.unsqueeze(1)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
memory = memory.unsqueeze(1)
|
||||
keys = keys.unsqueeze(-2)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(-2)
|
||||
memory = memory * mask
|
||||
if not self.disable_key_masking:
|
||||
keys = keys * mask
|
||||
|
||||
# Shape [batch, n heads, cell count]
|
||||
norm = keys.norm(dim=-1)
|
||||
if not self.disable_content_norm:
|
||||
norm = norm * memory.norm(dim=-1)
|
||||
|
||||
scores = (memory * keys).sum(-1) / (norm + _EPS)
|
||||
scores *= betas.unsqueeze(-1)
|
||||
|
||||
res = F.softmax(scores, scores.dim() - 1)
|
||||
return res.squeeze(1) if single_head else res
|
||||
|
||||
|
||||
class WriteHead(torch.nn.Module):
|
||||
@staticmethod
|
||||
def create_write_archive(write_dist, erase_vector, write_vector, phi):
|
||||
return dict(
|
||||
write_dist=write_dist,
|
||||
erase_vector=erase_vector,
|
||||
write_vector=write_vector,
|
||||
phi=phi,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dealloc_content=True,
|
||||
disable_content_norm=False,
|
||||
mask_min=0.0,
|
||||
disable_key_masking=False,
|
||||
):
|
||||
super(WriteHead, self).__init__()
|
||||
self.write_content_generator = ContentAddressGenerator(
|
||||
disable_content_norm,
|
||||
mask_min=mask_min,
|
||||
disable_key_masking=disable_key_masking,
|
||||
)
|
||||
self.allocation_manager = AllocationManager()
|
||||
self.last_write = None
|
||||
self.dealloc_content = dealloc_content
|
||||
self.new_sequence()
|
||||
|
||||
def new_sequence(self):
|
||||
self.last_write = None
|
||||
self.allocation_manager.new_sequence()
|
||||
|
||||
@staticmethod
|
||||
def mem_update(memory, write_dist, erase_vector, write_vector, phi):
|
||||
# In original paper the memory content is NOT deallocated, which makes content based addressing basically
|
||||
# unusable when multiple similar steps should be done. The reason for this is that the memory contents are
|
||||
# still there, so the lookup will find them, unless an allocation clears it before the next search, which is
|
||||
# completely random. So I'm arguing that erase matrix should also take in account the free gates (multiply it
|
||||
# with phi)
|
||||
write_dist = write_dist.unsqueeze(-1)
|
||||
|
||||
erase_matrix = 1.0 - write_dist * erase_vector.unsqueeze(-2)
|
||||
if phi is not None:
|
||||
erase_matrix = erase_matrix * phi.unsqueeze(-1)
|
||||
|
||||
update_matrix = write_dist * write_vector.unsqueeze(-2)
|
||||
return memory * erase_matrix + update_matrix
|
||||
|
||||
def forward(
|
||||
self,
|
||||
demon_action,
|
||||
memory,
|
||||
write_content_key,
|
||||
write_beta,
|
||||
erase_vector,
|
||||
write_vector,
|
||||
alloc_gate,
|
||||
write_gate,
|
||||
free_gates,
|
||||
prev_read_dist,
|
||||
write_mask=None,
|
||||
debug=None,
|
||||
):
|
||||
last_w_dist = (
|
||||
self.last_write["write_dist"] if self.last_write is not None else None
|
||||
)
|
||||
|
||||
content_dist = self.write_content_generator(
|
||||
memory, write_content_key, write_beta, mask=write_mask
|
||||
)
|
||||
alloc_dist, phi = self.allocation_manager(
|
||||
last_w_dist, prev_read_dist, free_gates
|
||||
)
|
||||
|
||||
# Shape [batch, cell count]
|
||||
write_dist = write_gate * (
|
||||
alloc_gate * alloc_dist + (1 - alloc_gate) * content_dist
|
||||
)
|
||||
self.last_write = WriteHead.create_write_archive(
|
||||
write_dist,
|
||||
erase_vector,
|
||||
write_vector,
|
||||
phi if self.dealloc_content else None,
|
||||
)
|
||||
|
||||
dict_append(debug, "alloc_dist", alloc_dist)
|
||||
dict_append(debug, "write_dist", write_dist)
|
||||
dict_append(debug, "mem_usages", self.allocation_manager.usages)
|
||||
dict_append(debug, "free_gates", free_gates)
|
||||
dict_append(debug, "write_betas", write_beta)
|
||||
dict_append(debug, "write_gate", write_gate)
|
||||
dict_append(debug, "write_vector", write_vector)
|
||||
dict_append(debug, "alloc_gate", alloc_gate)
|
||||
dict_append(debug, "erase_vector", erase_vector)
|
||||
if write_mask is not None:
|
||||
dict_append(debug, "write_mask", write_mask)
|
||||
|
||||
return WriteHead.mem_update(memory, **self.last_write)
|
||||
|
||||
|
||||
class RawWriteHead(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_read_heads,
|
||||
word_length,
|
||||
use_mask=False,
|
||||
dealloc_content=True,
|
||||
disable_content_norm=False,
|
||||
mask_min=0.0,
|
||||
disable_key_masking=False,
|
||||
):
|
||||
super(RawWriteHead, self).__init__()
|
||||
self.write_head = WriteHead(
|
||||
dealloc_content=dealloc_content,
|
||||
disable_content_norm=disable_content_norm,
|
||||
mask_min=mask_min,
|
||||
disable_key_masking=disable_key_masking,
|
||||
)
|
||||
self.word_length = word_length
|
||||
self.n_read_heads = n_read_heads
|
||||
self.use_mask = use_mask
|
||||
self.input_size = (
|
||||
3 * self.word_length
|
||||
+ self.n_read_heads
|
||||
+ 3
|
||||
+ (self.word_length if use_mask else 0)
|
||||
)
|
||||
|
||||
def new_sequence(self):
|
||||
self.write_head.new_sequence()
|
||||
|
||||
def get_prev_write(self):
|
||||
return self.write_head.last_write
|
||||
|
||||
def forward(self, demon_action, memory, nn_output, prev_read_dist, debug):
|
||||
shapes = (
|
||||
[[self.word_length]] * (4 if self.use_mask else 3)
|
||||
+ [[self.n_read_heads]]
|
||||
+ [[1]] * 3
|
||||
)
|
||||
tensors = split_tensor(nn_output, shapes)
|
||||
|
||||
if self.use_mask:
|
||||
write_mask = torch.sigmoid(tensors[0])
|
||||
tensors = tensors[1:]
|
||||
else:
|
||||
write_mask = None
|
||||
|
||||
(
|
||||
write_content_key,
|
||||
erase_vector,
|
||||
write_vector,
|
||||
free_gates,
|
||||
write_beta,
|
||||
alloc_gate,
|
||||
write_gate,
|
||||
) = tensors
|
||||
|
||||
erase_vector = torch.sigmoid(erase_vector)
|
||||
free_gates = torch.sigmoid(free_gates)
|
||||
write_beta = oneplus(write_beta)
|
||||
alloc_gate = torch.sigmoid(alloc_gate)
|
||||
write_gate = torch.sigmoid(write_gate)
|
||||
|
||||
return self.write_head(
|
||||
demon_action,
|
||||
memory,
|
||||
write_content_key,
|
||||
write_beta,
|
||||
erase_vector,
|
||||
write_vector,
|
||||
alloc_gate,
|
||||
write_gate,
|
||||
free_gates,
|
||||
prev_read_dist,
|
||||
debug=debug,
|
||||
write_mask=write_mask,
|
||||
)
|
||||
|
||||
def get_neural_input_size(self):
|
||||
return self.input_size
|
||||
|
||||
|
||||
class TemporalMemoryLinkage(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TemporalMemoryLinkage, self).__init__()
|
||||
self.temp_link_mat = None
|
||||
self.precedence_weighting = None
|
||||
self.diag_mask = None
|
||||
|
||||
self.initial_temp_link_mat = None
|
||||
self.initial_precedence_weighting = None
|
||||
self.initial_diag_mask = None
|
||||
self.initial_shape = None
|
||||
|
||||
def new_sequence(self):
|
||||
self.temp_link_mat = None
|
||||
self.precedence_weighting = None
|
||||
self.diag_mask = None
|
||||
|
||||
def _init_link(self, w_dist):
|
||||
s = list(w_dist.size())
|
||||
if self.initial_shape is None or s != self.initial_shape:
|
||||
self.initial_temp_link_mat = torch.zeros(s[0], s[-1], s[-1]).to(
|
||||
w_dist.device
|
||||
)
|
||||
self.initial_precedence_weighting = torch.zeros(s[0], s[-1]).to(
|
||||
w_dist.device
|
||||
)
|
||||
self.initial_diag_mask = (
|
||||
1.0 - torch.eye(s[-1]).unsqueeze(0).to(w_dist)
|
||||
).detach()
|
||||
|
||||
self.temp_link_mat = self.initial_temp_link_mat
|
||||
self.precedence_weighting = self.initial_precedence_weighting
|
||||
self.diag_mask = self.initial_diag_mask
|
||||
|
||||
def _update_precedence(self, w_dist):
|
||||
# w_dist shape: [ batch, cell count ]
|
||||
self.precedence_weighting = (
|
||||
1.0 - w_dist.sum(-1, keepdim=True)
|
||||
) * self.precedence_weighting + w_dist
|
||||
|
||||
def _update_links(self, w_dist):
|
||||
if self.temp_link_mat is None:
|
||||
self._init_link(w_dist)
|
||||
|
||||
wt_i = w_dist.unsqueeze(-1)
|
||||
wt_j = w_dist.unsqueeze(-2)
|
||||
pt_j = self.precedence_weighting.unsqueeze(-2)
|
||||
|
||||
self.temp_link_mat = (
|
||||
(1 - wt_i - wt_j) * self.temp_link_mat + wt_i * pt_j
|
||||
) * self.diag_mask
|
||||
|
||||
def forward(self, w_dist, prev_r_dists, debug=None):
|
||||
self._update_links(w_dist)
|
||||
self._update_precedence(w_dist)
|
||||
|
||||
# prev_r_dists shape: [ batch, n heads, cell count ]
|
||||
# Emulate matrix-vector multiplication by broadcast and sum. This way we don't need to transpose the matrix
|
||||
tlm_multi_head = self.temp_link_mat.unsqueeze(1)
|
||||
|
||||
forward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-2)).sum(-1)
|
||||
backward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-1)).sum(-2)
|
||||
|
||||
dict_append(debug, "forward_dists", forward_dist)
|
||||
dict_append(debug, "backward_dists", backward_dist)
|
||||
dict_append(debug, "precedence_weights", self.precedence_weighting)
|
||||
|
||||
# output shapes [ batch, n_heads, cell_count ]
|
||||
return forward_dist, backward_dist
|
||||
|
||||
|
||||
class ReadHead(torch.nn.Module):
|
||||
def __init__(
|
||||
self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False
|
||||
):
|
||||
super(ReadHead, self).__init__()
|
||||
self.content_addr_generator = ContentAddressGenerator(
|
||||
disable_content_norm=disable_content_norm,
|
||||
mask_min=mask_min,
|
||||
disable_key_masking=disable_key_masking,
|
||||
)
|
||||
self.read_dist = None
|
||||
self.read_data = None
|
||||
self.new_sequence()
|
||||
|
||||
def new_sequence(self):
|
||||
self.read_dist = None
|
||||
self.read_data = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
memory,
|
||||
read_content_keys,
|
||||
read_betas,
|
||||
forward_dist,
|
||||
backward_dist,
|
||||
gates,
|
||||
read_mask=None,
|
||||
debug=None,
|
||||
):
|
||||
content_dist = self.content_addr_generator(
|
||||
memory, read_content_keys, read_betas, mask=read_mask
|
||||
)
|
||||
|
||||
self.read_dist = (
|
||||
backward_dist * gates[..., 0:1]
|
||||
+ content_dist * gates[..., 1:2]
|
||||
+ forward_dist * gates[..., 2:]
|
||||
)
|
||||
|
||||
# memory shape: [ batch, cell count, word_length ]
|
||||
# read_dist shape: [ batch, n heads, cell count ]
|
||||
# result shape: [ batch, n_heads, word_length ]
|
||||
self.read_data = (memory.unsqueeze(1) * self.read_dist.unsqueeze(-1)).sum(-2)
|
||||
|
||||
dict_append(debug, "content_dist", content_dist)
|
||||
dict_append(debug, "balance", gates)
|
||||
dict_append(debug, "read_dist", self.read_dist)
|
||||
dict_append(debug, "read_content_keys", read_content_keys)
|
||||
if read_mask is not None:
|
||||
dict_append(debug, "read_mask", read_mask)
|
||||
dict_append(debug, "read_betas", read_betas.unsqueeze(-2))
|
||||
if read_mask is not None:
|
||||
dict_append(debug, "read_mask", read_mask)
|
||||
|
||||
return self.read_data
|
||||
|
||||
|
||||
class RawReadHead(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_heads,
|
||||
word_length,
|
||||
use_mask=False,
|
||||
disable_content_norm=False,
|
||||
mask_min=0.0,
|
||||
disable_key_masking=False,
|
||||
):
|
||||
super(RawReadHead, self).__init__()
|
||||
self.read_head = ReadHead(
|
||||
disable_content_norm=disable_content_norm,
|
||||
mask_min=mask_min,
|
||||
disable_key_masking=disable_key_masking,
|
||||
)
|
||||
self.n_heads = n_heads
|
||||
self.word_length = word_length
|
||||
self.use_mask = use_mask
|
||||
self.input_size = self.n_heads * (
|
||||
self.word_length * (2 if use_mask else 1) + 3 + 1
|
||||
)
|
||||
|
||||
def get_prev_dist(self, memory):
|
||||
if self.read_head.read_dist is not None:
|
||||
return self.read_head.read_dist
|
||||
else:
|
||||
m_shape = memory.size()
|
||||
return torch.zeros(m_shape[0], self.n_heads, m_shape[1]).to(memory)
|
||||
|
||||
def get_prev_data(self, memory):
|
||||
if self.read_head.read_data is not None:
|
||||
return self.read_head.read_data
|
||||
else:
|
||||
m_shape = memory.size()
|
||||
return torch.zeros(m_shape[0], self.n_heads, m_shape[-1]).to(memory)
|
||||
|
||||
def new_sequence(self):
|
||||
self.read_head.new_sequence()
|
||||
|
||||
def forward(self, memory, nn_output, forward_dist, backward_dist, debug):
|
||||
shapes = [[self.n_heads, self.word_length]] * (2 if self.use_mask else 1) + [
|
||||
[self.n_heads],
|
||||
[self.n_heads, 3],
|
||||
]
|
||||
tensors = split_tensor(nn_output, shapes)
|
||||
|
||||
if self.use_mask:
|
||||
read_mask = torch.sigmoid(tensors[0])
|
||||
tensors = tensors[1:]
|
||||
else:
|
||||
read_mask = None
|
||||
|
||||
keys, betas, gates = tensors
|
||||
|
||||
betas = oneplus(betas)
|
||||
gates = F.softmax(gates, gates.dim() - 1)
|
||||
|
||||
return self.read_head(
|
||||
memory,
|
||||
keys,
|
||||
betas,
|
||||
forward_dist,
|
||||
backward_dist,
|
||||
gates,
|
||||
debug=debug,
|
||||
read_mask=read_mask,
|
||||
)
|
||||
|
||||
def get_neural_input_size(self):
|
||||
return self.input_size
|
||||
|
||||
|
||||
class DistSharpnessEnhancer(torch.nn.Module):
|
||||
def __init__(self, n_heads):
|
||||
super(DistSharpnessEnhancer, self).__init__()
|
||||
self.n_heads = n_heads if isinstance(n_heads, list) else [n_heads]
|
||||
self.n_data = sum(self.n_heads)
|
||||
|
||||
def forward(self, nn_input, *dists):
|
||||
assert len(dists) == len(self.n_heads)
|
||||
nn_input = oneplus(nn_input[..., : self.n_data])
|
||||
factors = split_tensor(nn_input, self.n_heads)
|
||||
|
||||
res = []
|
||||
for i, d in enumerate(dists):
|
||||
s = list(d.size())
|
||||
ndim = d.dim()
|
||||
f = factors[i]
|
||||
if ndim == 2:
|
||||
assert self.n_heads[i] == 1
|
||||
elif ndim == 3:
|
||||
f = f.unsqueeze(-1)
|
||||
else:
|
||||
assert False
|
||||
|
||||
d += _EPS
|
||||
d = d / d.max(dim=-1, keepdim=True)[0]
|
||||
d = d.pow(f)
|
||||
d = d / d.sum(dim=-1, keepdim=True)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
def get_neural_input_size(self):
|
||||
return self.n_data
|
||||
|
||||
|
||||
class DNC(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
output_size,
|
||||
word_length,
|
||||
cell_count,
|
||||
n_read_heads,
|
||||
controller,
|
||||
batch_first=False,
|
||||
clip_controller=20,
|
||||
bias=True,
|
||||
mask=False,
|
||||
dealloc_content=True,
|
||||
link_sharpness_control=True,
|
||||
disable_content_norm=False,
|
||||
mask_min=0.0,
|
||||
disable_key_masking=False,
|
||||
):
|
||||
super(DNC, self).__init__()
|
||||
|
||||
self.clip_controller = clip_controller
|
||||
|
||||
self.read_head = RawReadHead(
|
||||
n_read_heads,
|
||||
word_length,
|
||||
use_mask=mask,
|
||||
disable_content_norm=disable_content_norm,
|
||||
mask_min=mask_min,
|
||||
disable_key_masking=disable_key_masking,
|
||||
)
|
||||
self.write_head = RawWriteHead(
|
||||
n_read_heads,
|
||||
word_length,
|
||||
use_mask=mask,
|
||||
dealloc_content=dealloc_content,
|
||||
disable_content_norm=disable_content_norm,
|
||||
mask_min=mask_min,
|
||||
disable_key_masking=disable_key_masking,
|
||||
)
|
||||
self.temporal_link = TemporalMemoryLinkage()
|
||||
self.sharpness_control = (
|
||||
DistSharpnessEnhancer([n_read_heads, n_read_heads])
|
||||
if link_sharpness_control
|
||||
else None
|
||||
)
|
||||
|
||||
in_size = input_size + n_read_heads * word_length
|
||||
control_channels = (
|
||||
self.read_head.get_neural_input_size()
|
||||
+ self.write_head.get_neural_input_size()
|
||||
+ (
|
||||
self.sharpness_control.get_neural_input_size()
|
||||
if self.sharpness_control is not None
|
||||
else 0
|
||||
)
|
||||
)
|
||||
|
||||
self.controller = controller
|
||||
controller.init(in_size)
|
||||
self.controller_to_controls = torch.nn.Linear(
|
||||
controller.get_output_size(), control_channels, bias=bias
|
||||
)
|
||||
self.controller_to_out = torch.nn.Linear(
|
||||
controller.get_output_size(), output_size, bias=bias
|
||||
)
|
||||
self.read_to_out = torch.nn.Linear(
|
||||
word_length * n_read_heads, output_size, bias=bias
|
||||
)
|
||||
|
||||
self.cell_count = cell_count
|
||||
self.word_length = word_length
|
||||
|
||||
self.memory = None
|
||||
self.reset_parameters()
|
||||
|
||||
self.batch_first = batch_first
|
||||
self.zero_mem_tensor = None
|
||||
|
||||
self.mem_state = None
|
||||
|
||||
self.device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
|
||||
def reset_parameters(self):
|
||||
linear_reset(self.controller_to_controls)
|
||||
linear_reset(self.controller_to_out)
|
||||
linear_reset(self.read_to_out)
|
||||
self.controller.reset_parameters()
|
||||
|
||||
def _step(self, in_data, debug, demon, rollout_storage):
|
||||
init_debug(debug, {"read_head": {}, "write_head": {}, "temporal_links": {}})
|
||||
|
||||
# input shape: [ batch, channels ]
|
||||
batch_size = in_data.size(0)
|
||||
|
||||
# # run the demon if it is used
|
||||
if demon:
|
||||
# Running policy_old:
|
||||
demon_action = demon.select_action(
|
||||
torch.cat([in_data, self.memory.view(batch_size, -1)], -1),
|
||||
rollout_storage,
|
||||
)
|
||||
in_data = in_data + demon_action
|
||||
|
||||
demon_action = None
|
||||
|
||||
# run the controller
|
||||
prev_read_data = self.read_head.get_prev_data(self.memory).view(
|
||||
[batch_size, -1]
|
||||
)
|
||||
|
||||
control_data = self.controller(torch.cat([in_data, prev_read_data], -1))
|
||||
|
||||
# memory ops
|
||||
controls = self.controller_to_controls(control_data).contiguous()
|
||||
controls = (
|
||||
controls.clamp(-self.clip_controller, self.clip_controller)
|
||||
if self.clip_controller is not None
|
||||
else controls
|
||||
)
|
||||
|
||||
shapes = [
|
||||
[self.write_head.get_neural_input_size()],
|
||||
[self.read_head.get_neural_input_size()],
|
||||
]
|
||||
if self.sharpness_control is not None:
|
||||
shapes.append(self.sharpness_control.get_neural_input_size())
|
||||
|
||||
tensors = split_tensor(controls, shapes)
|
||||
|
||||
write_head_control, read_head_control = tensors[:2]
|
||||
tensors = tensors[2:]
|
||||
|
||||
prev_read_dist = self.read_head.get_prev_dist(self.memory)
|
||||
|
||||
self.memory = self.write_head(
|
||||
demon_action,
|
||||
self.memory,
|
||||
write_head_control,
|
||||
prev_read_dist,
|
||||
debug=dict_get(debug, "write_head"),
|
||||
)
|
||||
|
||||
prev_write = self.write_head.get_prev_write()
|
||||
forward_dist, backward_dist = self.temporal_link(
|
||||
prev_write["write_dist"] if prev_write is not None else None,
|
||||
prev_read_dist,
|
||||
debug=dict_get(debug, "temporal_links"),
|
||||
)
|
||||
|
||||
if self.sharpness_control is not None:
|
||||
forward_dist, backward_dist = self.sharpness_control(
|
||||
tensors[0], forward_dist, backward_dist
|
||||
)
|
||||
|
||||
read_data = self.read_head(
|
||||
self.memory,
|
||||
read_head_control,
|
||||
forward_dist,
|
||||
backward_dist,
|
||||
debug=dict_get(debug, "read_head"),
|
||||
)
|
||||
|
||||
# output:
|
||||
return self.controller_to_out(control_data) + self.read_to_out(
|
||||
read_data.view(batch_size, -1)
|
||||
)
|
||||
|
||||
def _mem_init(self, batch_size, device):
|
||||
if self.zero_mem_tensor is None or self.zero_mem_tensor.size(0) != batch_size:
|
||||
self.zero_mem_tensor = torch.zeros(
|
||||
batch_size, self.cell_count, self.word_length
|
||||
).to(device)
|
||||
|
||||
self.memory = self.zero_mem_tensor
|
||||
|
||||
if self.mem_state is None:
|
||||
self.mem_state = []
|
||||
|
||||
def forward(self, in_data, debug=None, demon=None, rollout_storage=None):
|
||||
self.write_head.new_sequence()
|
||||
self.read_head.new_sequence()
|
||||
self.temporal_link.new_sequence()
|
||||
self.controller.new_sequence()
|
||||
|
||||
self._mem_init(in_data.size(0 if self.batch_first else 1), in_data.device)
|
||||
|
||||
out_tsteps = []
|
||||
|
||||
if self.batch_first:
|
||||
# input format: batch, time, channels
|
||||
for t in range(in_data.size(1)):
|
||||
out_tsteps.append(
|
||||
self._step(in_data[:, t], debug, demon, rollout_storage)
|
||||
)
|
||||
self.mem_state.append(self.memory.view(in_data.size(0), -1))
|
||||
else:
|
||||
# input format: time, batch, channels
|
||||
for t in range(in_data.size(0)):
|
||||
out_tsteps.append(self._step(in_data[t], debug, demon, rollout_storage))
|
||||
self.mem_state.append(self.memory.view(-1, in_data.size(0)))
|
||||
|
||||
merge_debug_tensors(debug, dim=1 if self.batch_first else 0)
|
||||
return torch.stack(out_tsteps, dim=1 if self.batch_first else 0)
|
||||
|
||||
|
||||
class LSTMController(torch.nn.Module):
|
||||
def __init__(self, layer_sizes, out_from_all_layers=True):
|
||||
super(LSTMController, self).__init__()
|
||||
self.out_from_all_layers = out_from_all_layers
|
||||
self.layer_sizes = layer_sizes
|
||||
self.states = None
|
||||
self.outputs = None
|
||||
|
||||
def new_sequence(self):
|
||||
self.states = [None] * len(self.layer_sizes)
|
||||
self.outputs = [None] * len(self.layer_sizes)
|
||||
|
||||
def reset_parameters(self):
|
||||
def init_layer(l, index):
|
||||
size = self.layer_sizes[index]
|
||||
# Initialize all matrices to sigmoid, just data input to tanh
|
||||
a = math.sqrt(3.0) * self.stdevs[i]
|
||||
l.weight.data[0:-size].uniform_(-a, a)
|
||||
a *= init.calculate_gain("tanh")
|
||||
l.weight.data[-size:].uniform_(-a, a)
|
||||
if l.bias is not None:
|
||||
l.bias.data[self.layer_sizes[i] :].fill_(0)
|
||||
# init forget gate to large number.
|
||||
l.bias.data[: self.layer_sizes[i]].fill_(1)
|
||||
|
||||
# xavier init merged input weights
|
||||
for i in range(len(self.layer_sizes)):
|
||||
init_layer(self.in_to_all[i], i)
|
||||
init_layer(self.out_to_all[i], i)
|
||||
if i > 0:
|
||||
init_layer(self.prev_to_all[i - 1], i)
|
||||
|
||||
def _add_modules(self, name, m_list):
|
||||
for i, m in enumerate(m_list):
|
||||
self.add_module("%s_%d" % (name, i), m)
|
||||
|
||||
def init(self, input_size):
|
||||
self.layer_sizes = self.layer_sizes
|
||||
|
||||
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
|
||||
# So use xavier init according to this.
|
||||
self.input_sizes = [
|
||||
(self.layer_sizes[i - 1] if i > 0 else 0) + self.layer_sizes[i] + input_size
|
||||
for i in range(len(self.layer_sizes))
|
||||
]
|
||||
self.stdevs = [
|
||||
math.sqrt(2.0 / (self.layer_sizes[i] + self.input_sizes[i]))
|
||||
for i in range(len(self.layer_sizes))
|
||||
]
|
||||
self.in_to_all = [
|
||||
torch.nn.Linear(input_size, 4 * self.layer_sizes[i])
|
||||
for i in range(len(self.layer_sizes))
|
||||
]
|
||||
self.out_to_all = [
|
||||
torch.nn.Linear(self.layer_sizes[i], 4 * self.layer_sizes[i], bias=False)
|
||||
for i in range(len(self.layer_sizes))
|
||||
]
|
||||
self.prev_to_all = [
|
||||
torch.nn.Linear(
|
||||
self.layer_sizes[i - 1], 4 * self.layer_sizes[i], bias=False
|
||||
)
|
||||
for i in range(1, len(self.layer_sizes))
|
||||
]
|
||||
|
||||
self._add_modules("in_to_all", self.in_to_all)
|
||||
self._add_modules("out_to_all", self.out_to_all)
|
||||
self._add_modules("prev_to_all", self.prev_to_all)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def get_output_size(self):
|
||||
return (
|
||||
sum(self.layer_sizes) if self.out_from_all_layers else self.layer_sizes[-1]
|
||||
)
|
||||
|
||||
def forward(self, data):
|
||||
for i, size in enumerate(self.layer_sizes):
|
||||
d = self.in_to_all[i](data)
|
||||
if self.outputs[i] is not None:
|
||||
d += self.out_to_all[i](self.outputs[i])
|
||||
if i > 0:
|
||||
d += self.prev_to_all[i - 1](self.outputs[i - 1])
|
||||
|
||||
input_data = torch.tanh(d[..., -size:])
|
||||
forget_gate, input_gate, output_gate = torch.sigmoid(d[..., :-size]).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
|
||||
state_update = input_gate * input_data
|
||||
|
||||
if self.states[i] is not None:
|
||||
self.states[i] = self.states[i] * forget_gate + state_update
|
||||
else:
|
||||
self.states[i] = state_update
|
||||
|
||||
self.outputs[i] = output_gate * torch.tanh(self.states[i])
|
||||
|
||||
return (
|
||||
torch.cat(self.outputs, -1)
|
||||
if self.out_from_all_layers
|
||||
else self.outputs[-1]
|
||||
)
|
||||
|
||||
|
||||
class FeedforwardController(torch.nn.Module):
|
||||
def __init__(self, layer_sizes=[]):
|
||||
super(FeedforwardController, self).__init__()
|
||||
self.layer_sizes = layer_sizes
|
||||
|
||||
def new_sequence(self):
|
||||
pass
|
||||
|
||||
def reset_parameters(self):
|
||||
for module in self.model:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
|
||||
def get_output_size(self):
|
||||
return self.layer_sizes[-1]
|
||||
|
||||
def init(self, input_size):
|
||||
self.layer_sizes = self.layer_sizes
|
||||
|
||||
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
|
||||
# So use xavier init according to this.
|
||||
self.input_sizes = [input_size] + self.layer_sizes[:-1]
|
||||
|
||||
layers = []
|
||||
for i, size in enumerate(self.layer_sizes):
|
||||
layers.append(torch.nn.Linear(self.input_sizes[i], self.layer_sizes[i]))
|
||||
layers.append(torch.nn.ReLU())
|
||||
self.model = torch.nn.Sequential(*layers)
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, data):
|
||||
return self.model(data)
|
140
Models/Demon.py
Normal file
140
Models/Demon.py
Normal file
@ -0,0 +1,140 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
from torch.distributions import Normal
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
LOG_SIG_MAX = 2
|
||||
LOG_SIG_MIN = -20
|
||||
EPSILON = 1e-6
|
||||
|
||||
SavedAction = namedtuple("SavedAction", ["action", "log_prob", "mean"])
|
||||
|
||||
|
||||
def linear_reset(module, gain=1.0):
|
||||
assert isinstance(module, torch.nn.Linear)
|
||||
init.xavier_uniform_(module.weight, gain=gain)
|
||||
s = module.weight.size(1)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class ZNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ZNet, self).__init__()
|
||||
|
||||
def reset_parameters(self):
|
||||
for module in self.lstm:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
|
||||
for module in self.hidden2z:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
|
||||
def init(self, input_size):
|
||||
self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
|
||||
self.hidden2z = nn.Sequential(nn.Linear(32, 1))
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, data):
|
||||
output, (hn, cn) = self.lstm(data)
|
||||
zvals = self.hidden2z(output)
|
||||
return F.softplus(zvals)
|
||||
|
||||
|
||||
class FNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(FNet, self).__init__()
|
||||
|
||||
def reset_parameters(self):
|
||||
for module in self.lstm:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
|
||||
for module in self.hidden2z:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
|
||||
def init(self, input_size):
|
||||
self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
|
||||
self.hidden2z = nn.Sequential(nn.Linear(32, 1))
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, data):
|
||||
output, (hn, cn) = self.lstm(data)
|
||||
output = F.elu(output)
|
||||
fvals = self.hidden2z(output)
|
||||
return fvals
|
||||
|
||||
|
||||
class Demon(torch.nn.Module):
|
||||
"""
|
||||
Demon manipulates the external memory of DNC.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_sizes=[]):
|
||||
super(Demon, self).__init__()
|
||||
self.layer_sizes = layer_sizes
|
||||
self.action_scale = torch.tensor(1)
|
||||
self.action_bias = torch.tensor(0.0)
|
||||
self.saved_actions = []
|
||||
|
||||
def get_output_size(self):
|
||||
return self.layer_sizes[-1]
|
||||
|
||||
def reset_parameters(self):
|
||||
for module in self.model:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
linear_reset(self.embed_mean, gain=init.calculate_gain("relu"))
|
||||
linear_reset(self.embed_log_std, gain=init.calculate_gain("relu"))
|
||||
|
||||
def init(self, input_size, output_size):
|
||||
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
|
||||
# So use xavier init according to this.
|
||||
self.input_sizes = [input_size] + self.layer_sizes[:-1]
|
||||
layers = []
|
||||
for i, size in enumerate(self.layer_sizes):
|
||||
layers.append(nn.Linear(self.input_sizes[i], self.layer_sizes[i]))
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
self.embed_mean = nn.Linear(self.layer_sizes[-1], output_size)
|
||||
self.embed_log_std = nn.Linear(self.layer_sizes[-1], output_size)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, data):
|
||||
x = self.model(data)
|
||||
x = F.relu(x)
|
||||
mean, log_std = self.embed_mean(x), self.embed_log_std(x)
|
||||
log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
|
||||
std = torch.exp(log_std)
|
||||
return mean, std
|
||||
|
||||
def act(self, data):
|
||||
"""
|
||||
pathwise derivative estimator for taking actions.
|
||||
:param data:
|
||||
:return:
|
||||
"""
|
||||
mean, std = self.forward(data)
|
||||
normal = Normal(mean, std)
|
||||
x = normal.rsample()
|
||||
|
||||
y = torch.softmax(x, dim=1)
|
||||
|
||||
action = y * self.action_scale + self.action_bias
|
||||
log_prob = normal.log_prob(action)
|
||||
# Enforcing Action Bound
|
||||
log_prob -= torch.log(self.action_scale * (1 - y.pow(2)) + EPSILON)
|
||||
log_prob = log_prob.sum(1, keepdim=True)
|
||||
|
||||
mean = torch.softmax(mean, dim=1) * self.action_scale + self.action_bias
|
||||
self.saved_actions.append(SavedAction(action, log_prob, mean))
|
||||
|
||||
return mean
|
238
Models/Information_Agents.py
Normal file
238
Models/Information_Agents.py
Normal file
@ -0,0 +1,238 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import MultivariateNormal
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
import numpy as np
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class RolloutStorage:
|
||||
def __init__(self):
|
||||
self.actions = []
|
||||
self.states = []
|
||||
self.logprobs = []
|
||||
self.rewards = []
|
||||
self.is_terminals = []
|
||||
|
||||
def clear_storage(self):
|
||||
del self.actions[:]
|
||||
del self.states[:]
|
||||
del self.logprobs[:]
|
||||
del self.rewards[:]
|
||||
del self.is_terminals[:]
|
||||
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
def __init__(self, state_dim, action_dim, action_std):
|
||||
super(ActorCritic, self).__init__()
|
||||
self.actor = nn.Sequential(
|
||||
nn.Linear(state_dim, 64),
|
||||
nn.Tanh(),
|
||||
nn.Linear(64, 32),
|
||||
nn.Tanh(),
|
||||
nn.Linear(32, action_dim),
|
||||
nn.Softmax(dim=1),
|
||||
)
|
||||
|
||||
# critic
|
||||
self.critic = nn.Sequential(
|
||||
nn.Linear(state_dim, 64),
|
||||
nn.Tanh(),
|
||||
nn.Linear(64, 32),
|
||||
nn.Tanh(),
|
||||
nn.Linear(32, 1),
|
||||
)
|
||||
self.action_var = torch.full((action_dim,), action_std * action_std).to(device)
|
||||
|
||||
def forward(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def act(self, state, rollout_storage):
|
||||
action_mean = self.actor(state)
|
||||
cov_mat = torch.diag(self.action_var).to(device)
|
||||
|
||||
dist = MultivariateNormal(action_mean, cov_mat)
|
||||
action = dist.sample()
|
||||
action_logprob = dist.log_prob(action)
|
||||
|
||||
if rollout_storage:
|
||||
rollout_storage.states.append(state)
|
||||
rollout_storage.actions.append(action)
|
||||
rollout_storage.logprobs.append(action_logprob)
|
||||
|
||||
return action.detach()
|
||||
|
||||
def evaluate(self, state, action):
|
||||
action_mean = self.actor(state)
|
||||
|
||||
action_var = self.action_var.expand_as(action_mean)
|
||||
cov_mat = torch.diag_embed(action_var).to(device)
|
||||
|
||||
dist = MultivariateNormal(action_mean, cov_mat)
|
||||
|
||||
action_logprobs = dist.log_prob(action)
|
||||
dist_entropy = dist.entropy()
|
||||
state_value = self.critic(state)
|
||||
|
||||
return action_logprobs, torch.squeeze(state_value), dist_entropy
|
||||
|
||||
|
||||
class Demon:
|
||||
def __init__(
|
||||
self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip
|
||||
):
|
||||
self.lr = lr
|
||||
self.betas = betas
|
||||
self.gamma = gamma
|
||||
self.eps_clip = eps_clip
|
||||
self.K_epochs = K_epochs
|
||||
|
||||
self.policy = ActorCritic(state_dim, action_dim, action_std).to(device)
|
||||
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
|
||||
|
||||
self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device)
|
||||
self.policy_old.load_state_dict(self.policy.state_dict())
|
||||
|
||||
self.MseLoss = nn.MSELoss()
|
||||
|
||||
def select_action(self, state, rollout_storage):
|
||||
return self.policy_old.act(state, rollout_storage)
|
||||
|
||||
def update(self, rollout_storage):
|
||||
# Monte Carlo estimate of rewards:
|
||||
rewards = []
|
||||
discounted_reward = 0
|
||||
for reward in reversed(rollout_storage.rewards):
|
||||
discounted_reward = reward + (self.gamma * discounted_reward)
|
||||
rewards.insert(0, discounted_reward)
|
||||
|
||||
# Normalizing the rewards:
|
||||
rewards = torch.stack(rewards)
|
||||
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
|
||||
rewards = rewards.squeeze(-1)
|
||||
|
||||
# convert list to tensor
|
||||
old_states = torch.squeeze(
|
||||
torch.stack(rollout_storage.states).to(device), 1
|
||||
).detach()
|
||||
old_actions = torch.squeeze(
|
||||
torch.stack(rollout_storage.actions).to(device), 1
|
||||
).detach()
|
||||
old_logprobs = (
|
||||
torch.squeeze(torch.stack(rollout_storage.logprobs), 1).to(device).detach()
|
||||
)
|
||||
|
||||
# Optimize policy for K epochs:
|
||||
for _ in range(self.K_epochs):
|
||||
# Evaluating old actions and values :
|
||||
logprobs, state_values, dist_entropy = self.policy.evaluate(
|
||||
old_states, old_actions
|
||||
)
|
||||
|
||||
# Finding the ratio (pi_theta / pi_theta__old):
|
||||
ratios = torch.exp(logprobs - old_logprobs.detach())
|
||||
|
||||
try:
|
||||
state_values = state_values[
|
||||
:-1, :
|
||||
] # reward is computed as the mutual info between consequenct mem state,
|
||||
# therefore n-1 values only.
|
||||
ratios = ratios[:-1, :] # the same for ratio
|
||||
dist_entropy = dist_entropy[:-1, :] # the same for entropy
|
||||
advantages = rewards - state_values.detach()
|
||||
surr1 = ratios * advantages
|
||||
surr2 = (
|
||||
torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip)
|
||||
* advantages
|
||||
)
|
||||
# Finding Surrogate Loss:
|
||||
loss = (
|
||||
-torch.min(surr1, surr2)
|
||||
+ 0.5 * self.MseLoss(state_values, rewards)
|
||||
- 0.01 * dist_entropy
|
||||
)
|
||||
# take gradient step
|
||||
self.optimizer.zero_grad()
|
||||
loss.mean().backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
|
||||
self.optimizer.step()
|
||||
except Exception:
|
||||
# Do thing for the sequences of lentgh 1.
|
||||
loss = torch.zeros_like(rewards).to(device)
|
||||
continue
|
||||
|
||||
# Copy new weights into old policy:
|
||||
self.policy_old.load_state_dict(self.policy.state_dict())
|
||||
return loss
|
||||
|
||||
|
||||
############################################
|
||||
# Mutual information Estimator Network######
|
||||
############################################
|
||||
|
||||
|
||||
def linear_reset(module, gain=1.0):
|
||||
assert isinstance(module, torch.nn.Linear)
|
||||
init.xavier_uniform_(module.weight, gain=gain)
|
||||
s = module.weight.size(1)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class FNet(nn.Module):
|
||||
"""
|
||||
Monte-Carlo estimators for Mutual Information Known as MINE.
|
||||
Mine produces estimates that are neither an upper or lower bound on MI.
|
||||
Other ZNet can be Introduced to address the problem of building bounds with finite samples (unlike Monte Carlo)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(FNet, self).__init__()
|
||||
|
||||
def reset_parameters(self):
|
||||
for module in self.lstm:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
|
||||
for module in self.hidden2f:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
|
||||
def init(self, input_size):
|
||||
self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
|
||||
self.hidden2f = nn.Sequential(nn.Linear(32, 1))
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, data):
|
||||
output, (hn, cn) = self.lstm(data)
|
||||
output = F.elu(output)
|
||||
fvals = self.hidden2f(output)
|
||||
return fvals
|
||||
|
||||
|
||||
class ZNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ZNet, self).__init__()
|
||||
|
||||
def reset_parameters(self):
|
||||
for module in self.lstm:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
|
||||
for module in self.hidden2z:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
linear_reset(module, gain=init.calculate_gain("relu"))
|
||||
|
||||
def init(self, input_size):
|
||||
self.lstm = nn.Sequential(nn.LSTM(input_size, 32, batch_first=True))
|
||||
self.hidden2z = nn.Sequential(nn.Linear(32, 1))
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, data):
|
||||
output, (hn, cn) = self.lstm(data)
|
||||
output = F.elu(output)
|
||||
zvals = self.hidden2z(output)
|
||||
return F.softplus(zvals)
|
BIN
Utils/.DS_Store
vendored
Normal file
BIN
Utils/.DS_Store
vendored
Normal file
Binary file not shown.
167
Utils/ArgumentParser.py
Normal file
167
Utils/ArgumentParser.py
Normal file
@ -0,0 +1,167 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
|
||||
|
||||
class ArgumentParser:
|
||||
class Parsed:
|
||||
pass
|
||||
|
||||
_type = type
|
||||
|
||||
@staticmethod
|
||||
def str_or_none(none_string="none"):
|
||||
def parse(s):
|
||||
return None if s.lower()==none_string else s
|
||||
return parse
|
||||
|
||||
@staticmethod
|
||||
def list_or_none(none_string="none", type=int):
|
||||
def parse(s):
|
||||
return None if s.lower() == none_string else [type(a) for a in s.split(",") if a]
|
||||
return parse
|
||||
|
||||
@staticmethod
|
||||
def _merge_args(args, new_args, arg_schemas):
|
||||
for name, val in new_args.items():
|
||||
old = args.get(name)
|
||||
if old is None:
|
||||
args[name] = val
|
||||
else:
|
||||
args[name] = arg_schemas[name]["updater"](old, val)
|
||||
|
||||
class Profile:
|
||||
def __init__(self, name, args=None, include=[]):
|
||||
assert not (args is None and not include), "One of args or include must be defined"
|
||||
self.name = name
|
||||
self.args = args
|
||||
if not isinstance(include, list):
|
||||
include=[include]
|
||||
self.include = include
|
||||
|
||||
def get_args(self, arg_schemas, profile_by_name):
|
||||
res = {}
|
||||
|
||||
for n in self.include:
|
||||
p = profile_by_name.get(n)
|
||||
assert p is not None, "Included profile %s doesn't exists" % n
|
||||
|
||||
ArgumentParser._merge_args(res, p.get_args(arg_schemas, profile_by_name), arg_schemas)
|
||||
|
||||
ArgumentParser._merge_args(res, self.args, arg_schemas)
|
||||
return res
|
||||
|
||||
|
||||
def __init__(self, description=None):
|
||||
self.parser = argparse.ArgumentParser(description=description)
|
||||
self.loaded = {}
|
||||
self.profiles = {}
|
||||
self.args = {}
|
||||
self.raw = None
|
||||
self.parsed = None
|
||||
self.parser.add_argument("-profile", type=str, help="Pre-defined profiles.")
|
||||
|
||||
def add_argument(self, name, type=None, default=None, help="", save=True, parser=lambda x:x, updater=lambda old, new:new):
|
||||
assert name not in ["profile"], "Argument name %s is reserved" % name
|
||||
assert not (type is None and default is None), "Either type or default must be given"
|
||||
|
||||
if type is None:
|
||||
type = ArgumentParser._type(default)
|
||||
|
||||
self.parser.add_argument(name, type=int if type==bool else type, default=None, help=help)
|
||||
if name[0] == '-':
|
||||
name = name[1:]
|
||||
|
||||
self.args[name] = {
|
||||
"type": type,
|
||||
"default": int(default) if type==bool else default,
|
||||
"save": save,
|
||||
"parser": parser,
|
||||
"updater": updater
|
||||
}
|
||||
|
||||
def add_profile(self, prof):
|
||||
if isinstance(prof, list):
|
||||
for p in prof:
|
||||
self.add_profile(p)
|
||||
else:
|
||||
self.profiles[prof.name] = prof
|
||||
|
||||
def do_parse_args(self, loaded={}):
|
||||
self.raw = self.parser.parse_args()
|
||||
|
||||
profile = {}
|
||||
if self.raw.profile:
|
||||
assert not loaded, "Loading arguments from file, but profile given."
|
||||
for pr in self.raw.profile.split(","):
|
||||
p = self.profiles.get(pr)
|
||||
assert p is not None, "Invalid profile: %s. Valid profiles: %s" % (pr, self.profiles.keys())
|
||||
p = p.get_args(self.args, self.profiles)
|
||||
self._merge_args(profile, p, self.args)
|
||||
|
||||
for k, v in self.raw.__dict__.items():
|
||||
if k in ["profile"]:
|
||||
continue
|
||||
|
||||
if v is None:
|
||||
if k in loaded and self.args[k]["save"]:
|
||||
self.raw.__dict__[k] = loaded[k]
|
||||
else:
|
||||
self.raw.__dict__[k] = profile.get(k, self.args[k]["default"])
|
||||
|
||||
self.parsed = ArgumentParser.Parsed()
|
||||
self.parsed.__dict__.update({k: self.args[k]["parser"](self.args[k]["type"](v)) if v is not None else None
|
||||
for k,v in self.raw.__dict__.items() if k in self.args})
|
||||
|
||||
return self.parsed
|
||||
|
||||
def parse_or_cache(self):
|
||||
if self.parsed is None:
|
||||
self.do_parse_args()
|
||||
|
||||
def parse(self):
|
||||
self.parse_or_cache()
|
||||
return self.parsed
|
||||
|
||||
def save(self, fname):
|
||||
self.parse_or_cache()
|
||||
with open(fname, 'w') as outfile:
|
||||
json.dump(self.raw.__dict__, outfile, indent=4)
|
||||
return True
|
||||
|
||||
def load(self, fname):
|
||||
if os.path.isfile(fname):
|
||||
map = {}
|
||||
with open(fname, "r") as data_file:
|
||||
map = json.load(data_file)
|
||||
|
||||
self.do_parse_args(map)
|
||||
return self.parsed
|
||||
|
||||
def sync(self, fname, save=True):
|
||||
if os.path.isfile(fname):
|
||||
self.load(fname)
|
||||
|
||||
if save:
|
||||
dir = os.path.dirname(fname)
|
||||
if not os.path.isdir(dir):
|
||||
os.makedirs(dir)
|
||||
|
||||
self.save(fname)
|
||||
return self.parsed
|
75
Utils/Collate.py
Normal file
75
Utils/Collate.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from operator import mul
|
||||
from functools import reduce
|
||||
|
||||
class VarLengthCollate:
|
||||
def __init__(self, ignore_symbol=0):
|
||||
self.ignore_symbol = ignore_symbol
|
||||
|
||||
def _measure_array_max_dim(self, batch):
|
||||
s=list(batch[0].size())
|
||||
different=[False] * len(s)
|
||||
for i in range(1, len(batch)):
|
||||
ns = batch[i].size()
|
||||
different = [different[j] or s[j]!=ns[j] for j in range(len(s))]
|
||||
s=[max(s[j], ns[j]) for j in range(len(s))]
|
||||
return s, different
|
||||
|
||||
def _merge_var_len_array(self, batch):
|
||||
max_size, different = self._measure_array_max_dim(batch)
|
||||
s=[len(batch)] + max_size
|
||||
storage = batch[0].storage()._new_shared(reduce(mul, s, 1))
|
||||
out = batch[0].new(storage).view(s).fill_(self.ignore_symbol)
|
||||
for i, d in enumerate(batch):
|
||||
this_o = out[i]
|
||||
for j, diff in enumerate(different):
|
||||
if different[j]:
|
||||
this_o = this_o.narrow(j,0,d.size(j))
|
||||
this_o.copy_(d)
|
||||
return out
|
||||
|
||||
|
||||
def __call__(self, batch):
|
||||
if isinstance(batch[0], dict):
|
||||
return {k: self([b[k] for b in batch]) for k in batch[0].keys()}
|
||||
elif isinstance(batch[0], np.ndarray):
|
||||
return self([torch.from_numpy(a) for a in batch])
|
||||
elif torch.is_tensor(batch[0]):
|
||||
return self._merge_var_len_array(batch)
|
||||
else:
|
||||
assert False, "Unknown type: %s" % type(batch[0])
|
||||
|
||||
class MetaCollate:
|
||||
def __init__(self, meta_name="meta", collate=VarLengthCollate()):
|
||||
self.meta_name = meta_name
|
||||
self.collate = collate
|
||||
|
||||
def __call__(self, batch):
|
||||
if isinstance(batch[0], dict):
|
||||
meta = [b[self.meta_name] for b in batch]
|
||||
batch = [{k: v for k,v in b.items() if k!=self.meta_name} for b in batch]
|
||||
else:
|
||||
meta = None
|
||||
|
||||
res = self.collate(batch)
|
||||
if meta is not None:
|
||||
res[self.meta_name] = meta
|
||||
|
||||
return res
|
133
Utils/Debug.py
Normal file
133
Utils/Debug.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
|
||||
enableDebug = False
|
||||
|
||||
def nan_check(arg, name=None, force=False):
|
||||
if not enableDebug and not force:
|
||||
return arg
|
||||
is_nan = False
|
||||
curr_nan = False
|
||||
if isinstance(arg, torch.autograd.Variable):
|
||||
curr_nan = not np.isfinite(arg.sum().cpu().data.numpy())
|
||||
elif isinstance(arg, torch.nn.parameter.Parameter):
|
||||
curr_nan = (not np.isfinite(arg.sum().cpu().data.numpy())) or (not np.isfinite(arg.grad.sum().cpu().data.numpy()))
|
||||
elif isinstance(arg, float):
|
||||
curr_nan = not np.isfinite(arg)
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for a in arg:
|
||||
nan_check(a)
|
||||
else:
|
||||
assert False, "Unsupported type %s" % type(arg)
|
||||
|
||||
if curr_nan:
|
||||
if sys.exc_info()[0] is not None:
|
||||
trace = str(traceback.format_exc())
|
||||
else:
|
||||
trace = "".join(traceback.format_stack())
|
||||
|
||||
print(arg)
|
||||
if name is not None:
|
||||
print("NaN found in %s." % name)
|
||||
else:
|
||||
print("NaN found.")
|
||||
if isinstance(arg, torch.autograd.Variable):
|
||||
print(" Argument is a torch tensor. Shape: %s" % list(arg.size()))
|
||||
|
||||
print(trace)
|
||||
sys.exit(-1)
|
||||
|
||||
return arg
|
||||
|
||||
|
||||
def assert_range(t, min=0.0, max=1.0):
|
||||
if not enableDebug:
|
||||
return
|
||||
|
||||
if t.min().cpu().data.numpy()<min or t.max().cpu().data.numpy()>max:
|
||||
print(t)
|
||||
assert False
|
||||
|
||||
|
||||
def assert_dist(t, use_lower_limit=True):
|
||||
if not enableDebug:
|
||||
return
|
||||
|
||||
assert_range(t)
|
||||
|
||||
if t.sum(-1).max().cpu().data.numpy()>1.001:
|
||||
print("MAT:", t)
|
||||
print("SUM:", t.sum(-1))
|
||||
assert False
|
||||
|
||||
if use_lower_limit and t.sum(-1).max().cpu().data.numpy()<0.999:
|
||||
print(t)
|
||||
print("SUM:", t.sum(-1))
|
||||
assert False
|
||||
|
||||
|
||||
def print_stat(name, t):
|
||||
if not enableDebug:
|
||||
return
|
||||
|
||||
min = t.min().cpu().data.numpy()
|
||||
max = t.max().cpu().data.numpy()
|
||||
mean = t.mean().cpu().data.numpy()
|
||||
|
||||
print("%s: min: %g, mean: %g, max: %g" % (name, min, mean, max))
|
||||
|
||||
|
||||
def dbg_print(*things):
|
||||
if not enableDebug:
|
||||
return
|
||||
print(*things)
|
||||
|
||||
class GradPrinter(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a):
|
||||
return a
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, g):
|
||||
print("Grad (print_grad): ", g[0])
|
||||
return g
|
||||
|
||||
def print_grad(t):
|
||||
return GradPrinter.apply(t)
|
||||
|
||||
def assert_equal(t1, ref, limit=1e-5, force=True):
|
||||
if not (enableDebug or force):
|
||||
return
|
||||
|
||||
assert t1.shape==ref.shape, "Tensor shapes differ: got %s, ref %s" % (t1.shape, ref.shape)
|
||||
norm = ref.abs().sum() / ref.nonzero().sum().float()
|
||||
threshold = norm * limit
|
||||
|
||||
errcnt = ((t1 - ref).abs() > threshold).sum()
|
||||
if errcnt > 0:
|
||||
print("Tensors differ. (max difference: %g, norm %f). No of errors: %d of %d" %
|
||||
((t1 - ref).abs().max().item(), norm, errcnt, t1.numel()))
|
||||
print("---------------------------------------------Out-----------------------------------------------")
|
||||
print(t1)
|
||||
print("---------------------------------------------Ref-----------------------------------------------")
|
||||
print(ref)
|
||||
print("-----------------------------------------------------------------------------------------------")
|
||||
assert False
|
24
Utils/Helpers.py
Normal file
24
Utils/Helpers.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import torch
|
||||
import torch.autograd
|
||||
|
||||
def as_numpy(data):
|
||||
if isinstance(data, (torch.Tensor, torch.autograd.Variable)):
|
||||
return data.detach().cpu().numpy()
|
||||
else:
|
||||
return data
|
20
Utils/Index.py
Normal file
20
Utils/Index.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
def index_by_dim(arr, dim, i_start, i_end=None):
|
||||
if dim<0:
|
||||
dim += arr.ndim
|
||||
return arr[tuple([slice(None,None)] * dim + [slice(i_start, i_end) if i_end is not None else i_start])]
|
51
Utils/Process.py
Normal file
51
Utils/Process.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import sys
|
||||
import ctypes
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
def run(cmd, hide_stderr = True):
|
||||
libc_search_dirs = ["/lib", "/lib/x86_64-linux-gnu", "/lib/powerpc64le-linux-gnu"]
|
||||
|
||||
if sys.platform=="linux" :
|
||||
found = None
|
||||
for d in libc_search_dirs:
|
||||
file = os.path.join(d, "libc.so.6")
|
||||
if os.path.isfile(file):
|
||||
found = file
|
||||
break
|
||||
|
||||
if not found:
|
||||
print("WARNING: Cannot find libc.so.6. Cannot kill process when parent dies.")
|
||||
killer = None
|
||||
else:
|
||||
libc = ctypes.CDLL(found)
|
||||
PR_SET_PDEATHSIG = 1
|
||||
KILL = 9
|
||||
killer = lambda: libc.prctl(PR_SET_PDEATHSIG, KILL)
|
||||
else:
|
||||
print("WARNING: OS not linux. Cannot kill process when parent dies.")
|
||||
killer = None
|
||||
|
||||
|
||||
if hide_stderr:
|
||||
stderr = open(os.devnull,'w')
|
||||
else:
|
||||
stderr = None
|
||||
|
||||
return subprocess.Popen(cmd.split(" "), stderr=stderr, preexec_fn=killer)
|
48
Utils/Profile.py
Normal file
48
Utils/Profile.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import atexit
|
||||
|
||||
ENABLED=False
|
||||
|
||||
_profiler = None
|
||||
|
||||
|
||||
def construct():
|
||||
global _profiler
|
||||
if not ENABLED:
|
||||
return
|
||||
|
||||
if _profiler is None:
|
||||
from line_profiler import LineProfiler
|
||||
_profiler = LineProfiler()
|
||||
|
||||
|
||||
def do_profile(follow=[]):
|
||||
construct()
|
||||
def inner(func):
|
||||
if _profiler is not None:
|
||||
_profiler.add_function(func)
|
||||
for f in follow:
|
||||
_profiler.add_function(f)
|
||||
_profiler.enable_by_count()
|
||||
return func
|
||||
return inner
|
||||
|
||||
@atexit.register
|
||||
def print_prof():
|
||||
if _profiler is not None:
|
||||
_profiler.print_stats()
|
233
Utils/Saver.py
Normal file
233
Utils/Saver.py
Normal file
@ -0,0 +1,233 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import torch
|
||||
import os
|
||||
import inspect
|
||||
import time
|
||||
|
||||
|
||||
class SaverElement:
|
||||
def save(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def load(self, saved_state):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CallbackSaver(SaverElement):
|
||||
def __init__(self, save_fn, load_fn):
|
||||
super().__init__()
|
||||
self.save = save_fn
|
||||
self.load = load_fn
|
||||
|
||||
|
||||
class StateSaver(SaverElement):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self._model = model
|
||||
|
||||
def load(self, state):
|
||||
try:
|
||||
self._model.load_state_dict(state)
|
||||
except Exception as e:
|
||||
if hasattr(self._model, "named_parameters"):
|
||||
names = set([n for n, _ in self._model.named_parameters()])
|
||||
loaded = set(self._model.keys())
|
||||
if names!=loaded:
|
||||
d = loaded.difference(names)
|
||||
if d:
|
||||
print("Loaded, but not in model: %s" % list(d))
|
||||
d = names.difference(loaded)
|
||||
if d:
|
||||
print("In model, but not loaded: %s" % list(d))
|
||||
if isinstance(self._model, torch.optim.Optimizer):
|
||||
print("WARNING: optimizer parameters not loaded!")
|
||||
else:
|
||||
raise e
|
||||
|
||||
def save(self):
|
||||
return self._model.state_dict()
|
||||
|
||||
|
||||
class GlobalVarSaver(SaverElement):
|
||||
def __init__(self, name):
|
||||
caller_frame = inspect.getouterframes(inspect.currentframe())[1]
|
||||
self._vars = caller_frame.frame.f_globals
|
||||
self._name = name
|
||||
|
||||
def load(self, state):
|
||||
self._vars.update({self._name: state})
|
||||
|
||||
def save(self):
|
||||
return self._vars[self._name]
|
||||
|
||||
|
||||
class PyObjectSaver(SaverElement):
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def load(self, state):
|
||||
def _load(target, state):
|
||||
if isinstance(target, dict):
|
||||
for k, v in state.items():
|
||||
target[k] = _load(target.get(k), v)
|
||||
elif isinstance(target, list):
|
||||
if len(target)!=len(state):
|
||||
target.clear()
|
||||
for v in state:
|
||||
target.append(v)
|
||||
else:
|
||||
for i, v in enumerate(state):
|
||||
target[i] = _load(target[i], v)
|
||||
|
||||
elif hasattr(target, "__dict__"):
|
||||
_load(target.__dict__, state)
|
||||
else:
|
||||
return state
|
||||
return target
|
||||
|
||||
_load(self._obj, state)
|
||||
|
||||
def save(self):
|
||||
def _save(target):
|
||||
if isinstance(target, dict):
|
||||
res = {k: _save(v) for k, v in target.items()}
|
||||
elif isinstance(target, list):
|
||||
res = [_save(v) for v in target]
|
||||
elif hasattr(target, "__dict__"):
|
||||
res = {k: _save(v) for k, v in target.__dict__.items()}
|
||||
else:
|
||||
res = target
|
||||
|
||||
return res
|
||||
|
||||
return _save(self._obj)
|
||||
|
||||
@staticmethod
|
||||
def obj_supported(obj):
|
||||
return isinstance(obj, (list, dict)) or hasattr(obj, "__dict__")
|
||||
|
||||
|
||||
class Saver:
|
||||
def __init__(self, dir, short_interval, keep_every_n_hours=4):
|
||||
self.savers = {}
|
||||
self.short_interval = short_interval
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
self.dir = dir
|
||||
self._keep_every_n_seconds = keep_every_n_hours * 3600
|
||||
|
||||
def register(self, name, saver):
|
||||
assert name not in self.savers, "Saver %s already registered" % name
|
||||
|
||||
if isinstance(saver, SaverElement):
|
||||
self.savers[name] = saver
|
||||
elif hasattr(saver, "state_dict") and callable(saver.state_dict):
|
||||
self.savers[name] = StateSaver(saver)
|
||||
elif PyObjectSaver.obj_supported(saver):
|
||||
self.savers[name] = PyObjectSaver(saver)
|
||||
else:
|
||||
assert "Unsupported thing to save: %s" % type(saver)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.register(key, value)
|
||||
|
||||
def write(self, iter):
|
||||
fname = os.path.join(self.dir, self.model_name_from_index(iter))
|
||||
print("Saving %s" % fname)
|
||||
|
||||
state = {}
|
||||
for name, fns in self.savers.items():
|
||||
state[name] = fns.save()
|
||||
|
||||
torch.save(state, fname)
|
||||
print("Saved.")
|
||||
|
||||
self._cleanup()
|
||||
|
||||
def tick(self, iter):
|
||||
if iter % self.short_interval != 0:
|
||||
return
|
||||
|
||||
self.write(iter)
|
||||
|
||||
@staticmethod
|
||||
def model_name_from_index(index):
|
||||
return "model-%d.pth" % index
|
||||
|
||||
@staticmethod
|
||||
def get_checkpoint_index_list(dir):
|
||||
return list(reversed(sorted(
|
||||
[int(fn.split(".")[0].split("-")[-1]) for fn in os.listdir(dir) if fn.split(".")[-1] == "pth"])))
|
||||
|
||||
@staticmethod
|
||||
def get_ckpts_in_time_window(dir, time_window_s, index_list=None):
|
||||
if index_list is None:
|
||||
index_list = Saver.get_checkpoint_index_list(dir)
|
||||
|
||||
|
||||
now = time.time()
|
||||
|
||||
res = []
|
||||
for i in index_list:
|
||||
name = Saver.model_name_from_index(i)
|
||||
mtime = os.path.getmtime(os.path.join(dir, name))
|
||||
if now - mtime > time_window_s:
|
||||
break
|
||||
|
||||
res.append(name)
|
||||
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def load_last_checkpoint(dir):
|
||||
last_checkpoint = Saver.get_checkpoint_index_list(dir)
|
||||
|
||||
if last_checkpoint:
|
||||
for index in last_checkpoint:
|
||||
fname = Saver.model_name_from_index(index)
|
||||
try:
|
||||
print("Loading %s" % fname)
|
||||
data = torch.load(os.path.join(dir, fname))
|
||||
except:
|
||||
print("WARNING: Loading %s failed. Maybe file is corrupted?" % fname)
|
||||
continue
|
||||
return data
|
||||
return None
|
||||
|
||||
def _cleanup(self):
|
||||
index_list = self.get_checkpoint_index_list(self.dir)
|
||||
new_files = self.get_ckpts_in_time_window(self.dir, self._keep_every_n_seconds, index_list[2:])
|
||||
new_files = new_files[:-1]
|
||||
|
||||
for f in new_files:
|
||||
os.remove(os.path.join(self.dir, f))
|
||||
|
||||
def load(self, fname=None):
|
||||
if fname is None:
|
||||
state = self.load_last_checkpoint(self.dir)
|
||||
if not state:
|
||||
return False
|
||||
else:
|
||||
state = torch.load(fname)
|
||||
|
||||
for k,s in state.items():
|
||||
if k not in self.savers:
|
||||
print("WARNING: failed to load state of %s. It doesn't exists." % k)
|
||||
continue
|
||||
self.savers[k].load(s)
|
||||
|
||||
return True
|
31
Utils/Seed.py
Normal file
31
Utils/Seed.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
seed = None
|
||||
|
||||
|
||||
def fix():
|
||||
global seed
|
||||
seed = 0xB0C1FA52
|
||||
|
||||
|
||||
def get_randstate():
|
||||
if seed:
|
||||
return np.random.RandomState(seed)
|
||||
else:
|
||||
return np.random.RandomState()
|
323
Utils/Visdom.py
Normal file
323
Utils/Visdom.py
Normal file
@ -0,0 +1,323 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import time
|
||||
import visdom
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
import socket
|
||||
from . import Process
|
||||
|
||||
vis = None
|
||||
port = None
|
||||
visdom_fail_count = 0
|
||||
|
||||
|
||||
def port_used(port):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
result = sock.connect_ex(('127.0.0.1', port))
|
||||
if result == 0:
|
||||
sock.close()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def alloc_port(start_from=7000):
|
||||
while True:
|
||||
if port_used(start_from):
|
||||
print("Port already used: %d" % start_from)
|
||||
start_from += 1
|
||||
else:
|
||||
return start_from
|
||||
|
||||
|
||||
def wait_for_port(port, timeout=5):
|
||||
star_time = time.time()
|
||||
while not port_used(port):
|
||||
if time.time() - star_time > timeout:
|
||||
return False
|
||||
|
||||
time.sleep(0.1)
|
||||
return True
|
||||
|
||||
|
||||
def start(on_port=None):
|
||||
global vis
|
||||
global port
|
||||
global visdom_fail_count
|
||||
assert vis is None, "Cannot start more than 1 visdom servers."
|
||||
|
||||
if visdom_fail_count>=3:
|
||||
return
|
||||
|
||||
port = alloc_port() if on_port is None else on_port
|
||||
|
||||
print("Starting Visdom server on %d" % port)
|
||||
Process.run("%s -m visdom.server -p %d" % (sys.executable, port))
|
||||
if not wait_for_port(port):
|
||||
print("ERROR: failed to start Visdom server. Server not responding.")
|
||||
visdom_fail_count += 1
|
||||
return
|
||||
print("Done.")
|
||||
|
||||
vis = visdom.Visdom(port=port)
|
||||
|
||||
|
||||
def _start_if_not_running():
|
||||
if vis is None:
|
||||
start()
|
||||
|
||||
|
||||
def save_heatmap(dir, title, img):
|
||||
if dir is not None:
|
||||
fname = os.path.join(dir, title.replace(" ", "_") + ".npy")
|
||||
d = os.path.dirname(fname)
|
||||
os.makedirs(d, exist_ok=True)
|
||||
np.save(fname, img)
|
||||
|
||||
|
||||
class Plot2D:
|
||||
TO_SAVE = ["x", "y", "curr_accu", "curr_cnt", "legend"]
|
||||
|
||||
def __init__(self, name, store_interval=1, legend=None, xlabel=None, ylabel=None):
|
||||
_start_if_not_running()
|
||||
|
||||
self.x = []
|
||||
self.y = []
|
||||
self.store_interval = store_interval
|
||||
self.curr_accu = None
|
||||
self.curr_cnt = 0
|
||||
self.name = name
|
||||
self.legend = legend
|
||||
self.xlabel = xlabel
|
||||
self.ylabel = ylabel
|
||||
|
||||
self.replot = False
|
||||
self.visplot = None
|
||||
self.last_vis_update_pos = 0
|
||||
|
||||
def set_legend(self, legend):
|
||||
if self.legend != legend:
|
||||
self.legend = legend
|
||||
self.replot = True
|
||||
|
||||
def _send_update(self):
|
||||
if not self.x or vis is None:
|
||||
return
|
||||
|
||||
if self.visplot is None or self.replot:
|
||||
opts = {
|
||||
"title": self.name
|
||||
}
|
||||
if self.xlabel:
|
||||
opts["xlabel"] = self.xlabel
|
||||
if self.ylabel:
|
||||
opts["ylabel"] = self.ylabel
|
||||
|
||||
if self.legend is not None:
|
||||
opts["legend"] = self.legend
|
||||
|
||||
self.visplot = vis.line(X=np.asfarray(self.x), Y=np.asfarray(self.y), opts=opts, win=self.visplot)
|
||||
self.replot = False
|
||||
else:
|
||||
vis.line(
|
||||
X=np.asfarray(self.x[self.last_vis_update_pos:]),
|
||||
Y=np.asfarray(self.y[self.last_vis_update_pos:]),
|
||||
win=self.visplot,
|
||||
update='append'
|
||||
)
|
||||
|
||||
self.last_vis_update_pos = len(self.x) - 1
|
||||
|
||||
def add_point(self, x, y):
|
||||
if not isinstance(y, list):
|
||||
y = [y]
|
||||
|
||||
|
||||
|
||||
if self.curr_accu is None:
|
||||
self.curr_accu = [0.0] * len(y)
|
||||
|
||||
if len(self.curr_accu) < len(y):
|
||||
# The number of curves increased.
|
||||
need_to_add = (len(y) - len(self.curr_accu))
|
||||
|
||||
self.curr_accu += [0.0] * need_to_add
|
||||
count = len(self.x)
|
||||
if count>0:
|
||||
self.replot = True
|
||||
if not isinstance(self.x[0], list):
|
||||
self.x = [[x] for x in self.x]
|
||||
self.y = [[y] for y in self.y]
|
||||
|
||||
nan = float("nan")
|
||||
for a in self.x:
|
||||
a += [nan] * need_to_add
|
||||
for a in self.y:
|
||||
a += [nan] * need_to_add
|
||||
elif len(self.curr_accu) > len(y):
|
||||
y = y[:] + [float("nan")] * (len(self.curr_accu) - len(y))
|
||||
|
||||
self.curr_accu = [self.curr_accu[i] + y[i] for i in range(len(y))]
|
||||
self.curr_cnt += 1
|
||||
if self.curr_cnt == self.store_interval:
|
||||
if len(y) > 1:
|
||||
self.x.append([x] * len(y))
|
||||
self.y.append([a / self.curr_cnt for a in self.curr_accu])
|
||||
else:
|
||||
self.x.append(x)
|
||||
self.y.append(self.curr_accu[0] / self.curr_cnt)
|
||||
|
||||
self.curr_accu = [0.0] * len(y)
|
||||
self.curr_cnt = 0
|
||||
|
||||
self._send_update()
|
||||
|
||||
def state_dict(self):
|
||||
s = {k: self.__dict__[k] for k in self.TO_SAVE}
|
||||
return s
|
||||
|
||||
def load_state_dict(self, state):
|
||||
if self.legend is not None:
|
||||
# Load legend only if not given in the constructor.
|
||||
state["legend"] = self.legend
|
||||
self.__dict__.update(state)
|
||||
self.last_vis_update_pos = 0
|
||||
|
||||
# Read old format
|
||||
if not isinstance(self.curr_accu, list) and self.curr_accu is not None:
|
||||
self.curr_accu = [self.curr_accu]
|
||||
|
||||
self._send_update()
|
||||
|
||||
|
||||
class Image:
|
||||
def __init__(self, title, dumpdir=None):
|
||||
_start_if_not_running()
|
||||
|
||||
self.win = None
|
||||
self.opts = dict(title=title)
|
||||
self.dumpdir = dumpdir
|
||||
|
||||
def set_dump_dir(self, dumpdir):
|
||||
self.dumpdir = dumpdir
|
||||
|
||||
def draw(self, img):
|
||||
if vis is None:
|
||||
return
|
||||
|
||||
if isinstance(img, list):
|
||||
if self.win is None:
|
||||
self.win = vis.images(img, opts=self.opts)
|
||||
else:
|
||||
vis.images(img, win=self.win, opts=self.opts)
|
||||
else:
|
||||
if len(img.shape)==2:
|
||||
img = np.expand_dims(img, 0)
|
||||
elif img.shape[-1] in [1,3] and img.shape[0] not in [1,3]:
|
||||
# cv2 image
|
||||
img = img.transpose(2,0,1)
|
||||
img = img[::-1]
|
||||
|
||||
if img.dtype==np.uint8:
|
||||
img = img.astype(np.float32)/255
|
||||
|
||||
self.opts["width"] = img.shape[2]
|
||||
self.opts["height"] = img.shape[1]
|
||||
|
||||
save_heatmap(self.dumpdir, self.opts["title"], img)
|
||||
if self.win is None:
|
||||
self.win = vis.image(img, opts=self.opts)
|
||||
else:
|
||||
vis.image(img, win=self.win, opts=self.opts)
|
||||
|
||||
def __call__(self, img):
|
||||
self.draw(img)
|
||||
|
||||
|
||||
class Text:
|
||||
def __init__(self, title):
|
||||
_start_if_not_running()
|
||||
|
||||
self.win = None
|
||||
self.title = title
|
||||
self.curr_text = ""
|
||||
|
||||
def set(self, text):
|
||||
self.curr_text = text
|
||||
|
||||
if vis is None:
|
||||
return
|
||||
|
||||
if self.win is None:
|
||||
self.win = vis.text(text, opts=dict(
|
||||
title=self.title
|
||||
))
|
||||
else:
|
||||
vis.text(text, win=self.win)
|
||||
|
||||
def state_dict(self):
|
||||
return {"text": self.curr_text}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
self.set(state["text"])
|
||||
|
||||
def __call__(self, text):
|
||||
self.set(text)
|
||||
|
||||
|
||||
class Heatmap:
|
||||
def __init__(self, title, min=None, max=None, xlabel=None, ylabel=None, colormap='Viridis', dumpdir=None):
|
||||
_start_if_not_running()
|
||||
|
||||
self.win = None
|
||||
self.opt = dict(title=title, colormap=colormap)
|
||||
self.dumpdir = dumpdir
|
||||
if min is not None:
|
||||
self.opt["xmin"] = min
|
||||
if max is not None:
|
||||
self.opt["xmax"] = max
|
||||
|
||||
if xlabel:
|
||||
self.opt["xlabel"] = xlabel
|
||||
if ylabel:
|
||||
self.opt["ylabel"] = ylabel
|
||||
|
||||
def set_dump_dir(self, dumpdir):
|
||||
self.dumpdir = dumpdir
|
||||
|
||||
def draw(self, img):
|
||||
if vis is None:
|
||||
return
|
||||
|
||||
o = self.opt.copy()
|
||||
if "xmin" not in o:
|
||||
o["xmin"] = float(img.min())
|
||||
|
||||
if "xmax" not in o:
|
||||
o["xmax"] = float(img.max())
|
||||
|
||||
save_heatmap(self.dumpdir, o["title"], img)
|
||||
|
||||
if self.win is None:
|
||||
self.win = vis.heatmap(img, opts=o)
|
||||
else:
|
||||
vis.heatmap(img, win=self.win, opts=o)
|
||||
|
||||
def __call__(self, img):
|
||||
self.draw(img)
|
189
Utils/download.py
Normal file
189
Utils/download.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import requests, tarfile, io, os, zipfile
|
||||
|
||||
from io import BytesIO, SEEK_SET, SEEK_END
|
||||
|
||||
|
||||
class UrlStream:
|
||||
def __init__(self, url):
|
||||
self._url = url
|
||||
headers = requests.head(url).headers
|
||||
headers = {k.lower(): v for k,v in headers.items()}
|
||||
self._seek_supported = headers.get('accept-ranges')=='bytes' and 'content-length' in headers
|
||||
if self._seek_supported:
|
||||
self._size = int(headers['content-length'])
|
||||
self._curr_pos = 0
|
||||
self._buf_start_pos = 0
|
||||
self._iter = None
|
||||
self._buffer = None
|
||||
self._buf_size = 0
|
||||
self._loaded_all = False
|
||||
|
||||
def seekable(self):
|
||||
return self._seek_supported
|
||||
|
||||
def _load_all(self):
|
||||
if self._loaded_all:
|
||||
return
|
||||
self._make_request()
|
||||
old_buf_pos = self._buffer.tell()
|
||||
self._buffer.seek(0, SEEK_END)
|
||||
for chunk in self._iter:
|
||||
self._buffer.write(chunk)
|
||||
self._buf_size = self._buffer.tell()
|
||||
self._buffer.seek(old_buf_pos, SEEK_SET)
|
||||
self._loaded_all = True
|
||||
|
||||
def seek(self, position, whence=SEEK_SET):
|
||||
if whence == SEEK_END:
|
||||
assert position<=0
|
||||
if self._seek_supported:
|
||||
self.seek(self._size + position)
|
||||
else:
|
||||
self._load_all()
|
||||
self._buffer.seek(position, SEEK_END)
|
||||
self._curr_pos = self._buffer.tell()
|
||||
elif whence==SEEK_SET:
|
||||
if self._curr_pos != position:
|
||||
self._curr_pos = position
|
||||
if self._seek_supported:
|
||||
self._iter = None
|
||||
self._buffer = None
|
||||
else:
|
||||
self._load_until(position)
|
||||
self._buffer.seek(position)
|
||||
self._curr_pos = position
|
||||
else:
|
||||
assert "Invalid whence %s" % whence
|
||||
|
||||
return self.tell()
|
||||
|
||||
def tell(self):
|
||||
return self._curr_pos
|
||||
|
||||
def _load_until(self, goal_position):
|
||||
self._make_request()
|
||||
old_buf_pos = self._buffer.tell()
|
||||
current_position = self._buffer.seek(0, SEEK_END)
|
||||
|
||||
goal_position = goal_position - self._buf_start_pos
|
||||
while current_position < goal_position:
|
||||
try:
|
||||
d = next(self._iter)
|
||||
self._buffer.write(d)
|
||||
current_position += len(d)
|
||||
except StopIteration:
|
||||
break
|
||||
self._buf_size = current_position
|
||||
self._buffer.seek(old_buf_pos, SEEK_SET)
|
||||
|
||||
def _new_buffer(self):
|
||||
remaining = self._buffer.read() if self._buffer is not None else None
|
||||
self._buffer = BytesIO()
|
||||
if remaining is not None:
|
||||
self._buffer.write(remaining)
|
||||
self._buf_start_pos = self._curr_pos
|
||||
self._buf_size = 0 if remaining is None else len(remaining)
|
||||
self._buffer.seek(0, SEEK_SET)
|
||||
self._loaded_all = False
|
||||
|
||||
def _make_request(self):
|
||||
if self._iter is None:
|
||||
h = {
|
||||
"User-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36",
|
||||
}
|
||||
if self._seek_supported:
|
||||
h["Range"] = "bytes=%d-%d" % (self._curr_pos, self._size - 1)
|
||||
|
||||
r = requests.get(self._url, headers=h, stream=True)
|
||||
|
||||
self._iter = r.iter_content(1024 * 1024)
|
||||
self._new_buffer()
|
||||
elif self._seek_supported and self._buf_size > 128 * 1024 * 1024:
|
||||
self._new_buffer()
|
||||
|
||||
def size(self):
|
||||
if self._seek_supported:
|
||||
return self._size
|
||||
else:
|
||||
self._load_all()
|
||||
return self._buf_size
|
||||
|
||||
def read(self, size=None):
|
||||
if size is None:
|
||||
size = self.size()
|
||||
|
||||
self._load_until(self._curr_pos + size)
|
||||
if self._seek_supported:
|
||||
self._curr_pos = min(self._curr_pos+size, self._size)
|
||||
|
||||
return self._buffer.read(size)
|
||||
|
||||
def iter_content(self, block_size):
|
||||
while True:
|
||||
d = self.read(block_size)
|
||||
if not len(d):
|
||||
break
|
||||
yield d
|
||||
|
||||
|
||||
def download(url, dest=None, extract=True, ignore_if_exists=False):
|
||||
"""
|
||||
Download a file from the internet.
|
||||
|
||||
Args:
|
||||
url: the url to download
|
||||
dest: destination file if extract=False, or destionation dir if extract=True. If None, it will be the last part of URL.
|
||||
extract: extract a tar.gz or zip file?
|
||||
ignore_if_exists: don't do anything if file exists
|
||||
|
||||
Returns:
|
||||
the destination filename.
|
||||
"""
|
||||
|
||||
base_url = url.split("?")[0]
|
||||
|
||||
if dest is None:
|
||||
dest = [f for f in base_url.split("/") if f][-1]
|
||||
|
||||
if os.path.exists(dest) and ignore_if_exists:
|
||||
return dest
|
||||
|
||||
stream = UrlStream(url)
|
||||
extension = base_url.split(".")[-1].lower()
|
||||
|
||||
if extract and extension in ['gz', 'bz2', 'zip']:
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
|
||||
if extension in ['gz', 'bz2']:
|
||||
decompressed_file = tarfile.open(fileobj=stream, mode='r|'+extension)
|
||||
elif extension=='zip':
|
||||
decompressed_file = zipfile.ZipFile(stream, mode='r')
|
||||
else:
|
||||
assert False, "Invalid extension: %s" % extension
|
||||
|
||||
decompressed_file.extractall(dest)
|
||||
else:
|
||||
try:
|
||||
with open(dest, 'wb') as f:
|
||||
for d in stream.iter_content(1024*1024):
|
||||
f.write(d)
|
||||
except:
|
||||
os.remove(dest)
|
||||
raise
|
||||
return dest
|
111
Utils/gpu_allocator.py
Normal file
111
Utils/gpu_allocator.py
Normal file
@ -0,0 +1,111 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import torch
|
||||
from Utils.lockfile import LockFile
|
||||
|
||||
def get_memory_usage():
|
||||
try:
|
||||
proc = subprocess.Popen("nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits".split(" "),
|
||||
stdout=subprocess.PIPE)
|
||||
lines = [s.strip().split(" ") for s in proc.communicate()[0].decode().split("\n") if s]
|
||||
return {int(g[0][:-1]): int(g[1]) for g in lines}
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_free_gpus():
|
||||
try:
|
||||
free = []
|
||||
proc = subprocess.Popen("nvidia-smi --query-compute-apps=gpu_uuid --format=csv,noheader,nounits".split(" "),
|
||||
stdout=subprocess.PIPE)
|
||||
uuids = [s.strip() for s in proc.communicate()[0].decode().split("\n") if s]
|
||||
|
||||
proc = subprocess.Popen("nvidia-smi --query-gpu=index,uuid --format=csv,noheader,nounits".split(" "),
|
||||
stdout=subprocess.PIPE)
|
||||
|
||||
id_uid_pair = [s.strip().split(", ") for s in proc.communicate()[0].decode().split("\n") if s]
|
||||
for i in id_uid_pair:
|
||||
id, uid = i
|
||||
|
||||
if uid not in uuids:
|
||||
free.append(int(id))
|
||||
|
||||
return free
|
||||
except:
|
||||
return None
|
||||
|
||||
def _fix_order():
|
||||
os.environ["CUDA_DEVICE_ORDER"] = os.environ.get("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
|
||||
|
||||
def allocate(n:int = 1):
|
||||
_fix_order()
|
||||
with LockFile("/tmp/gpu_allocation_lock"):
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
print("WARNING: trying to allocate %d GPUs, but CUDA_VISIBLE_DEVICES already set to %s" %
|
||||
(n, os.environ["CUDA_VISIBLE_DEVICES"]))
|
||||
return
|
||||
|
||||
allocated = get_free_gpus()
|
||||
if allocated is None:
|
||||
print("WARNING: failed to allocate %d GPUs" % n)
|
||||
return
|
||||
allocated = allocated[:n]
|
||||
|
||||
if len(allocated) < n:
|
||||
print("There is no more free GPUs. Allocating the one with least memory usage.")
|
||||
usage = get_memory_usage()
|
||||
if usage is None:
|
||||
print("WARNING: failed to allocate %d GPUs" % n)
|
||||
return
|
||||
|
||||
inv_usages = {}
|
||||
|
||||
for k, v in usage.items():
|
||||
if v not in inv_usages:
|
||||
inv_usages[v] = []
|
||||
|
||||
inv_usages[v].append(k)
|
||||
|
||||
min_usage = list(sorted(inv_usages.keys()))
|
||||
min_usage_devs = []
|
||||
for u in min_usage:
|
||||
min_usage_devs += inv_usages[u]
|
||||
|
||||
min_usage_devs = [m for m in min_usage_devs if m not in allocated]
|
||||
|
||||
n2 = n - len(allocated)
|
||||
if n2>len(min_usage_devs):
|
||||
print("WARNING: trying to allocate %d GPUs but only %d available" % (n, len(min_usage_devs)+len(allocated)))
|
||||
n2 = len(min_usage_devs)
|
||||
|
||||
allocated += min_usage_devs[:n2]
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"]=",".join([str(a) for a in allocated])
|
||||
for i in range(len(allocated)):
|
||||
a = torch.FloatTensor([0.0])
|
||||
a.cuda(i)
|
||||
|
||||
def use_gpu(gpu="auto", n_autoalloc=1):
|
||||
_fix_order()
|
||||
|
||||
gpu = gpu.lower()
|
||||
if gpu in ["auto", ""]:
|
||||
allocate(n_autoalloc)
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
|
41
Utils/lockfile.py
Normal file
41
Utils/lockfile.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import fcntl
|
||||
|
||||
|
||||
class LockFile:
|
||||
def __init__(self, fname):
|
||||
self._fname = fname
|
||||
self._fd = None
|
||||
|
||||
def acquire(self):
|
||||
self._fd=open(self._fname, "w")
|
||||
os.chmod(self._fname, 0o777)
|
||||
|
||||
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
||||
|
||||
def release(self):
|
||||
fcntl.lockf(self._fd, fcntl.LOCK_UN)
|
||||
self._fd.close()
|
||||
self._fd = None
|
||||
|
||||
def __enter__(self):
|
||||
self.acquire()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.release()
|
54
Utils/timer.py
Normal file
54
Utils/timer.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class OnceEvery:
|
||||
def __init__(self, interval):
|
||||
self._interval = interval
|
||||
self._last_check = 0
|
||||
|
||||
def __call__(self):
|
||||
now = time.time()
|
||||
if now - self._last_check >= self._interval:
|
||||
self._last_check = now
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class Measure:
|
||||
def __init__(self, average=1):
|
||||
self._start = None
|
||||
self._average = average
|
||||
self._accu_value = 0.0
|
||||
self._history_list = []
|
||||
|
||||
def start(self):
|
||||
self._start = time.time()
|
||||
|
||||
def passed(self):
|
||||
if self._start is None:
|
||||
return None
|
||||
|
||||
p = time.time() - self._start
|
||||
self._history_list.append(p)
|
||||
self._accu_value += p
|
||||
if len(self._history_list) > self._average:
|
||||
self._accu_value -= self._history_list.pop(0)
|
||||
|
||||
return self._accu_value / len(self._history_list)
|
263
Utils/universal.py
Normal file
263
Utils/universal.py
Normal file
@ -0,0 +1,263 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
float32 = [np.float32, torch.float32]
|
||||
float64 = [np.float64, torch.float64]
|
||||
uint8 = [np.uint8, torch.uint8]
|
||||
|
||||
_all_types = [float32, float64, uint8]
|
||||
|
||||
_dtype_numpy_map = {v[0]().dtype.name:v for v in _all_types}
|
||||
_dtype_pytorch_map = {v[1]:v for v in _all_types}
|
||||
|
||||
|
||||
def dtype(t):
|
||||
if torch.is_tensor(t):
|
||||
return _dtype_pytorch_map[t.dtype]
|
||||
else:
|
||||
return _dtype_numpy_map[t.dtype.name]
|
||||
|
||||
|
||||
def cast(t, type):
|
||||
if torch.is_tensor(t):
|
||||
return t.type(type[1])
|
||||
else:
|
||||
return t.astype(type[0])
|
||||
|
||||
|
||||
def to_numpy(t):
|
||||
if torch.is_tensor(t):
|
||||
return t.detach().cpu().numpy()
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def to_list(t):
|
||||
t = to_numpy(t)
|
||||
if isinstance(t, np.ndarray):
|
||||
t = t.tolist()
|
||||
return t
|
||||
|
||||
|
||||
def is_tensor(t):
|
||||
return torch.is_tensor(t) or isinstance(t, np.ndarray)
|
||||
|
||||
|
||||
def first_batch(t):
|
||||
if is_tensor(t):
|
||||
return t[0]
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def ndim(t):
|
||||
if torch.is_tensor(t):
|
||||
return t.dim()
|
||||
else:
|
||||
return t.ndim
|
||||
|
||||
|
||||
def shape(t):
|
||||
return list(t.shape)
|
||||
|
||||
|
||||
def transpose(t, axis):
|
||||
if torch.is_tensor(t):
|
||||
return t.permute(axis)
|
||||
else:
|
||||
return np.transpose(t, axis)
|
||||
|
||||
|
||||
def apply_recursive(d, fn, filter=None):
|
||||
if isinstance(d, list):
|
||||
return [apply_recursive(da, fn) for da in d]
|
||||
elif isinstance(d, tuple):
|
||||
return tuple(apply_recursive(list(d), fn))
|
||||
elif isinstance(d, dict):
|
||||
return {k: apply_recursive(v, fn) for k, v in d.items()}
|
||||
else:
|
||||
if filter is None or filter(d):
|
||||
return fn(d)
|
||||
else:
|
||||
return d
|
||||
|
||||
|
||||
def apply_to_tensors(d, fn):
|
||||
return apply_recursive(d, fn, torch.is_tensor)
|
||||
|
||||
|
||||
def recursive_decorator(apply_this_fn):
|
||||
def decorator(func):
|
||||
def wrapped_funct(*args, **kwargs):
|
||||
args = apply_recursive(args, apply_this_fn)
|
||||
kwargs = apply_recursive(kwargs, apply_this_fn)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped_funct
|
||||
|
||||
return decorator
|
||||
|
||||
untensor = recursive_decorator(to_numpy)
|
||||
unnumpy = recursive_decorator(to_list)
|
||||
|
||||
def unbatch(only_if_dim_equal=None):
|
||||
if only_if_dim_equal is not None and not isinstance(only_if_dim_equal, list):
|
||||
only_if_dim_equal = [only_if_dim_equal]
|
||||
|
||||
def get_first_batch(t):
|
||||
if is_tensor(t) and (only_if_dim_equal is None or ndim(t) in only_if_dim_equal):
|
||||
return t[0]
|
||||
else:
|
||||
return t
|
||||
|
||||
return recursive_decorator(get_first_batch)
|
||||
|
||||
|
||||
def sigmoid(t):
|
||||
if torch.is_tensor(t):
|
||||
return torch.sigmoid(t)
|
||||
else:
|
||||
return 1.0 / (1.0 + np.exp(-t))
|
||||
|
||||
|
||||
def argmax(t, dim):
|
||||
if torch.is_tensor(t):
|
||||
_, res = t.max(dim)
|
||||
else:
|
||||
res = np.argmax(t, axis=dim)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def flip(t, axis):
|
||||
if torch.is_tensor(t):
|
||||
return t.flip(axis)
|
||||
else:
|
||||
return np.flip(t, axis)
|
||||
|
||||
|
||||
def transpose(t, axes):
|
||||
if torch.is_tensor(t):
|
||||
return t.permute(*axes)
|
||||
else:
|
||||
return np.transpose(t, axes)
|
||||
|
||||
|
||||
def split_n(t, axis):
|
||||
if torch.is_tensor(t):
|
||||
return t.split(1, dim=axis)
|
||||
else:
|
||||
return np.split(t, t.shape[axis], axis=axis)
|
||||
|
||||
|
||||
def cat(array_of_tensors, axis):
|
||||
if torch.is_tensor(array_of_tensors[0]):
|
||||
return torch.cat(array_of_tensors, axis)
|
||||
else:
|
||||
return np.concatenate(array_of_tensors, axis)
|
||||
|
||||
|
||||
def clamp(t, min=None, max=None):
|
||||
if torch.is_tensor(t):
|
||||
return t.clamp(min, max)
|
||||
else:
|
||||
if min is not None:
|
||||
t = np.maximum(t, min)
|
||||
|
||||
if max is not None:
|
||||
t = np.minimum(t, max)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def power(t, p):
|
||||
if torch.is_tensor(t) or torch.is_tensor(p):
|
||||
return torch.pow(t, p)
|
||||
else:
|
||||
return np.power(t, p)
|
||||
|
||||
|
||||
def random_normal_as(a, mean, std, seed=None):
|
||||
if torch.is_tensor(a):
|
||||
return torch.randn_like(a) * std + mean
|
||||
else:
|
||||
if seed is None:
|
||||
seed = np.random
|
||||
return seed.normal(loc=mean, scale=std, size=shape(a))
|
||||
|
||||
|
||||
def pad(t, pad):
|
||||
assert ndim(t) == 4
|
||||
|
||||
if torch.is_tensor(t):
|
||||
return F.pad(t, pad)
|
||||
else:
|
||||
assert np.pad(t, ([0,0], [0,0], pad[0:2], pad[2:]))
|
||||
|
||||
|
||||
def dx(img):
|
||||
lsh = img[:, :, :, 2:]
|
||||
orig = img[:, :, :, :-2]
|
||||
|
||||
return pad(0.5 * (lsh - orig), (1, 1, 0, 0))
|
||||
|
||||
|
||||
def dy(img):
|
||||
ush = img[:, :, 2:, :]
|
||||
orig = img[:, :, :-2, :]
|
||||
|
||||
return pad(0.5 * (ush - orig), (0, 0, 1, 1))
|
||||
|
||||
|
||||
def reshape(t, shape):
|
||||
if torch.is_tensor(t):
|
||||
return t.view(*shape)
|
||||
else:
|
||||
return t.reshape(*shape)
|
||||
|
||||
|
||||
def broadcast_to_beginning(t, target):
|
||||
if torch.is_tensor(t):
|
||||
nd_target = target.dim()
|
||||
t_shape = list(t.shape)
|
||||
return t.view(*t_shape, *([1]*(nd_target-len(t_shape))))
|
||||
else:
|
||||
nd_target = target.ndim
|
||||
t_shape = list(t.shape)
|
||||
return t.reshape(*t_shape, *([1] * (nd_target - len(t_shape))))
|
||||
|
||||
|
||||
def lin_combine(d1,w1, d2,w2, bcast_begin=False):
|
||||
if isinstance(d1, (list, tuple)):
|
||||
assert len(d1) == len(d2)
|
||||
res = [lin_combine(d1[i], w1, d2[i], w2) for i in range(len(d1))]
|
||||
if isinstance(d1, tuple):
|
||||
res = tuple(d1)
|
||||
elif isinstance(d1, dict):
|
||||
res = {k: lin_combine(v, w1, d2[k], w2) for k, v in d1.items()}
|
||||
else:
|
||||
if bcast_begin:
|
||||
w1 = broadcast_to_beginning(w1, d1)
|
||||
w2 = broadcast_to_beginning(w2, d2)
|
||||
|
||||
res = d1 * w1 + d2 * w2
|
||||
|
||||
return res
|
BIN
Visualize/.DS_Store
vendored
Normal file
BIN
Visualize/.DS_Store
vendored
Normal file
Binary file not shown.
73
Visualize/BitmapTask.py
Normal file
73
Visualize/BitmapTask.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
try:
|
||||
import cv2
|
||||
except:
|
||||
cv2=None
|
||||
|
||||
from Utils.Helpers import as_numpy
|
||||
|
||||
def visualize_bitmap_task(i_data, o_data, zoom=8):
|
||||
if not isinstance(o_data, list):
|
||||
o_data = [o_data]
|
||||
|
||||
imgs = []
|
||||
for d in [i_data]+o_data:
|
||||
if d is None:
|
||||
continue
|
||||
|
||||
d=as_numpy(d)
|
||||
if d.ndim>2:
|
||||
d=d[0]
|
||||
|
||||
imgs.append(np.expand_dims(d.T*255, -1).astype(np.uint8))
|
||||
|
||||
img = np.concatenate(imgs, 0)
|
||||
return nearest_zoom(img, zoom)
|
||||
|
||||
def visualize_01(t, zoom=8):
|
||||
return nearest_zoom(np.expand_dims(t*255,-1).astype(np.uint8), zoom)
|
||||
|
||||
def nearest_zoom(img, zoom=1):
|
||||
if zoom>1 and cv2 is not None:
|
||||
return cv2.resize(img, (img.shape[1] * zoom, img.shape[0] * zoom), interpolation=cv2.INTER_NEAREST)
|
||||
else:
|
||||
return img
|
||||
|
||||
def concatenate_tensors(tensors):
|
||||
max_size = None
|
||||
dtype = None
|
||||
|
||||
for t in tensors:
|
||||
s = t.shape
|
||||
if max_size is None:
|
||||
max_size = list(s)
|
||||
dtype = t.dtype
|
||||
continue
|
||||
|
||||
assert t.ndim ==len(max_size), "Can only concatenate tensors with same ndim."
|
||||
assert t.dtype == dtype, "Tensors must have the same type"
|
||||
max_size = [max(max_size[i], s[i]) for i in range(len(max_size))]
|
||||
|
||||
res = np.zeros([len(tensors)] + max_size, dtype=dtype)
|
||||
for i, t in enumerate(tensors):
|
||||
res[i][tuple([slice(0,t.shape[i]) for i in range(t.ndim)])] = t
|
||||
return res
|
||||
|
||||
|
||||
|
1
Visualize/__init__.py
Normal file
1
Visualize/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .BitmapTask import *
|
64
Visualize/preview.py
Normal file
64
Visualize/preview.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ==============================================================================
|
||||
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
import sys
|
||||
from Utils import universal as U
|
||||
|
||||
|
||||
def preview(vis_interval=None, to_numpy=True, debatch=False):
|
||||
def decorator(func):
|
||||
state = {
|
||||
"last_vis_time": 0,
|
||||
"thread_running": False
|
||||
}
|
||||
|
||||
def wrapper_function(*args, **kwargs):
|
||||
if state["thread_running"]:
|
||||
return
|
||||
|
||||
if vis_interval is not None:
|
||||
now = time.time()
|
||||
if now - state["last_vis_time"] < vis_interval:
|
||||
return
|
||||
state["last_vis_time"] = now
|
||||
|
||||
state["thread_running"] = True
|
||||
|
||||
if debatch:
|
||||
args = U.apply_recursive(args, U.first_batch)
|
||||
kwargs = U.apply_recursive(kwargs, U.first_batch)
|
||||
|
||||
if to_numpy:
|
||||
args = U.apply_recursive(args, U.to_numpy)
|
||||
kwargs = U.apply_recursive(kwargs, U.to_numpy)
|
||||
|
||||
def run():
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
sys.exit(-1)
|
||||
|
||||
state["thread_running"] = False
|
||||
|
||||
download_thread = threading.Thread(target=run)
|
||||
download_thread.start()
|
||||
|
||||
return wrapper_function
|
||||
return decorator
|
BIN
assets/demon.png
Normal file
BIN
assets/demon.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 90 KiB |
1070
memory_demon.py
Normal file
1070
memory_demon.py
Normal file
File diff suppressed because it is too large
Load Diff
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
tqdm
|
||||
torch
|
||||
visdom
|
||||
numpy
|
||||
tensorboard
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user