mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
update data memorizer and test
This commit is contained in:
parent
b5fb34aa90
commit
28a98d3460
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import pickle
|
import pickle
|
||||||
import hashlib
|
import hashlib
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@ -21,27 +22,33 @@ from collections import OrderedDict
|
|||||||
class DataMemorizer():
|
class DataMemorizer():
|
||||||
def __init__(self, config, tmp_dir):
|
def __init__(self, config, tmp_dir):
|
||||||
|
|
||||||
self.hash = self.make_config_hash(config)
|
self.hash_name = self.make_config_hash(config)
|
||||||
self.tmp_dir = tmp_dir
|
if isinstance(tmp_dir, pathlib.Path):
|
||||||
|
self.tmp_dir = tmp_dir
|
||||||
|
else:
|
||||||
|
self.tmp_dir = pathlib.Path(tmp_dir)
|
||||||
|
|
||||||
|
if not self.tmp_dir.exists():
|
||||||
|
self.tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.check_existent()
|
return self.check_existent()
|
||||||
|
|
||||||
def check_existent(self):
|
def check_existent(self):
|
||||||
file_name = os.path.join(self.tmp_dir, self.hash + '.pkl')
|
file_name = self.tmp_dir / self.hash_name
|
||||||
return os.path.isfile(file_name)
|
return file_name.exists()
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
with open(os.path.join(self.tmp_dir, self.hash + '.pkl'), 'rb') as outfile:
|
with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile:
|
||||||
data = pickle.load(outfile)
|
data = pickle.load(outfile)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def dump_data(self, data_to_save):
|
def dump_data(self, data_to_save):
|
||||||
with open(os.path.join(self.tmp_dir, self.hash + '.pkl'), 'wb') as outfile:
|
with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile:
|
||||||
pickle.dump(data_to_save, outfile)
|
pickle.dump(data_to_save, outfile)
|
||||||
|
|
||||||
def purge_data(self):
|
def purge_data(self):
|
||||||
file_name = os.path.join(self.tmp_dir, self.hash + '.pkl')
|
file_name = str(self.tmp_dir / self.hash_name)
|
||||||
if os.path.isfile(file_name):
|
if os.path.isfile(file_name):
|
||||||
os.remove(file_name)
|
os.remove(file_name)
|
||||||
|
|
||||||
@ -60,4 +67,4 @@ class DataMemorizer():
|
|||||||
|
|
||||||
hash_object = hashlib.md5(str(sort_dict).encode())
|
hash_object = hashlib.md5(str(sort_dict).encode())
|
||||||
hash = str(hash_object.hexdigest())
|
hash = str(hash_object.hexdigest())
|
||||||
return hash
|
return hash + '.pkl'
|
||||||
|
@ -25,7 +25,7 @@ def tmp_dir():
|
|||||||
tmp_dir = TMP_DIR
|
tmp_dir = TMP_DIR
|
||||||
pathlib.Path(tmp_dir).mkdir(parents=True, exist_ok=True)
|
pathlib.Path(tmp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
yield tmp_dir
|
yield tmp_dir
|
||||||
shutil.rmtree
|
shutil.rmtree(tmp_dir)
|
||||||
|
|
||||||
|
|
||||||
class TestDataMemorizer():
|
class TestDataMemorizer():
|
||||||
@ -38,8 +38,8 @@ class TestDataMemorizer():
|
|||||||
data_memory_2 = DataMemorizer(hash_config_2, tmp_dir)
|
data_memory_2 = DataMemorizer(hash_config_2, tmp_dir)
|
||||||
data_memory_3 = DataMemorizer(hash_config_3, tmp_dir)
|
data_memory_3 = DataMemorizer(hash_config_3, tmp_dir)
|
||||||
|
|
||||||
assert data_memory_1.hash == data_memory_2.hash
|
assert data_memory_1.hash_name == data_memory_2.hash_name
|
||||||
assert data_memory_1.hash != data_memory_3.hash
|
assert data_memory_1.hash_name != data_memory_3.hash_name
|
||||||
|
|
||||||
def test_data_memorizing(self, tmp_dir):
|
def test_data_memorizing(self, tmp_dir):
|
||||||
hash_config = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 123}
|
hash_config = {'set_types': 'tokens', 'target_mode': 'mode1', 'seed': 123}
|
||||||
|
Loading…
Reference in New Issue
Block a user