dnc-with-demon/Dataset/Bitmap/KeyValue2Way.py
2022-11-05 14:59:40 -07:00

90 lines
3.1 KiB
Python

# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import math
import numpy as np
from .BitmapTask import BitmapTask
from Utils.Seed import get_randstate
class KeyValue2Way(BitmapTask):
def __init__(self, length=None, bit_w=8, transform=lambda x: x):
assert bit_w % 2 == 0, "bit_w must be even"
super(KeyValue2Way, self).__init__()
self.length = length
self.bit_w = bit_w
self.transform = transform
self.seed = None
self.key_w = self.bit_w // 2
self.max_key = 2 ** self.key_w - 1
def __getitem__(self, key):
if self.seed is None:
self.seed = get_randstate()
if self.length is None:
# Random length batch hack.
length = key
else:
length = self.length() if callable(self.length) else self.length
# keys must be unique
keys = None
last_size = 0
while last_size != length:
res = self.seed.random_integers(0, self.max_key, size=(length - last_size))
if keys is not None:
keys = np.concatenate((res, keys))
else:
keys = res
keys = np.unique(keys)
last_size = keys.size
# view as bunch of uint8s, convert them to bit patterns, then cut the correct amount from it
keys = keys.view(np.uint8).reshape(length, -1)
keys = keys[:, : math.ceil(self.key_w / 8)]
keys = np.unpackbits(np.expand_dims(keys, -1), axis=-1)
keys = np.flip(keys, axis=-1).reshape(keys.shape[0], -1)[:, : self.key_w]
keys = keys.astype(np.float32)
values = self.seed.randint(0, 2, keys.shape).astype(np.float32)
perm = self.seed.permutation(length)
keys_perm = keys[perm, :]
values_perm = values[perm, :]
i_p = np.zeros((3 * (length + 1), self.bit_w + 2), dtype=np.float32)
o_p = np.zeros((3 * (length + 1), self.key_w), dtype=np.float32)
i_p[:length, : self.key_w] = keys
i_p[:length, self.key_w : -2] = values
i_p[length + 1 : 2 * length + 1, : self.key_w] = keys_perm
o_p[length + 1 : 2 * length + 1] = values_perm
perm = self.seed.permutation(length)
keys_perm = keys[perm, :]
values_perm = values[perm, :]
o_p[2 * (length + 1) : -1] = keys_perm
i_p[2 * (length + 1) : -1, : self.key_w] = values_perm
i_p[length, -2] = 1
i_p[2 * length + 1, -1] = 1
i_p[-1, -2:] = 1
return self.transform({"input": i_p, "output": o_p})