dnc/Utils/Saver.py

234 lines
7.0 KiB
Python
Raw Normal View History

2018-11-16 03:31:23 +08:00
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import os
import inspect
import time
class SaverElement:
def save(self):
raise NotImplementedError
def load(self, saved_state):
raise NotImplementedError
class CallbackSaver(SaverElement):
def __init__(self, save_fn, load_fn):
super().__init__()
self.save = save_fn
self.load = load_fn
class StateSaver(SaverElement):
def __init__(self, model):
super().__init__()
self._model = model
def load(self, state):
try:
self._model.load_state_dict(state)
except Exception as e:
if hasattr(self._model, "named_parameters"):
names = set([n for n, _ in self._model.named_parameters()])
loaded = set(self._model.keys())
if names!=loaded:
d = loaded.difference(names)
if d:
print("Loaded, but not in model: %s" % list(d))
d = names.difference(loaded)
if d:
print("In model, but not loaded: %s" % list(d))
if isinstance(self._model, torch.optim.Optimizer):
print("WARNING: optimizer parameters not loaded!")
else:
raise e
def save(self):
return self._model.state_dict()
class GlobalVarSaver(SaverElement):
def __init__(self, name):
caller_frame = inspect.getouterframes(inspect.currentframe())[1]
self._vars = caller_frame.frame.f_globals
self._name = name
def load(self, state):
self._vars.update({self._name: state})
def save(self):
return self._vars[self._name]
class PyObjectSaver(SaverElement):
def __init__(self, obj):
self._obj = obj
def load(self, state):
def _load(target, state):
if isinstance(target, dict):
for k, v in state.items():
target[k] = _load(target.get(k), v)
elif isinstance(target, list):
if len(target)!=len(state):
target.clear()
for v in state:
target.append(v)
else:
for i, v in enumerate(state):
target[i] = _load(target[i], v)
elif hasattr(target, "__dict__"):
_load(target.__dict__, state)
else:
return state
return target
_load(self._obj, state)
def save(self):
def _save(target):
if isinstance(target, dict):
res = {k: _save(v) for k, v in target.items()}
elif isinstance(target, list):
res = [_save(v) for v in target]
elif hasattr(target, "__dict__"):
res = {k: _save(v) for k, v in target.__dict__.items()}
else:
res = target
return res
return _save(self._obj)
@staticmethod
def obj_supported(obj):
return isinstance(obj, (list, dict)) or hasattr(obj, "__dict__")
class Saver:
def __init__(self, dir, short_interval, keep_every_n_hours=4):
self.savers = {}
self.short_interval = short_interval
os.makedirs(dir, exist_ok=True)
self.dir = dir
self._keep_every_n_seconds = keep_every_n_hours * 3600
def register(self, name, saver):
assert name not in self.savers, "Saver %s already registered" % name
if isinstance(saver, SaverElement):
self.savers[name] = saver
elif hasattr(saver, "state_dict") and callable(saver.state_dict):
self.savers[name] = StateSaver(saver)
elif PyObjectSaver.obj_supported(saver):
self.savers[name] = PyObjectSaver(saver)
else:
assert "Unsupported thing to save: %s" % type(saver)
def __setitem__(self, key, value):
self.register(key, value)
def write(self, iter):
fname = os.path.join(self.dir, self.model_name_from_index(iter))
print("Saving %s" % fname)
state = {}
for name, fns in self.savers.items():
state[name] = fns.save()
torch.save(state, fname)
print("Saved.")
self._cleanup()
def tick(self, iter):
if iter % self.short_interval != 0:
return
self.write(iter)
@staticmethod
def model_name_from_index(index):
return "model-%d.pth" % index
@staticmethod
def get_checkpoint_index_list(dir):
return list(reversed(sorted(
[int(fn.split(".")[0].split("-")[-1]) for fn in os.listdir(dir) if fn.split(".")[-1] == "pth"])))
@staticmethod
def get_ckpts_in_time_window(dir, time_window_s, index_list=None):
if index_list is None:
index_list = Saver.get_checkpoint_index_list(dir)
now = time.time()
res = []
for i in index_list:
name = Saver.model_name_from_index(i)
mtime = os.path.getmtime(os.path.join(dir, name))
if now - mtime > time_window_s:
break
res.append(name)
return res
@staticmethod
def load_last_checkpoint(dir):
last_checkpoint = Saver.get_checkpoint_index_list(dir)
if last_checkpoint:
for index in last_checkpoint:
fname = Saver.model_name_from_index(index)
try:
print("Loading %s" % fname)
data = torch.load(os.path.join(dir, fname))
except:
print("WARNING: Loading %s failed. Maybe file is corrupted?" % fname)
continue
return data
return None
def _cleanup(self):
index_list = self.get_checkpoint_index_list(self.dir)
new_files = self.get_ckpts_in_time_window(self.dir, self._keep_every_n_seconds, index_list[2:])
new_files = new_files[:-1]
for f in new_files:
os.remove(os.path.join(self.dir, f))
def load(self, fname=None):
if fname is None:
state = self.load_last_checkpoint(self.dir)
if not state:
return False
else:
state = torch.load(fname)
for k,s in state.items():
if k not in self.savers:
print("WARNING: failed to load state of %s. It doesn't exists." % k)
continue
self.savers[k].load(s)
return True