mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
update data memorizer and test
This commit is contained in:
parent
b5fb34aa90
commit
28a98d3460
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import os
|
||||
import pathlib
|
||||
import pickle
|
||||
import hashlib
|
||||
from collections import OrderedDict
|
||||
@ -21,27 +22,33 @@ from collections import OrderedDict
|
||||
class DataMemorizer():
|
||||
def __init__(self, config, tmp_dir):
|
||||
|
||||
self.hash = self.make_config_hash(config)
|
||||
self.hash_name = self.make_config_hash(config)
|
||||
if isinstance(tmp_dir, pathlib.Path):
|
||||
self.tmp_dir = tmp_dir
|
||||
else:
|
||||
self.tmp_dir = pathlib.Path(tmp_dir)
|
||||
|
||||
if not self.tmp_dir.exists():
|
||||
self.tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.check_existent()
|
||||
|
||||
def check_existent(self):
|
||||
file_name = os.path.join(self.tmp_dir, self.hash + '.pkl')
|
||||
return os.path.isfile(file_name)
|
||||
file_name = self.tmp_dir / self.hash_name
|
||||
return file_name.exists()
|
||||
|
||||
def load_data(self):
|
||||
with open(os.path.join(self.tmp_dir, self.hash + '.pkl'), 'rb') as outfile:
|
||||
with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile:
|
||||
data = pickle.load(outfile)
|
||||
return data
|
||||
|
||||
def dump_data(self, data_to_save):
|
||||
with open(os.path.join(self.tmp_dir, self.hash + '.pkl'), 'wb') as outfile:
|
||||
with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile:
|
||||
pickle.dump(data_to_save, outfile)
|
||||
|
||||
def purge_data(self):
|
||||
file_name = os.path.join(self.tmp_dir, self.hash + '.pkl')
|
||||
file_name = str(self.tmp_dir / self.hash_name)
|
||||
if os.path.isfile(file_name):
|
||||
os.remove(file_name)
|
||||
|
||||
@ -60,4 +67,4 @@ class DataMemorizer():
|
||||
|
||||
hash_object = hashlib.md5(str(sort_dict).encode())
|
||||
hash = str(hash_object.hexdigest())
|
||||
return hash
|
||||
return hash + '.pkl'
|
||||
|
@ -25,7 +25,7 @@ def tmp_dir():
|
||||
tmp_dir = TMP_DIR
|
||||
pathlib.Path(tmp_dir).mkdir(parents=True, exist_ok=True)
|
||||
yield tmp_dir
|
||||
shutil.rmtree
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
|
||||
class TestDataMemorizer():
|
||||
@ -38,8 +38,8 @@ class TestDataMemorizer():
|
||||
data_memory_2 = DataMemorizer(hash_config_2, tmp_dir)
|
||||
data_memory_3 = DataMemorizer(hash_config_3, tmp_dir)
|
||||
|
||||
assert data_memory_1.hash == data_memory_2.hash
|
||||
assert data_memory_1.hash != data_memory_3.hash
|
||||
assert data_memory_1.hash_name == data_memory_2.hash_name
|
||||
assert data_memory_1.hash_name != data_memory_3.hash_name
|
||||
|
||||
def test_data_memorizing(self, tmp_dir):
|
||||
hash_config = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 123}
|
||||
|
Loading…
Reference in New Issue
Block a user