mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Tweak sampling to work with updated models and params
This commit is contained in:
@@ -26,7 +26,7 @@ import beta_schedules
|
||||
import sampling
|
||||
import plotting
|
||||
from datasets import NoisedAnglesDataset, CathCanonicalAnglesDataset
|
||||
from angles_and_coords import create_new_chain
|
||||
from angles_and_coords import create_new_chain_nerf
|
||||
import tmalign
|
||||
|
||||
# :)
|
||||
@@ -46,22 +46,18 @@ def build_datasets(
|
||||
dset_args = dict(
|
||||
timesteps=training_args["timesteps"],
|
||||
variance_schedule=training_args["variance_schedule"],
|
||||
noise_prior=training_args["noise_prior"],
|
||||
shift_to_zero_twopi=training_args["shift_angles_zero_twopi"],
|
||||
max_seq_len=training_args["max_seq_len"],
|
||||
min_seq_len=training_args["min_seq_len"],
|
||||
var_scale=training_args["variance_scale"],
|
||||
toy=training_args["subset"],
|
||||
syn_noiser=training_args["syn_noiser"],
|
||||
exhaustive_t=training_args["exhaustive_validation_t"],
|
||||
single_dist_debug=training_args["single_dist_debug"],
|
||||
single_angle_debug=training_args["single_angle_debug"],
|
||||
single_time_debug=training_args["single_timestep_debug"],
|
||||
toy=training_args["subset"],
|
||||
angles_definitions=training_args["angles_definitions"],
|
||||
zero_center=training_args["zero_center"],
|
||||
train_only = True,
|
||||
)
|
||||
if "angles_definitions" in training_args:
|
||||
dset_args["angles_definitions"] = training_args["angles_definitions"]
|
||||
if "zero_center" in training_args:
|
||||
dset_args["zero_center"] = training_args["zero_center"]
|
||||
else:
|
||||
dset_args["zero_center"] = False # Old default value
|
||||
|
||||
train_dset, valid_dset, test_dset = get_train_valid_test_sets(**dset_args)
|
||||
logging.info(
|
||||
@@ -69,22 +65,8 @@ def build_datasets(
|
||||
)
|
||||
return train_dset, valid_dset, test_dset
|
||||
|
||||
|
||||
def write_as_pdb(
|
||||
preds: pd.DataFrame,
|
||||
all_ft_train_dset: CathCanonicalAnglesDataset,
|
||||
fname: str,
|
||||
angles_to_use: Optional[List[str]] = None,
|
||||
):
|
||||
"""Write the predictions as a pdb file"""
|
||||
if angles_to_use:
|
||||
preds = preds.loc[:, angles_to_use] # Sample only dihedrals
|
||||
create_new_chain(fname, preds, sampled_values_dset=all_ft_train_dset)
|
||||
|
||||
|
||||
def write_preds_pdb_folder(
|
||||
final_sampled: Sequence[pd.DataFrame],
|
||||
all_ft_train_dset: CathCanonicalAnglesDataset,
|
||||
outdir: str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
@@ -95,8 +77,8 @@ def write_preds_pdb_folder(
|
||||
logging.info(f"Writing sampled anlges as PDB files to {outdir}")
|
||||
retval = []
|
||||
for i, samp in enumerate(final_sampled):
|
||||
fname = os.path.join(outdir, f"generated_{i}.pdb")
|
||||
write_as_pdb(samp, all_ft_train_dset, fname)
|
||||
fname = create_new_chain_nerf(os.path.join(outdir, f"generated_{i}.pdb"), samp)
|
||||
assert fname
|
||||
retval.append(fname)
|
||||
return retval
|
||||
|
||||
@@ -154,6 +136,7 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--legacy", action="store_true", help="Use legacy model loading code"
|
||||
)
|
||||
parser.add_argument("--skiptm", action="store_true", help="Skip calculation of TM scores against training set")
|
||||
parser.add_argument("--seed", type=int, default=SEED, help="Random seed")
|
||||
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
|
||||
return parser
|
||||
@@ -241,23 +224,23 @@ def main():
|
||||
)
|
||||
|
||||
# Write the sampled angles as pdb files
|
||||
all_ft_train_dset = CathCanonicalAnglesDataset(split="train")
|
||||
sampled_dfs = [
|
||||
pd.DataFrame(s, columns=train_dset.feature_names["angles"])
|
||||
for s in final_sampled
|
||||
]
|
||||
pdb_files = write_preds_pdb_folder(
|
||||
sampled_dfs, all_ft_train_dset, outdir / "sampled_pdb"
|
||||
sampled_dfs, outdir / "sampled_pdb"
|
||||
)
|
||||
|
||||
logging.info(f"Done writing main outputs! Calculating tm scores...")
|
||||
all_tm_scores = {}
|
||||
for i, fname in enumerate(pdb_files):
|
||||
samp_name = os.path.splitext(os.path.basename(fname))[0]
|
||||
tm_score = tmalign.max_tm_across_refs(fname, train_dset.dset.filenames)
|
||||
all_tm_scores[samp_name] = tm_score
|
||||
with open(outdir / "tm_scores.json", "w") as sink:
|
||||
json.dump(all_tm_scores, sink, indent=4)
|
||||
if not args.skiptm:
|
||||
logging.info(f"Done writing main outputs! Calculating tm scores...")
|
||||
all_tm_scores = {}
|
||||
for i, fname in enumerate(pdb_files):
|
||||
samp_name = os.path.splitext(os.path.basename(fname))[0]
|
||||
tm_score = tmalign.max_tm_across_refs(fname, train_dset.dset.filenames)
|
||||
all_tm_scores[samp_name] = tm_score
|
||||
with open(outdir / "tm_scores.json", "w") as sink:
|
||||
json.dump(all_tm_scores, sink, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user