diff --git a/bin/train.py b/bin/train.py index 3de61f7..f5fb54f 100644 --- a/bin/train.py +++ b/bin/train.py @@ -153,7 +153,7 @@ def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("config_json", type=str, help="json of params") + parser.add_argument("-c", "--config", type=str, help="json of params") parser.add_argument( "-o", "--outdir", @@ -173,8 +173,10 @@ def main(): args = parser.parse_args() # Load in parameters and run training loop - with open(args.config_json) as source: - config_args = json.load(source) + config_args = {} # Empty dictionary as default + if args.config: + with open(args.config) as source: + config_args = json.load(source) train(results_dir=args.outdir, toy=args.toy, **config_args)