add data utils and test

This commit is contained in:
Joerg Franke 2018-06-24 16:20:29 +02:00
parent c1dd5a0935
commit 453ffb9b58
6 changed files with 127 additions and 0 deletions

View File

View File

@ -0,0 +1,67 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import threading
class BatchGenerator():
def __init__(self, data_set, set, batch_size, shuffle=True, max_len=False):
self.set = set
self.data_set = data_set
self.batch_size = batch_size
self.sample_amount = self.data_set.sample_amount(self.set)
self.shuffle = shuffle
self.max_len = max_len
self.lock = threading.Lock()
if self.shuffle:
self.order = self.data_set.rng.permutation(np.arange(self.sample_amount))
else:
self.order = np.arange(self.sample_amount)
self.sample_count = 0
def shuffle_order(self):
self.order = self.data_set.rng.permutation(self.order)
def increase_sample_count(self):
with self.lock:
self.sample_count += 1
if self.sample_count >= self.sample_amount:
self.sample_count = 0
if self.shuffle:
self.order = self.data_set.rng.permutation(self.order)
def __iter__(self):
return self
def __next__(self):
batch_list = []
for b in range(self.batch_size):
sample = self.data_set.get_sample(self.set, self.order[self.sample_count])
while self.max_len and sample['x'].shape[0] > self.max_len:
self.increase_sample_count()
sample = self.data_set.get_sample(self.set, self.order[self.sample_count])
batch_list.append(sample)
self.increase_sample_count()
batch = self.data_set.patch_batch(batch_list)
return batch

View File

View File

View File

@ -0,0 +1,60 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
import shutil
import pathlib
from adnc.data.utils.data_memorizer import DataMemorizer
TMP_DIR = '.tmp_dir'
@pytest.fixture()
def tmp_dir():
tmp_dir = TMP_DIR
pathlib.Path(tmp_dir).mkdir(parents=True, exist_ok=True)
yield tmp_dir
shutil.rmtree
class TestDataMemorizer():
def test_hashing(self, tmp_dir):
hash_config_1 = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 123}
hash_config_2 = {'seed': 123, 'set_types': 'tokens', 'target_mode': 'mode1'}
hash_config_3 = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 124}
data_memory_1 = DataMemorizer(hash_config_1, tmp_dir)
data_memory_2 = DataMemorizer(hash_config_2, tmp_dir)
data_memory_3 = DataMemorizer(hash_config_3, tmp_dir)
assert data_memory_1.hash == data_memory_2.hash
assert data_memory_1.hash != data_memory_3.hash
def test_data_memorizing(self, tmp_dir):
hash_config = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 123}
dummy_data = [{'dict': 'test'}, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 'string_test']
data_memory = DataMemorizer(hash_config, tmp_dir)
assert not data_memory()
data_memory.dump_data(dummy_data)
assert data_memory()
dict_dummy, list_dummy, str_dummy = data_memory.load_data()
assert dict_dummy['dict'] == 'test'
assert list_dummy[4] == 4
assert str_dummy == 'string_test'
data_memory.purge_data()
assert not data_memory()