fix(#456): gate best-model relaxation by score

This commit is contained in:
Dima
2026-04-09 14:20:00 +02:00
parent 4f173d9aec
commit 12476a55c2
5 changed files with 191 additions and 7 deletions

View File

@@ -39,6 +39,46 @@ RELAX_EXCLUDE_RESIDUES = []
RELAX_MAX_OUTER_ITERATIONS = 3
def _select_models_to_relax(
ranked_order: List[str],
*,
models_to_relax: "ModelsToRelax",
iptm_scores: Dict[str, float],
ptm_scores: Dict[str, float],
relax_best_score_threshold,
) -> List[str]:
if models_to_relax == ModelsToRelax.ALL:
return ranked_order
if models_to_relax == ModelsToRelax.NONE:
return []
if not ranked_order:
return []
best_model = ranked_order[0]
if relax_best_score_threshold is None:
return [best_model]
if best_model in iptm_scores:
score_name = "iPTM"
score_value = iptm_scores[best_model]
else:
score_name = "pTM"
score_value = ptm_scores.get(best_model)
if score_value is None or score_value < relax_best_score_threshold:
logging.info(
"Skipping relaxation for %s because its %s score %.3f is below the "
"requested threshold %.3f.",
best_model,
score_name,
0.0 if score_value is None else score_value,
relax_best_score_threshold,
)
return []
return [best_model]
@enum.unique
class ModelsToRelax(enum.Enum):
ALL = 0
@@ -788,6 +828,7 @@ class AlphaFold2Backend(FoldingBackend):
output_dir: str,
features_directory: str,
models_to_relax: ModelsToRelax,
relax_best_score_threshold = None,
compress_pickles: bool = False,
remove_pickles: bool = False,
remove_keys_from_pickles: bool = False,
@@ -917,12 +958,13 @@ class AlphaFold2Backend(FoldingBackend):
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
use_gpu=_resolve_gpu_relax(use_gpu_relax))
if models_to_relax == ModelsToRelax.BEST:
to_relax = [ranked_order[0]]
elif models_to_relax == ModelsToRelax.ALL:
to_relax = ranked_order
elif models_to_relax == ModelsToRelax.NONE:
to_relax = []
to_relax = _select_models_to_relax(
ranked_order,
models_to_relax=models_to_relax,
iptm_scores=iptm_scores,
ptm_scores=ptm_scores,
relax_best_score_threshold=relax_best_score_threshold,
)
for model_name in to_relax:
if f'relax_{model_name}' in timings:

View File

@@ -164,6 +164,7 @@ def main(argv):
"--desired_num_res": FLAGS.desired_num_res,
"--desired_num_msa": FLAGS.desired_num_msa,
"--models_to_relax": FLAGS.models_to_relax,
"--relax_best_score_threshold": FLAGS.relax_best_score_threshold,
"--threshold_clashes": FLAGS.threshold_clashes,
"--hb_allowance": FLAGS.hb_allowance,
"--plddt_threshold": FLAGS.plddt_threshold,

View File

