mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
fix(#456): gate best-model relaxation by score
This commit is contained in:
@@ -39,6 +39,46 @@ RELAX_EXCLUDE_RESIDUES = []
|
||||
RELAX_MAX_OUTER_ITERATIONS = 3
|
||||
|
||||
|
||||
def _select_models_to_relax(
|
||||
ranked_order: List[str],
|
||||
*,
|
||||
models_to_relax: "ModelsToRelax",
|
||||
iptm_scores: Dict[str, float],
|
||||
ptm_scores: Dict[str, float],
|
||||
relax_best_score_threshold,
|
||||
) -> List[str]:
|
||||
if models_to_relax == ModelsToRelax.ALL:
|
||||
return ranked_order
|
||||
if models_to_relax == ModelsToRelax.NONE:
|
||||
return []
|
||||
if not ranked_order:
|
||||
return []
|
||||
|
||||
best_model = ranked_order[0]
|
||||
if relax_best_score_threshold is None:
|
||||
return [best_model]
|
||||
|
||||
if best_model in iptm_scores:
|
||||
score_name = "iPTM"
|
||||
score_value = iptm_scores[best_model]
|
||||
else:
|
||||
score_name = "pTM"
|
||||
score_value = ptm_scores.get(best_model)
|
||||
|
||||
if score_value is None or score_value < relax_best_score_threshold:
|
||||
logging.info(
|
||||
"Skipping relaxation for %s because its %s score %.3f is below the "
|
||||
"requested threshold %.3f.",
|
||||
best_model,
|
||||
score_name,
|
||||
0.0 if score_value is None else score_value,
|
||||
relax_best_score_threshold,
|
||||
)
|
||||
return []
|
||||
|
||||
return [best_model]
|
||||
|
||||
|
||||
@enum.unique
|
||||
class ModelsToRelax(enum.Enum):
|
||||
ALL = 0
|
||||
@@ -788,6 +828,7 @@ class AlphaFold2Backend(FoldingBackend):
|
||||
output_dir: str,
|
||||
features_directory: str,
|
||||
models_to_relax: ModelsToRelax,
|
||||
relax_best_score_threshold = None,
|
||||
compress_pickles: bool = False,
|
||||
remove_pickles: bool = False,
|
||||
remove_keys_from_pickles: bool = False,
|
||||
@@ -917,12 +958,13 @@ class AlphaFold2Backend(FoldingBackend):
|
||||
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
|
||||
use_gpu=_resolve_gpu_relax(use_gpu_relax))
|
||||
|
||||
if models_to_relax == ModelsToRelax.BEST:
|
||||
to_relax = [ranked_order[0]]
|
||||
elif models_to_relax == ModelsToRelax.ALL:
|
||||
to_relax = ranked_order
|
||||
elif models_to_relax == ModelsToRelax.NONE:
|
||||
to_relax = []
|
||||
to_relax = _select_models_to_relax(
|
||||
ranked_order,
|
||||
models_to_relax=models_to_relax,
|
||||
iptm_scores=iptm_scores,
|
||||
ptm_scores=ptm_scores,
|
||||
relax_best_score_threshold=relax_best_score_threshold,
|
||||
)
|
||||
|
||||
for model_name in to_relax:
|
||||
if f'relax_{model_name}' in timings:
|
||||
|
||||
@@ -164,6 +164,7 @@ def main(argv):
|
||||
"--desired_num_res": FLAGS.desired_num_res,
|
||||
"--desired_num_msa": FLAGS.desired_num_msa,
|
||||
"--models_to_relax": FLAGS.models_to_relax,
|
||||
"--relax_best_score_threshold": FLAGS.relax_best_score_threshold,
|
||||
"--threshold_clashes": FLAGS.threshold_clashes,
|
||||
"--hb_allowance": FLAGS.hb_allowance,
|
||||
"--plddt_threshold": FLAGS.plddt_threshold,
|
||||
|
||||
@@ -102,6 +102,12 @@ flags.DEFINE_enum_class(
|
||||
"in case you are having issues with the relaxation "
|
||||
"stage.",
|
||||
)
|
||||
flags.DEFINE_float(
|
||||
"relax_best_score_threshold",
|
||||
None,
|
||||
"Optional minimum iPTM/pTM score required before relaxing the best-ranked "
|
||||
"model when --models_to_relax=Best.",
|
||||
)
|
||||
flags.DEFINE_enum('model_preset', 'monomer',
|
||||
['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'],
|
||||
'Choose preset model configuration - the monomer model, '
|
||||
@@ -228,7 +234,8 @@ def _validate_flags_for_backend(backend_name: str) -> None:
|
||||
# Backend-specific flags
|
||||
af2_like_flags = {
|
||||
'compress_result_pickles', 'remove_result_pickles', 'models_to_relax',
|
||||
'remove_keys_from_pickles', 'convert_to_modelcif', 'allow_resume',
|
||||
'relax_best_score_threshold', 'remove_keys_from_pickles',
|
||||
'convert_to_modelcif', 'allow_resume',
|
||||
'num_cycle', 'num_predictions_per_model', 'pair_msa',
|
||||
'save_features_for_multimeric_object', 'skip_templates',
|
||||
'msa_depth_scan', 'multimeric_template', 'model_names', 'msa_depth',
|
||||
@@ -475,6 +482,7 @@ def main(argv):
|
||||
"remove_keys_from_pickles": FLAGS.remove_keys_from_pickles,
|
||||
"use_gpu_relax": FLAGS.use_gpu_relax,
|
||||
"models_to_relax": FLAGS.models_to_relax,
|
||||
"relax_best_score_threshold": FLAGS.relax_best_score_threshold,
|
||||
"features_directory": FLAGS.features_directory,
|
||||
"convert_to_modelcif": FLAGS.convert_to_modelcif
|
||||
}
|
||||
|
||||
@@ -971,3 +971,102 @@ def test_postprocess_handles_monomers_without_relaxation_and_logs_modelcif_error
|
||||
assert plot_calls == [0]
|
||||
assert cleanup_calls
|
||||
assert modelcif_errors == ["Error: convert failed"]
|
||||
|
||||
|
||||
def test_postprocess_skips_best_relaxation_below_score_threshold(
|
||||
af2_backend_module,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
info_messages = []
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"plot_pae_from_matrix",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"post_prediction_process",
|
||||
lambda *args, **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module.logging,
|
||||
"info",
|
||||
lambda message, *args: info_messages.append(message % args if args else message),
|
||||
)
|
||||
|
||||
multimer = af2_backend_module.MultimericObject(
|
||||
description="complex",
|
||||
input_seqs=["AA", "BB"],
|
||||
feature_dict={},
|
||||
multimeric_mode=True,
|
||||
)
|
||||
prediction_results = {
|
||||
"model_high": {
|
||||
"plddt": np.array([90.0, 91.0, 92.0, 93.0], dtype=np.float32),
|
||||
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
"ranking_confidence": 0.9,
|
||||
"iptm": 0.7,
|
||||
"ptm": 0.8,
|
||||
"unrelaxed_protein": SimpleNamespace(name="high"),
|
||||
"seqs": ["AA", "BB"],
|
||||
}
|
||||
}
|
||||
|
||||
af2_backend_module.AlphaFold2Backend.postprocess(
|
||||
prediction_results=prediction_results,
|
||||
multimeric_object=multimer,
|
||||
output_dir=tmp_path,
|
||||
features_directory="/features",
|
||||
models_to_relax=af2_backend_module.ModelsToRelax.BEST,
|
||||
relax_best_score_threshold=0.8,
|
||||
convert_to_modelcif=False,
|
||||
)
|
||||
|
||||
assert not (tmp_path / "relaxed_model_high.pdb").exists()
|
||||
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "PDB:high"
|
||||
assert any("Skipping relaxation for model_high" in message for message in info_messages)
|
||||
|
||||
|
||||
def test_postprocess_uses_ptm_threshold_for_best_monomer_relaxation(
|
||||
af2_backend_module,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"plot_pae_from_matrix",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"post_prediction_process",
|
||||
lambda *args, **kwargs: None,
|
||||
)
|
||||
|
||||
monomer = af2_backend_module.MonomericObject("single", "AB")
|
||||
prediction_results = {
|
||||
"modelA": {
|
||||
"plddt": np.array([81.0, 82.0], dtype=np.float32),
|
||||
"predicted_aligned_error": np.zeros((2, 2), dtype=np.float32),
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
"ranking_confidence": 81.5,
|
||||
"ptm": 0.4,
|
||||
"seqs": ["AB"],
|
||||
"unrelaxed_protein": SimpleNamespace(name="mono"),
|
||||
}
|
||||
}
|
||||
|
||||
af2_backend_module.AlphaFold2Backend.postprocess(
|
||||
prediction_results=prediction_results,
|
||||
multimeric_object=monomer,
|
||||
output_dir=tmp_path,
|
||||
features_directory="/features",
|
||||
models_to_relax=af2_backend_module.ModelsToRelax.BEST,
|
||||
relax_best_score_threshold=0.3,
|
||||
convert_to_modelcif=False,
|
||||
)
|
||||
|
||||
assert (tmp_path / "relaxed_modelA.pdb").read_text(encoding="utf-8") == "RELAXED:mono"
|
||||
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "RELAXED:mono"
|
||||
|
||||
@@ -313,6 +313,7 @@ def _load_run_multimer_jobs_module():
|
||||
# Predefine the shared FLAGS that run_multimer_jobs expects from run_structure_prediction.
|
||||
shared_flag_defaults = {
|
||||
"models_to_relax": "NONE",
|
||||
"relax_best_score_threshold": None,
|
||||
"num_cycle": 3,
|
||||
"num_predictions_per_model": 1,
|
||||
"pair_msa": True,
|
||||
@@ -908,6 +909,7 @@ def test_main_sets_multimer_model_flags_for_multimer_jobs(
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "msa_depth_scan", True)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "model_names", ["model_2_multimer_v3"])
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "msa_depth", 64)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "relax_best_score_threshold", 0.6)
|
||||
|
||||
monkeypatch.setattr(run_structure_prediction_module, "parse_fold", lambda *args: [["parsed"]])
|
||||
monkeypatch.setattr(run_structure_prediction_module, "create_custom_info", lambda parsed: "data")
|
||||
@@ -934,6 +936,7 @@ def test_main_sets_multimer_model_flags_for_multimer_jobs(
|
||||
assert captured_calls[0]["model_flags"]["msa_depth_scan"] is True
|
||||
assert captured_calls[0]["model_flags"]["model_names_custom"] == ["model_2_multimer_v3"]
|
||||
assert captured_calls[0]["model_flags"]["msa_depth"] == 64
|
||||
assert captured_calls[0]["postprocess_flags"]["relax_best_score_threshold"] == 0.6
|
||||
|
||||
|
||||
def test_main_rejects_mismatched_output_directories(
|
||||
@@ -1191,3 +1194,34 @@ def test_run_multimer_jobs_forwards_multimeric_template_filters(
|
||||
assert calls[0][calls[0].index("--hb_allowance") + 1] == "0.7"
|
||||
assert "--plddt_threshold" in calls[0]
|
||||
assert calls[0][calls[0].index("--plddt_threshold") + 1] == "42.0"
|
||||
|
||||
|
||||
def test_run_multimer_jobs_forwards_relax_best_score_threshold(
|
||||
run_multimer_jobs_module,
|
||||
monkeypatch,
|
||||
):
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
run_multimer_jobs_module.subprocess,
|
||||
"run",
|
||||
lambda command, check, env: calls.append(command),
|
||||
)
|
||||
run_multimer_jobs_module.generate_fold_specifications = (
|
||||
lambda input_files, delimiter, exclude_permutations: ["job1"]
|
||||
)
|
||||
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "mode", "custom")
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "protein_lists", ["proteins.txt"])
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "dry_run", False)
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "fold_backend", "alphafold2")
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "output_path", "/tmp/output")
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "data_dir", "/tmp/models")
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "monomer_objects_dir", ["/tmp/features"])
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "models_to_relax", "Best")
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "relax_best_score_threshold", 0.6)
|
||||
|
||||
run_multimer_jobs_module.main(["prog"])
|
||||
|
||||
assert len(calls) == 1
|
||||
assert "--relax_best_score_threshold" in calls[0]
|
||||
assert calls[0][calls[0].index("--relax_best_score_threshold") + 1] == "0.6"
|
||||
|
||||
Reference in New Issue
Block a user