update data memorizer and test

This commit is contained in:
Joerg Franke 2018-06-25 11:03:29 +02:00
parent b5fb34aa90
commit 28a98d3460
2 changed files with 18 additions and 11 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import os
import pathlib
import pickle
import hashlib
from collections import OrderedDict
@ -21,27 +22,33 @@ from collections import OrderedDict
class DataMemorizer():
def __init__(self, config, tmp_dir):
self.hash = self.make_config_hash(config)
self.tmp_dir = tmp_dir
self.hash_name = self.make_config_hash(config)
if isinstance(tmp_dir, pathlib.Path):
self.tmp_dir = tmp_dir
else:
self.tmp_dir = pathlib.Path(tmp_dir)
if not self.tmp_dir.exists():
self.tmp_dir.mkdir(parents=True, exist_ok=True)
def __call__(self, *args, **kwargs):
return self.check_existent()
def check_existent(self):
file_name = os.path.join(self.tmp_dir, self.hash + '.pkl')
return os.path.isfile(file_name)
file_name = self.tmp_dir / self.hash_name
return file_name.exists()
def load_data(self):
with open(os.path.join(self.tmp_dir, self.hash + '.pkl'), 'rb') as outfile:
with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile:
data = pickle.load(outfile)
return data
def dump_data(self, data_to_save):
with open(os.path.join(self.tmp_dir, self.hash + '.pkl'), 'wb') as outfile:
with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile:
pickle.dump(data_to_save, outfile)
def purge_data(self):
file_name = os.path.join(self.tmp_dir, self.hash + '.pkl')
file_name = str(self.tmp_dir / self.hash_name)
if os.path.isfile(file_name):
os.remove(file_name)
@ -60,4 +67,4 @@ class DataMemorizer():
hash_object = hashlib.md5(str(sort_dict).encode())
hash = str(hash_object.hexdigest())
return hash
return hash + '.pkl'

View File

@ -25,7 +25,7 @@ def tmp_dir():
tmp_dir = TMP_DIR
pathlib.Path(tmp_dir).mkdir(parents=True, exist_ok=True)
yield tmp_dir
shutil.rmtree
shutil.rmtree(tmp_dir)
class TestDataMemorizer():
@ -38,8 +38,8 @@ class TestDataMemorizer():
data_memory_2 = DataMemorizer(hash_config_2, tmp_dir)
data_memory_3 = DataMemorizer(hash_config_3, tmp_dir)
assert data_memory_1.hash == data_memory_2.hash
assert data_memory_1.hash != data_memory_3.hash
assert data_memory_1.hash_name == data_memory_2.hash_name
assert data_memory_1.hash_name != data_memory_3.hash_name
def test_data_memorizing(self, tmp_dir):
hash_config = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 123}