diff --git a/adnc/data/utils/data_memorizer.py b/adnc/data/utils/data_memorizer.py index 4af8387..b3765e0 100755 --- a/adnc/data/utils/data_memorizer.py +++ b/adnc/data/utils/data_memorizer.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import os +import pathlib import pickle import hashlib from collections import OrderedDict @@ -21,27 +22,33 @@ from collections import OrderedDict class DataMemorizer(): def __init__(self, config, tmp_dir): - self.hash = self.make_config_hash(config) - self.tmp_dir = tmp_dir + self.hash_name = self.make_config_hash(config) + if isinstance(tmp_dir, pathlib.Path): + self.tmp_dir = tmp_dir + else: + self.tmp_dir = pathlib.Path(tmp_dir) + + if not self.tmp_dir.exists(): + self.tmp_dir.mkdir(parents=True, exist_ok=True) def __call__(self, *args, **kwargs): return self.check_existent() def check_existent(self): - file_name = os.path.join(self.tmp_dir, self.hash + '.pkl') - return os.path.isfile(file_name) + file_name = self.tmp_dir / self.hash_name + return file_name.exists() def load_data(self): - with open(os.path.join(self.tmp_dir, self.hash + '.pkl'), 'rb') as outfile: + with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile: data = pickle.load(outfile) return data def dump_data(self, data_to_save): - with open(os.path.join(self.tmp_dir, self.hash + '.pkl'), 'wb') as outfile: + with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile: pickle.dump(data_to_save, outfile) def purge_data(self): - file_name = os.path.join(self.tmp_dir, self.hash + '.pkl') + file_name = str(self.tmp_dir / self.hash_name) if os.path.isfile(file_name): os.remove(file_name) @@ -60,4 +67,4 @@ class DataMemorizer(): hash_object = hashlib.md5(str(sort_dict).encode()) hash = str(hash_object.hexdigest()) - return hash + return hash + '.pkl' diff --git a/test/adnc/data/utils/test_data_memorizer.py b/test/adnc/data/utils/test_data_memorizer.py index 248def3..d7c4d5d 100755 --- a/test/adnc/data/utils/test_data_memorizer.py +++ b/test/adnc/data/utils/test_data_memorizer.py @@ -25,7 +25,7 @@ def tmp_dir(): tmp_dir = TMP_DIR pathlib.Path(tmp_dir).mkdir(parents=True, exist_ok=True) yield tmp_dir - shutil.rmtree + shutil.rmtree(tmp_dir) class TestDataMemorizer(): @@ -38,8 +38,8 @@ class TestDataMemorizer(): data_memory_2 = DataMemorizer(hash_config_2, tmp_dir) data_memory_3 = DataMemorizer(hash_config_3, tmp_dir) - assert data_memory_1.hash == data_memory_2.hash - assert data_memory_1.hash != data_memory_3.hash + assert data_memory_1.hash_name == data_memory_2.hash_name + assert data_memory_1.hash_name != data_memory_3.hash_name def test_data_memorizing(self, tmp_dir): hash_config = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 123}