168 lines
5.4 KiB
Python
168 lines
5.4 KiB
Python
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
#
|
||
|
# ==============================================================================
|
||
|
|
||
|
import os
|
||
|
import json
|
||
|
import argparse
|
||
|
|
||
|
|
||
|
class ArgumentParser:
|
||
|
class Parsed:
|
||
|
pass
|
||
|
|
||
|
_type = type
|
||
|
|
||
|
@staticmethod
|
||
|
def str_or_none(none_string="none"):
|
||
|
def parse(s):
|
||
|
return None if s.lower()==none_string else s
|
||
|
return parse
|
||
|
|
||
|
@staticmethod
|
||
|
def list_or_none(none_string="none", type=int):
|
||
|
def parse(s):
|
||
|
return None if s.lower() == none_string else [type(a) for a in s.split(",") if a]
|
||
|
return parse
|
||
|
|
||
|
@staticmethod
|
||
|
def _merge_args(args, new_args, arg_schemas):
|
||
|
for name, val in new_args.items():
|
||
|
old = args.get(name)
|
||
|
if old is None:
|
||
|
args[name] = val
|
||
|
else:
|
||
|
args[name] = arg_schemas[name]["updater"](old, val)
|
||
|
|
||
|
class Profile:
|
||
|
def __init__(self, name, args=None, include=[]):
|
||
|
assert not (args is None and not include), "One of args or include must be defined"
|
||
|
self.name = name
|
||
|
self.args = args
|
||
|
if not isinstance(include, list):
|
||
|
include=[include]
|
||
|
self.include = include
|
||
|
|
||
|
def get_args(self, arg_schemas, profile_by_name):
|
||
|
res = {}
|
||
|
|
||
|
for n in self.include:
|
||
|
p = profile_by_name.get(n)
|
||
|
assert p is not None, "Included profile %s doesn't exists" % n
|
||
|
|
||
|
ArgumentParser._merge_args(res, p.get_args(arg_schemas, profile_by_name), arg_schemas)
|
||
|
|
||
|
ArgumentParser._merge_args(res, self.args, arg_schemas)
|
||
|
return res
|
||
|
|
||
|
|
||
|
def __init__(self, description=None):
|
||
|
self.parser = argparse.ArgumentParser(description=description)
|
||
|
self.loaded = {}
|
||
|
self.profiles = {}
|
||
|
self.args = {}
|
||
|
self.raw = None
|
||
|
self.parsed = None
|
||
|
self.parser.add_argument("-profile", type=str, help="Pre-defined profiles.")
|
||
|
|
||
|
def add_argument(self, name, type=None, default=None, help="", save=True, parser=lambda x:x, updater=lambda old, new:new):
|
||
|
assert name not in ["profile"], "Argument name %s is reserved" % name
|
||
|
assert not (type is None and default is None), "Either type or default must be given"
|
||
|
|
||
|
if type is None:
|
||
|
type = ArgumentParser._type(default)
|
||
|
|
||
|
self.parser.add_argument(name, type=int if type==bool else type, default=None, help=help)
|
||
|
if name[0] == '-':
|
||
|
name = name[1:]
|
||
|
|
||
|
self.args[name] = {
|
||
|
"type": type,
|
||
|
"default": int(default) if type==bool else default,
|
||
|
"save": save,
|
||
|
"parser": parser,
|
||
|
"updater": updater
|
||
|
}
|
||
|
|
||
|
def add_profile(self, prof):
|
||
|
if isinstance(prof, list):
|
||
|
for p in prof:
|
||
|
self.add_profile(p)
|
||
|
else:
|
||
|
self.profiles[prof.name] = prof
|
||
|
|
||
|
def do_parse_args(self, loaded={}):
|
||
|
self.raw = self.parser.parse_args()
|
||
|
|
||
|
profile = {}
|
||
|
if self.raw.profile:
|
||
|
assert not loaded, "Loading arguments from file, but profile given."
|
||
|
for pr in self.raw.profile.split(","):
|
||
|
p = self.profiles.get(pr)
|
||
|
assert p is not None, "Invalid profile: %s. Valid profiles: %s" % (pr, self.profiles.keys())
|
||
|
p = p.get_args(self.args, self.profiles)
|
||
|
self._merge_args(profile, p, self.args)
|
||
|
|
||
|
for k, v in self.raw.__dict__.items():
|
||
|
if k in ["profile"]:
|
||
|
continue
|
||
|
|
||
|
if v is None:
|
||
|
if k in loaded and self.args[k]["save"]:
|
||
|
self.raw.__dict__[k] = loaded[k]
|
||
|
else:
|
||
|
self.raw.__dict__[k] = profile.get(k, self.args[k]["default"])
|
||
|
|
||
|
self.parsed = ArgumentParser.Parsed()
|
||
|
self.parsed.__dict__.update({k: self.args[k]["parser"](self.args[k]["type"](v)) if v is not None else None
|
||
|
for k,v in self.raw.__dict__.items() if k in self.args})
|
||
|
|
||
|
return self.parsed
|
||
|
|
||
|
def parse_or_cache(self):
|
||
|
if self.parsed is None:
|
||
|
self.do_parse_args()
|
||
|
|
||
|
def parse(self):
|
||
|
self.parse_or_cache()
|
||
|
return self.parsed
|
||
|
|
||
|
def save(self, fname):
|
||
|
self.parse_or_cache()
|
||
|
with open(fname, 'w') as outfile:
|
||
|
json.dump(self.raw.__dict__, outfile, indent=4)
|
||
|
return True
|
||
|
|
||
|
def load(self, fname):
|
||
|
if os.path.isfile(fname):
|
||
|
map = {}
|
||
|
with open(fname, "r") as data_file:
|
||
|
map = json.load(data_file)
|
||
|
|
||
|
self.do_parse_args(map)
|
||
|
return self.parsed
|
||
|
|
||
|
def sync(self, fname, save=True):
|
||
|
if os.path.isfile(fname):
|
||
|
self.load(fname)
|
||
|
|
||
|
if save:
|
||
|
dir = os.path.dirname(fname)
|
||
|
if not os.path.isdir(dir):
|
||
|
os.makedirs(dir)
|
||
|
|
||
|
self.save(fname)
|
||
|
return self.parsed
|