diff --git a/bin/sample.py b/bin/sample.py index f7d65b9..f2f0d1f 100644 --- a/bin/sample.py +++ b/bin/sample.py @@ -26,7 +26,7 @@ import beta_schedules import sampling import plotting from datasets import NoisedAnglesDataset, CathCanonicalAnglesDataset -from angles_and_coords import create_new_chain +from angles_and_coords import create_new_chain_nerf import tmalign # :) @@ -46,22 +46,18 @@ def build_datasets( dset_args = dict( timesteps=training_args["timesteps"], variance_schedule=training_args["variance_schedule"], - noise_prior=training_args["noise_prior"], - shift_to_zero_twopi=training_args["shift_angles_zero_twopi"], + max_seq_len=training_args["max_seq_len"], + min_seq_len=training_args["min_seq_len"], var_scale=training_args["variance_scale"], - toy=training_args["subset"], syn_noiser=training_args["syn_noiser"], exhaustive_t=training_args["exhaustive_validation_t"], - single_dist_debug=training_args["single_dist_debug"], single_angle_debug=training_args["single_angle_debug"], single_time_debug=training_args["single_timestep_debug"], + toy=training_args["subset"], + angles_definitions=training_args["angles_definitions"], + zero_center=training_args["zero_center"], + train_only = True, ) - if "angles_definitions" in training_args: - dset_args["angles_definitions"] = training_args["angles_definitions"] - if "zero_center" in training_args: - dset_args["zero_center"] = training_args["zero_center"] - else: - dset_args["zero_center"] = False # Old default value train_dset, valid_dset, test_dset = get_train_valid_test_sets(**dset_args) logging.info( @@ -69,22 +65,8 @@ def build_datasets( ) return train_dset, valid_dset, test_dset - -def write_as_pdb( - preds: pd.DataFrame, - all_ft_train_dset: CathCanonicalAnglesDataset, - fname: str, - angles_to_use: Optional[List[str]] = None, -): - """Write the predictions as a pdb file""" - if angles_to_use: - preds = preds.loc[:, angles_to_use] # Sample only dihedrals - create_new_chain(fname, preds, sampled_values_dset=all_ft_train_dset) - - def write_preds_pdb_folder( final_sampled: Sequence[pd.DataFrame], - all_ft_train_dset: CathCanonicalAnglesDataset, outdir: str, ) -> List[str]: """ @@ -95,8 +77,8 @@ def write_preds_pdb_folder( logging.info(f"Writing sampled anlges as PDB files to {outdir}") retval = [] for i, samp in enumerate(final_sampled): - fname = os.path.join(outdir, f"generated_{i}.pdb") - write_as_pdb(samp, all_ft_train_dset, fname) + fname = create_new_chain_nerf(os.path.join(outdir, f"generated_{i}.pdb"), samp) + assert fname retval.append(fname) return retval @@ -154,6 +136,7 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument( "--legacy", action="store_true", help="Use legacy model loading code" ) + parser.add_argument("--skiptm", action="store_true", help="Skip calculation of TM scores against training set") parser.add_argument("--seed", type=int, default=SEED, help="Random seed") parser.add_argument("--device", type=str, default="cuda:0", help="Device to use") return parser @@ -241,23 +224,23 @@ def main(): ) # Write the sampled angles as pdb files - all_ft_train_dset = CathCanonicalAnglesDataset(split="train") sampled_dfs = [ pd.DataFrame(s, columns=train_dset.feature_names["angles"]) for s in final_sampled ] pdb_files = write_preds_pdb_folder( - sampled_dfs, all_ft_train_dset, outdir / "sampled_pdb" + sampled_dfs, outdir / "sampled_pdb" ) - logging.info(f"Done writing main outputs! Calculating tm scores...") - all_tm_scores = {} - for i, fname in enumerate(pdb_files): - samp_name = os.path.splitext(os.path.basename(fname))[0] - tm_score = tmalign.max_tm_across_refs(fname, train_dset.dset.filenames) - all_tm_scores[samp_name] = tm_score - with open(outdir / "tm_scores.json", "w") as sink: - json.dump(all_tm_scores, sink, indent=4) + if not args.skiptm: + logging.info(f"Done writing main outputs! Calculating tm scores...") + all_tm_scores = {} + for i, fname in enumerate(pdb_files): + samp_name = os.path.splitext(os.path.basename(fname))[0] + tm_score = tmalign.max_tm_across_refs(fname, train_dset.dset.filenames) + all_tm_scores[samp_name] = tm_score + with open(outdir / "tm_scores.json", "w") as sink: + json.dump(all_tm_scores, sink, indent=4) if __name__ == "__main__":