Save reconst scores to json

This commit is contained in:
Kevin Wu
2023-02-13 23:12:47 -08:00
parent a162535ad0
commit 52db3a274a

View File

@@ -54,8 +54,11 @@ def load_dataset(pdb_files: Collection[str], model_dir: Path):
def build_parser():
parser = argparse.ArgumentParser(usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("pdb_files", nargs="+", help="PDB files to reconstruct")
parser.add_argument("output_json", type=str, help="Output JSON file")
parser.add_argument(
"-t",
"--timesteps",
@@ -75,7 +78,7 @@ def get_reconstruction_error(
timesteps: int = 800,
model: str = "wukevin/foldingdiff_cath",
device: torch.device = torch.device("cuda:1"),
):
) -> np.ndarray:
"""Get the reconstruction error for a set of PDB files"""
if utils.is_huggingface_hub_id(model):
logging.info(f"Detected huggingface repo ID {model}")
@@ -102,6 +105,7 @@ def get_reconstruction_error(
logging.info(
f"Reconstuction scores from t={timesteps}: {(np.min(scores_wrt_coords), np.max(scores_wrt_coords))}"
)
return scores_wrt_coords
def main():
@@ -109,12 +113,16 @@ def main():
parser = build_parser()
args = parser.parse_args()
get_reconstruction_error(
scores = get_reconstruction_error(
args.pdb_files,
timesteps=args.timesteps,
model=args.model,
device=torch.device(f"cuda:{args.device}"),
)
scores_dict = {pdb: score for pdb, score in zip(args.pdb_files, scores)}
with open(args.output_json, "w") as sink:
json.dump(scores_dict, sink, indent=4)
if __name__ == "__main__":