mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Hyperparam search params and generator
This commit is contained in:
148
scripts/scripts_from_hyper_json.py
Normal file
148
scripts/scripts_from_hyper_json.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Code to read in a json describing hyperparam sweep and create
|
||||
scripts to execute each combination
|
||||
"""
|
||||
|
||||
import os, sys
|
||||
import logging
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
from typing import *
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def params_to_cli_args(param_dict: Dict[str, Any]) -> str:
|
||||
"""Format the params to CLI arguments"""
|
||||
tokens = []
|
||||
for k, v in param_dict.items():
|
||||
prefix = "-" if len(k) == 1 else "--"
|
||||
if isinstance(v, bool): # For booleans, treat as flags
|
||||
if v:
|
||||
tokens.append(f"{prefix}{k}")
|
||||
else:
|
||||
tokens.append(f"{prefix}{k} {v}")
|
||||
# Manually add this since we can't provide this in the json
|
||||
tokens.append(f"--outdir {params_to_filename(param_dict)}")
|
||||
retval = " ".join(tokens)
|
||||
return retval
|
||||
|
||||
|
||||
def params_to_filename(
|
||||
param_dict: Dict[str, Any],
|
||||
blacklist_tokens: Collection[str] = {"model", "blacklist", "config", "pretrained"},
|
||||
) -> str:
|
||||
"""
|
||||
Format the params (in key-value pairs (param, arg)) into a filename
|
||||
blacklist_tokens species the params that are excluded from filename (typically path-like)
|
||||
"""
|
||||
tokens = []
|
||||
for k, v in param_dict.items():
|
||||
if k in blacklist_tokens:
|
||||
logging.warning(f"Excluding {k}: {v} from generated fname")
|
||||
continue
|
||||
assert " " not in k, f"Parameter cannot have space: {k}"
|
||||
if isinstance(v, str) and " " in v: # Replace spaces in value
|
||||
v = v.replace(" ", "_")
|
||||
# This is (value, key) for historical readability reasons
|
||||
tokens.append(f"{v}_{k}")
|
||||
return "_".join(tokens)
|
||||
|
||||
|
||||
def build_parser():
|
||||
"""Build argument parser"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("exec", type=str, help="Executable to run")
|
||||
parser.add_argument(
|
||||
"json_config", type=str, help="Config file specifying hyperparams"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mode",
|
||||
choices=["shellargs", "json"],
|
||||
default="json",
|
||||
help="Write args as shell args or as json file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--num", type=int, default=1, help="Number of replicates to run"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--outdir",
|
||||
type=str,
|
||||
default=os.getcwd(),
|
||||
help="Directory to write shell files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--template", type=str, default="", help="template to append to",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
"""Run script"""
|
||||
args = build_parser().parse_args()
|
||||
assert os.path.isfile(args.exec)
|
||||
assert os.path.isfile(args.json_config)
|
||||
|
||||
if not os.path.isdir(args.outdir):
|
||||
logging.info(f"Creating output directory: {args.outdir}")
|
||||
os.makedirs(args.outdir)
|
||||
|
||||
# Read in the template
|
||||
header_lines = []
|
||||
if args.template:
|
||||
with open(args.template) as source:
|
||||
header_lines = [l.strip() for l in source]
|
||||
|
||||
# Read in the hypeparameters to sweep over
|
||||
with open(args.json_config) as source:
|
||||
params = json.load(source)
|
||||
|
||||
# Create the scripts
|
||||
outdirs = []
|
||||
for p in itertools.product(*params.values()):
|
||||
d = dict(zip(params.keys(), p))
|
||||
logging.info(f"Writing script for {d}")
|
||||
# Create out direcotry name
|
||||
outdir_name = os.path.join(args.outdir, params_to_filename(d))
|
||||
assert outdir_name not in outdirs, f"Duplicated output dir: {outdir_name}"
|
||||
outdirs.append(outdir_name)
|
||||
|
||||
if args.mode == "shellargs":
|
||||
# Build command
|
||||
cli_args = params_to_cli_args(d)
|
||||
cmd = f"python {args.exec} {cli_args}"
|
||||
script_lines = header_lines + [cmd]
|
||||
|
||||
# Write script
|
||||
script_name = outdir_name + ".sh"
|
||||
with open(script_name, "w") as sink:
|
||||
for line in script_lines:
|
||||
sink.write(line + "\n")
|
||||
elif args.mode == "json":
|
||||
# Write a json of all the parameters
|
||||
d = dict(zip(params.keys(), p))
|
||||
in_json_fname = outdir_name + ".json"
|
||||
with open(in_json_fname, "w") as sink:
|
||||
json.dump(d, sink, indent=4)
|
||||
|
||||
cmd = f"python {args.exec} {in_json_fname}"
|
||||
script_lines = header_lines + [cmd]
|
||||
script_name = outdir_name + ".sh"
|
||||
with open(script_name, "w") as sink:
|
||||
for line in script_lines:
|
||||
sink.write(line + "\n")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unrecognized mode: {args.mode}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
|
||||
doctest.testmod()
|
||||
main()
|
||||
Reference in New Issue
Block a user