@@ -102,6 +102,12 @@ flags.DEFINE_enum_class(
"in case you are having issues with the relaxation "
"stage.",
)
flags.DEFINE_float(
"relax_best_score_threshold",
None,
"Optional minimum iPTM/pTM score required before relaxing the best-ranked "
"model when --models_to_relax=Best.",
)
flags.DEFINE_enum('model_preset', 'monomer',
['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'],
'Choose preset model configuration - the monomer model, '
@@ -228,7 +234,8 @@ def _validate_flags_for_backend(backend_name: str) -> None:
# Backend-specific flags
af2_like_flags = {
'compress_result_pickles', 'remove_result_pickles', 'models_to_relax',
'remove_keys_from_pickles', 'convert_to_modelcif', 'allow_resume',
'relax_best_score_threshold', 'remove_keys_from_pickles',
'convert_to_modelcif', 'allow_resume',
'num_cycle', 'num_predictions_per_model', 'pair_msa',
'save_features_for_multimeric_object', 'skip_templates',
'msa_depth_scan', 'multimeric_template', 'model_names', 'msa_depth',
@@ -475,6 +482,7 @@ def main(argv):
"remove_keys_from_pickles": FLAGS.remove_keys_from_pickles,
"use_gpu_relax": FLAGS.use_gpu_relax,
"models_to_relax": FLAGS.models_to_relax,
"relax_best_score_threshold": FLAGS.relax_best_score_threshold,
"features_directory": FLAGS.features_directory,
"convert_to_modelcif": FLAGS.convert_to_modelcif
}

View File

@@ -971,3 +971,102 @@ def test_postprocess_handles_monomers_without_relaxation_and_logs_modelcif_error
assert plot_calls == [0]
assert cleanup_calls
assert modelcif_errors == ["Error: convert failed"]
def test_postprocess_skips_best_relaxation_below_score_threshold(
af2_backend_module,
monkeypatch,
tmp_path,
):
info_messages = []
monkeypatch.setattr(
af2_backend_module,
"plot_pae_from_matrix",
lambda **kwargs: None,
)
monkeypatch.setattr(
af2_backend_module,
"post_prediction_process",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(
af2_backend_module.logging,
"info",
lambda message, *args: info_messages.append(message % args if args else message),
)
multimer = af2_backend_module.MultimericObject(
description="complex",
input_seqs=["AA", "BB"],
feature_dict={},
multimeric_mode=True,
)
prediction_results = {
"model_high": {
"plddt": np.array([90.0, 91.0, 92.0, 93.0], dtype=np.float32),
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
"max_predicted_aligned_error": 31.0,
"ranking_confidence": 0.9,
"iptm": 0.7,
"ptm": 0.8,
"unrelaxed_protein": SimpleNamespace(name="high"),
"seqs": ["AA", "BB"],
}
}
af2_backend_module.AlphaFold2Backend.postprocess(
prediction_results=prediction_results,
multimeric_object=multimer,
output_dir=tmp_path,
features_directory="/features",
models_to_relax=af2_backend_module.ModelsToRelax.BEST,
relax_best_score_threshold=0.8,
convert_to_modelcif=False,
)
assert not (tmp_path / "relaxed_model_high.pdb").exists()
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "PDB:high"
assert any("Skipping relaxation for model_high" in message for message in info_messages)
def test_postprocess_uses_ptm_threshold_for_best_monomer_relaxation(
af2_backend_module,
monkeypatch,
tmp_path,
):
monkeypatch.setattr(
af2_backend_module,
"plot_pae_from_matrix",
lambda **kwargs: None,
)
monkeypatch.setattr(
af2_backend_module,
"post_prediction_process",
lambda *args, **kwargs: None,
)
monomer = af2_backend_module.MonomericObject("single", "AB")
prediction_results = {
"modelA": {
"plddt": np.array([81.0, 82.0], dtype=np.float32),
"predicted_aligned_error": np.zeros((2, 2), dtype=np.float32),
"max_predicted_aligned_error": 31.0,
"ranking_confidence": 81.5,
"ptm": 0.4,
"seqs": ["AB"],
"unrelaxed_protein": SimpleNamespace(name="mono"),
}
}
af2_backend_module.AlphaFold2Backend.postprocess(
prediction_results=prediction_results,
multimeric_object=monomer,
output_dir=tmp_path,
features_directory="/features",
models_to_relax=af2_backend_module.ModelsToRelax.BEST,
relax_best_score_threshold=0.3,
convert_to_modelcif=False,
)
assert (tmp_path / "relaxed_modelA.pdb").read_text(encoding="utf-8") == "RELAXED:mono"
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "RELAXED:mono"

View File

@@ -313,6 +313,7 @@ def _load_run_multimer_jobs_module():
# Predefine the shared FLAGS that run_multimer_jobs expects from run_structure_prediction.
shared_flag_defaults = {
"models_to_relax": "NONE",
"relax_best_score_threshold": None,
"num_cycle": 3,
"num_predictions_per_model": 1,
"pair_msa": True,
@@ -908,6 +909,7 @@ def test_main_sets_multimer_model_flags_for_multimer_jobs(
_set_flag(run_structure_prediction_module.FLAGS, "msa_depth_scan", True)
_set_flag(run_structure_prediction_module.FLAGS, "model_names", ["model_2_multimer_v3"])
_set_flag(run_structure_prediction_module.FLAGS, "msa_depth", 64)
_set_flag(run_structure_prediction_module.FLAGS, "relax_best_score_threshold", 0.6)
monkeypatch.setattr(run_structure_prediction_module, "parse_fold", lambda *args: [["parsed"]])
monkeypatch.setattr(run_structure_prediction_module, "create_custom_info", lambda parsed: "data")
@@ -934,6 +936,7 @@ def test_main_sets_multimer_model_flags_for_multimer_jobs(
assert captured_calls[0]["model_flags"]["msa_depth_scan"] is True
assert captured_calls[0]["model_flags"]["model_names_custom"] == ["model_2_multimer_v3"]
assert captured_calls[0]["model_flags"]["msa_depth"] == 64
assert captured_calls[0]["postprocess_flags"]["relax_best_score_threshold"] == 0.6
def test_main_rejects_mismatched_output_directories(
@@ -1191,3 +1194,34 @@ def test_run_multimer_jobs_forwards_multimeric_template_filters(
assert calls[0][calls[0].index("--hb_allowance") + 1] == "0.7"
assert "--plddt_threshold" in calls[0]
assert calls[0][calls[0].index("--plddt_threshold") + 1] == "42.0"
def test_run_multimer_jobs_forwards_relax_best_score_threshold(
run_multimer_jobs_module,
monkeypatch,
):
calls = []
monkeypatch.setattr(
run_multimer_jobs_module.subprocess,
"run",
lambda command, check, env: calls.append(command),
)
run_multimer_jobs_module.generate_fold_specifications = (
lambda input_files, delimiter, exclude_permutations: ["job1"]
)
_set_flag(run_multimer_jobs_module.FLAGS, "mode", "custom")
_set_flag(run_multimer_jobs_module.FLAGS, "protein_lists", ["proteins.txt"])
_set_flag(run_multimer_jobs_module.FLAGS, "dry_run", False)
_set_flag(run_multimer_jobs_module.FLAGS, "fold_backend", "alphafold2")
_set_flag(run_multimer_jobs_module.FLAGS, "output_path", "/tmp/output")
_set_flag(run_multimer_jobs_module.FLAGS, "data_dir", "/tmp/models")
_set_flag(run_multimer_jobs_module.FLAGS, "monomer_objects_dir", ["/tmp/features"])
_set_flag(run_multimer_jobs_module.FLAGS, "models_to_relax", "Best")
_set_flag(run_multimer_jobs_module.FLAGS, "relax_best_score_threshold", 0.6)
run_multimer_jobs_module.main(["prog"])
assert len(calls) == 1
assert "--relax_best_score_threshold" in calls[0]
assert calls[0][calls[0].index("--relax_best_score_threshold") + 1] == "0.6"