mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 22:08:04 +08:00
add data utils and test
This commit is contained in:
parent
c1dd5a0935
commit
453ffb9b58
0
adnc/data/utils/__init__.py
Normal file
0
adnc/data/utils/__init__.py
Normal file
67
adnc/data/utils/batch_generator.py
Executable file
67
adnc/data/utils/batch_generator.py
Executable file
@ -0,0 +1,67 @@
|
||||
# Copyright 2018 Jörg Franke
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
|
||||
class BatchGenerator():
|
||||
def __init__(self, data_set, set, batch_size, shuffle=True, max_len=False):
|
||||
|
||||
self.set = set
|
||||
self.data_set = data_set
|
||||
self.batch_size = batch_size
|
||||
self.sample_amount = self.data_set.sample_amount(self.set)
|
||||
self.shuffle = shuffle
|
||||
self.max_len = max_len
|
||||
|
||||
self.lock = threading.Lock()
|
||||
|
||||
if self.shuffle:
|
||||
self.order = self.data_set.rng.permutation(np.arange(self.sample_amount))
|
||||
else:
|
||||
self.order = np.arange(self.sample_amount)
|
||||
|
||||
self.sample_count = 0
|
||||
|
||||
def shuffle_order(self):
|
||||
self.order = self.data_set.rng.permutation(self.order)
|
||||
|
||||
def increase_sample_count(self):
|
||||
with self.lock:
|
||||
self.sample_count += 1
|
||||
if self.sample_count >= self.sample_amount:
|
||||
self.sample_count = 0
|
||||
if self.shuffle:
|
||||
self.order = self.data_set.rng.permutation(self.order)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
|
||||
batch_list = []
|
||||
for b in range(self.batch_size):
|
||||
|
||||
sample = self.data_set.get_sample(self.set, self.order[self.sample_count])
|
||||
|
||||
while self.max_len and sample['x'].shape[0] > self.max_len:
|
||||
self.increase_sample_count()
|
||||
sample = self.data_set.get_sample(self.set, self.order[self.sample_count])
|
||||
|
||||
batch_list.append(sample)
|
||||
self.increase_sample_count()
|
||||
|
||||
batch = self.data_set.patch_batch(batch_list)
|
||||
return batch
|
0
test/adnc/data/__init__.py
Normal file
0
test/adnc/data/__init__.py
Normal file
0
test/adnc/data/utils/__init__.py
Normal file
0
test/adnc/data/utils/__init__.py
Normal file
60
test/adnc/data/utils/test_data_memorizer.py
Executable file
60
test/adnc/data/utils/test_data_memorizer.py
Executable file
@ -0,0 +1,60 @@
|
||||
# Copyright 2018 Jörg Franke
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import shutil
|
||||
import pathlib
|
||||
from adnc.data.utils.data_memorizer import DataMemorizer
|
||||
|
||||
TMP_DIR = '.tmp_dir'
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_dir():
|
||||
tmp_dir = TMP_DIR
|
||||
pathlib.Path(tmp_dir).mkdir(parents=True, exist_ok=True)
|
||||
yield tmp_dir
|
||||
shutil.rmtree
|
||||
|
||||
|
||||
class TestDataMemorizer():
|
||||
def test_hashing(self, tmp_dir):
|
||||
hash_config_1 = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 123}
|
||||
hash_config_2 = {'seed': 123, 'set_types': 'tokens', 'target_mode': 'mode1'}
|
||||
hash_config_3 = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 124}
|
||||
|
||||
data_memory_1 = DataMemorizer(hash_config_1, tmp_dir)
|
||||
data_memory_2 = DataMemorizer(hash_config_2, tmp_dir)
|
||||
data_memory_3 = DataMemorizer(hash_config_3, tmp_dir)
|
||||
|
||||
assert data_memory_1.hash == data_memory_2.hash
|
||||
assert data_memory_1.hash != data_memory_3.hash
|
||||
|
||||
def test_data_memorizing(self, tmp_dir):
|
||||
hash_config = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 123}
|
||||
dummy_data = [{'dict': 'test'}, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 'string_test']
|
||||
|
||||
data_memory = DataMemorizer(hash_config, tmp_dir)
|
||||
assert not data_memory()
|
||||
|
||||
data_memory.dump_data(dummy_data)
|
||||
assert data_memory()
|
||||
|
||||
dict_dummy, list_dummy, str_dummy = data_memory.load_data()
|
||||
assert dict_dummy['dict'] == 'test'
|
||||
assert list_dummy[4] == 4
|
||||
assert str_dummy == 'string_test'
|
||||
|
||||
data_memory.purge_data()
|
||||
assert not data_memory()
|
Loading…
Reference in New Issue
Block a user