Tweak sampling to work with updated models and params

This commit is contained in:
Kevin Wu
2022-09-09 13:11:12 -07:00
parent 6a15991a4b
commit 27294e5216

View File

@@ -26,7 +26,7 @@ import beta_schedules
import sampling
import plotting
from datasets import NoisedAnglesDataset, CathCanonicalAnglesDataset
from angles_and_coords import create_new_chain
from angles_and_coords import create_new_chain_nerf
import tmalign
# :)
@@ -46,22 +46,18 @@ def build_datasets(
dset_args = dict(
timesteps=training_args["timesteps"],
variance_schedule=training_args["variance_schedule"],
noise_prior=training_args["noise_prior"],
shift_to_zero_twopi=training_args["shift_angles_zero_twopi"],
max_seq_len=training_args["max_seq_len"],
min_seq_len=training_args["min_seq_len"],
var_scale=training_args["variance_scale"],
toy=training_args["subset"],
syn_noiser=training_args["syn_noiser"],
exhaustive_t=training_args["exhaustive_validation_t"],
single_dist_debug=training_args["single_dist_debug"],
single_angle_debug=training_args["single_angle_debug"],
single_time_debug=training_args["single_timestep_debug"],
toy=training_args["subset"],
angles_definitions=training_args["angles_definitions"],
zero_center=training_args["zero_center"],
train_only = True,
)
if "angles_definitions" in training_args:
dset_args["angles_definitions"] = training_args["angles_definitions"]
if "zero_center" in training_args:
dset_args["zero_center"] = training_args["zero_center"]
else:
dset_args["zero_center"] = False # Old default value
train_dset, valid_dset, test_dset = get_train_valid_test_sets(**dset_args)
logging.info(
@@ -69,22 +65,8 @@ def build_datasets(
)
return train_dset, valid_dset, test_dset
def write_as_pdb(
preds: pd.DataFrame,
all_ft_train_dset: CathCanonicalAnglesDataset,
fname: str,
angles_to_use: Optional[List[str]] = None,
):
"""Write the predictions as a pdb file"""
if angles_to_use:
preds = preds.loc[:, angles_to_use] # Sample only dihedrals
create_new_chain(fname, preds, sampled_values_dset=all_ft_train_dset)
def write_preds_pdb_folder(
final_sampled: Sequence[pd.DataFrame],
all_ft_train_dset: CathCanonicalAnglesDataset,
outdir: str,
) -> List[str]:
"""
@@ -95,8 +77,8 @@ def write_preds_pdb_folder(
logging.info(f"Writing sampled anlges as PDB files to {outdir}")
retval = []
for i, samp in enumerate(final_sampled):
fname = os.path.join(outdir, f"generated_{i}.pdb")
write_as_pdb(samp, all_ft_train_dset, fname)
fname = create_new_chain_nerf(os.path.join(outdir, f"generated_{i}.pdb"), samp)
assert fname
retval.append(fname)
return retval
@@ -154,6 +136,7 @@ def build_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--legacy", action="store_true", help="Use legacy model loading code"
)
parser.add_argument("--skiptm", action="store_true", help="Skip calculation of TM scores against training set")
parser.add_argument("--seed", type=int, default=SEED, help="Random seed")
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
return parser
@@ -241,23 +224,23 @@ def main():
)
# Write the sampled angles as pdb files
all_ft_train_dset = CathCanonicalAnglesDataset(split="train")
sampled_dfs = [
pd.DataFrame(s, columns=train_dset.feature_names["angles"])
for s in final_sampled
]
pdb_files = write_preds_pdb_folder(
sampled_dfs, all_ft_train_dset, outdir / "sampled_pdb"
sampled_dfs, outdir / "sampled_pdb"
)
logging.info(f"Done writing main outputs! Calculating tm scores...")
all_tm_scores = {}
for i, fname in enumerate(pdb_files):
samp_name = os.path.splitext(os.path.basename(fname))[0]
tm_score = tmalign.max_tm_across_refs(fname, train_dset.dset.filenames)
all_tm_scores[samp_name] = tm_score
with open(outdir / "tm_scores.json", "w") as sink:
json.dump(all_tm_scores, sink, indent=4)
if not args.skiptm:
logging.info(f"Done writing main outputs! Calculating tm scores...")
all_tm_scores = {}
for i, fname in enumerate(pdb_files):
samp_name = os.path.splitext(os.path.basename(fname))[0]
tm_score = tmalign.max_tm_across_refs(fname, train_dset.dset.filenames)
all_tm_scores[samp_name] = tm_score
with open(outdir / "tm_scores.json", "w") as sink:
json.dump(all_tm_scores, sink, indent=4)
if __name__ == "__main__":