diff --git a/bin/partial_noise_reconstruct.py b/bin/partial_noise_reconstruct.py index cef96a7..d3b6f7f 100644 --- a/bin/partial_noise_reconstruct.py +++ b/bin/partial_noise_reconstruct.py @@ -54,8 +54,11 @@ def load_dataset(pdb_files: Collection[str], model_dir: Path): def build_parser(): - parser = argparse.ArgumentParser(usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument("pdb_files", nargs="+", help="PDB files to reconstruct") + parser.add_argument("output_json", type=str, help="Output JSON file") parser.add_argument( "-t", "--timesteps", @@ -75,7 +78,7 @@ def get_reconstruction_error( timesteps: int = 800, model: str = "wukevin/foldingdiff_cath", device: torch.device = torch.device("cuda:1"), -): +) -> np.ndarray: """Get the reconstruction error for a set of PDB files""" if utils.is_huggingface_hub_id(model): logging.info(f"Detected huggingface repo ID {model}") @@ -102,6 +105,7 @@ def get_reconstruction_error( logging.info( f"Reconstuction scores from t={timesteps}: {(np.min(scores_wrt_coords), np.max(scores_wrt_coords))}" ) + return scores_wrt_coords def main(): @@ -109,12 +113,16 @@ def main(): parser = build_parser() args = parser.parse_args() - get_reconstruction_error( + scores = get_reconstruction_error( args.pdb_files, timesteps=args.timesteps, model=args.model, device=torch.device(f"cuda:{args.device}"), ) + + scores_dict = {pdb: score for pdb, score in zip(args.pdb_files, scores)} + with open(args.output_json, "w") as sink: + json.dump(scores_dict, sink, indent=4) if __name__ == "__main__":