mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Save reconst scores to json
This commit is contained in:
@@ -54,8 +54,11 @@ def load_dataset(pdb_files: Collection[str], model_dir: Path):
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = argparse.ArgumentParser(
|
||||
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("pdb_files", nargs="+", help="PDB files to reconstruct")
|
||||
parser.add_argument("output_json", type=str, help="Output JSON file")
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--timesteps",
|
||||
@@ -75,7 +78,7 @@ def get_reconstruction_error(
|
||||
timesteps: int = 800,
|
||||
model: str = "wukevin/foldingdiff_cath",
|
||||
device: torch.device = torch.device("cuda:1"),
|
||||
):
|
||||
) -> np.ndarray:
|
||||
"""Get the reconstruction error for a set of PDB files"""
|
||||
if utils.is_huggingface_hub_id(model):
|
||||
logging.info(f"Detected huggingface repo ID {model}")
|
||||
@@ -102,6 +105,7 @@ def get_reconstruction_error(
|
||||
logging.info(
|
||||
f"Reconstuction scores from t={timesteps}: {(np.min(scores_wrt_coords), np.max(scores_wrt_coords))}"
|
||||
)
|
||||
return scores_wrt_coords
|
||||
|
||||
|
||||
def main():
|
||||
@@ -109,12 +113,16 @@ def main():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
get_reconstruction_error(
|
||||
scores = get_reconstruction_error(
|
||||
args.pdb_files,
|
||||
timesteps=args.timesteps,
|
||||
model=args.model,
|
||||
device=torch.device(f"cuda:{args.device}"),
|
||||
)
|
||||
|
||||
scores_dict = {pdb: score for pdb, score in zip(args.pdb_files, scores)}
|
||||
with open(args.output_json, "w") as sink:
|
||||
json.dump(scores_dict, sink, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user