dnc-with-demon/Utils/ArgumentParser.py

168 lines
5.4 KiB
Python
Raw Normal View History

2022-11-06 05:59:40 +08:00
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import os
import json
import argparse
class ArgumentParser:
class Parsed:
pass
_type = type
@staticmethod
def str_or_none(none_string="none"):
def parse(s):
return None if s.lower()==none_string else s
return parse
@staticmethod
def list_or_none(none_string="none", type=int):
def parse(s):
return None if s.lower() == none_string else [type(a) for a in s.split(",") if a]
return parse
@staticmethod
def _merge_args(args, new_args, arg_schemas):
for name, val in new_args.items():
old = args.get(name)
if old is None:
args[name] = val
else:
args[name] = arg_schemas[name]["updater"](old, val)
class Profile:
def __init__(self, name, args=None, include=[]):
assert not (args is None and not include), "One of args or include must be defined"
self.name = name
self.args = args
if not isinstance(include, list):
include=[include]
self.include = include
def get_args(self, arg_schemas, profile_by_name):
res = {}
for n in self.include:
p = profile_by_name.get(n)
assert p is not None, "Included profile %s doesn't exists" % n
ArgumentParser._merge_args(res, p.get_args(arg_schemas, profile_by_name), arg_schemas)
ArgumentParser._merge_args(res, self.args, arg_schemas)
return res
def __init__(self, description=None):
self.parser = argparse.ArgumentParser(description=description)
self.loaded = {}
self.profiles = {}
self.args = {}
self.raw = None
self.parsed = None
self.parser.add_argument("-profile", type=str, help="Pre-defined profiles.")
def add_argument(self, name, type=None, default=None, help="", save=True, parser=lambda x:x, updater=lambda old, new:new):
assert name not in ["profile"], "Argument name %s is reserved" % name
assert not (type is None and default is None), "Either type or default must be given"
if type is None:
type = ArgumentParser._type(default)
self.parser.add_argument(name, type=int if type==bool else type, default=None, help=help)
if name[0] == '-':
name = name[1:]
self.args[name] = {
"type": type,
"default": int(default) if type==bool else default,
"save": save,
"parser": parser,
"updater": updater
}
def add_profile(self, prof):
if isinstance(prof, list):
for p in prof:
self.add_profile(p)
else:
self.profiles[prof.name] = prof
def do_parse_args(self, loaded={}):
self.raw = self.parser.parse_args()
profile = {}
if self.raw.profile:
assert not loaded, "Loading arguments from file, but profile given."
for pr in self.raw.profile.split(","):
p = self.profiles.get(pr)
assert p is not None, "Invalid profile: %s. Valid profiles: %s" % (pr, self.profiles.keys())
p = p.get_args(self.args, self.profiles)
self._merge_args(profile, p, self.args)
for k, v in self.raw.__dict__.items():
if k in ["profile"]:
continue
if v is None:
if k in loaded and self.args[k]["save"]:
self.raw.__dict__[k] = loaded[k]
else:
self.raw.__dict__[k] = profile.get(k, self.args[k]["default"])
self.parsed = ArgumentParser.Parsed()
self.parsed.__dict__.update({k: self.args[k]["parser"](self.args[k]["type"](v)) if v is not None else None
for k,v in self.raw.__dict__.items() if k in self.args})
return self.parsed
def parse_or_cache(self):
if self.parsed is None:
self.do_parse_args()
def parse(self):
self.parse_or_cache()
return self.parsed
def save(self, fname):
self.parse_or_cache()
with open(fname, 'w') as outfile:
json.dump(self.raw.__dict__, outfile, indent=4)
return True
def load(self, fname):
if os.path.isfile(fname):
map = {}
with open(fname, "r") as data_file:
map = json.load(data_file)
self.do_parse_args(map)
return self.parsed
def sync(self, fname, save=True):
if os.path.isfile(fname):
self.load(fname)
if save:
dir = os.path.dirname(fname)
if not os.path.isdir(dir):
os.makedirs(dir)
self.save(fname)
return self.parsed