mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
fix(#464): support template filters in quick multimer mode
This commit is contained in:
@@ -525,7 +525,10 @@ class MultimericObject:
|
||||
def __init__(self, interactors: list, pair_msa: bool = True,
|
||||
multimeric_template: bool = False,
|
||||
multimeric_template_meta_data: str = None,
|
||||
multimeric_template_dir:str = None) -> None:
|
||||
multimeric_template_dir:str = None,
|
||||
threshold_clashes: float = 1000,
|
||||
hb_allowance: float = 0.4,
|
||||
plddt_threshold: float = 0) -> None:
|
||||
self.description = ""
|
||||
self.interactors = interactors
|
||||
self.build_description_monomer_mapping()
|
||||
@@ -534,6 +537,9 @@ class MultimericObject:
|
||||
self.chain_id_map = dict()
|
||||
self.input_seqs = []
|
||||
self.multimeric_template_dir = multimeric_template_dir
|
||||
self.threshold_clashes = threshold_clashes
|
||||
self.hb_allowance = hb_allowance
|
||||
self.plddt_threshold = plddt_threshold
|
||||
self.create_output_name()
|
||||
|
||||
if multimeric_template_meta_data is not None:
|
||||
@@ -644,7 +650,10 @@ create_individual_features.py
|
||||
pdb_id = k.split('.cif')[0]
|
||||
multimeric_template_features = extract_multimeric_template_features_for_single_chain(query_seq=curr_monomer.sequence,
|
||||
pdb_id=pdb_id,chain_id=v,
|
||||
mmcif_file=os.path.join(self.multimeric_template_dir,k))
|
||||
mmcif_file=os.path.join(self.multimeric_template_dir,k),
|
||||
threshold_clashes=getattr(self, "threshold_clashes", 1000),
|
||||
hb_allowance=getattr(self, "hb_allowance", 0.4),
|
||||
plddt_threshold=getattr(self, "plddt_threshold", 0))
|
||||
curr_monomer.feature_dict.update(multimeric_template_features.features)
|
||||
|
||||
|
||||
|
||||
@@ -163,7 +163,10 @@ def main(argv):
|
||||
"--protein_delimiter": FLAGS.protein_delimiter,
|
||||
"--desired_num_res": FLAGS.desired_num_res,
|
||||
"--desired_num_msa": FLAGS.desired_num_msa,
|
||||
"--models_to_relax": FLAGS.models_to_relax
|
||||
"--models_to_relax": FLAGS.models_to_relax,
|
||||
"--threshold_clashes": FLAGS.threshold_clashes,
|
||||
"--hb_allowance": FLAGS.hb_allowance,
|
||||
"--plddt_threshold": FLAGS.plddt_threshold,
|
||||
}
|
||||
|
||||
command_args = {}
|
||||
|
||||
@@ -73,6 +73,18 @@ flags.DEFINE_string('description_file', None,
|
||||
'Path to the text file with multimeric template instruction.')
|
||||
flags.DEFINE_string('path_to_mmt', None,
|
||||
'Path to directory with multimeric template mmCIF files.')
|
||||
flags.DEFINE_float(
|
||||
'threshold_clashes',
|
||||
1000,
|
||||
'Threshold for VDW overlap used to remove clashes from quick-mode multimeric templates.')
|
||||
flags.DEFINE_float(
|
||||
'hb_allowance',
|
||||
0.4,
|
||||
'Allowance for hydrogen bonding when filtering quick-mode multimeric templates.')
|
||||
flags.DEFINE_float(
|
||||
'plddt_threshold',
|
||||
0,
|
||||
'Threshold for removing low-pLDDT residues from quick-mode multimeric templates.')
|
||||
flags.DEFINE_integer('desired_num_res', None,
|
||||
'A desired number of residues to pad')
|
||||
flags.DEFINE_integer('desired_num_msa', None,
|
||||
@@ -220,7 +232,8 @@ def _validate_flags_for_backend(backend_name: str) -> None:
|
||||
'num_cycle', 'num_predictions_per_model', 'pair_msa',
|
||||
'save_features_for_multimeric_object', 'skip_templates',
|
||||
'msa_depth_scan', 'multimeric_template', 'model_names', 'msa_depth',
|
||||
'description_file', 'path_to_mmt', 'desired_num_res', 'desired_num_msa',
|
||||
'description_file', 'path_to_mmt', 'threshold_clashes', 'hb_allowance',
|
||||
'plddt_threshold', 'desired_num_res', 'desired_num_msa',
|
||||
'benchmark', 'model_preset', 'use_ap_style', 'use_gpu_relax', 'dropout',
|
||||
}
|
||||
alphalink_extra = {'crosslinks'}
|
||||
@@ -336,6 +349,9 @@ def pre_modelling_setup(
|
||||
multimeric_template=FLAGS.multimeric_template,
|
||||
multimeric_template_meta_data=FLAGS.description_file,
|
||||
multimeric_template_dir=FLAGS.path_to_mmt,
|
||||
threshold_clashes=FLAGS.threshold_clashes,
|
||||
hb_allowance=FLAGS.hb_allowance,
|
||||
plddt_threshold=FLAGS.plddt_threshold,
|
||||
)
|
||||
if FLAGS.save_features_for_multimeric_object:
|
||||
pickle.dump(MultimericObject.feature_dict, open(join(output_dir, "multimeric_object_features.pkl"), "wb"))
|
||||
|
||||
@@ -63,7 +63,14 @@ def obtain_kalign_binary_path() -> Optional[str]:
|
||||
return shutil.which('kalign')
|
||||
|
||||
|
||||
def parse_mmcif_file(file_id: str, mmcif_file: str, chain_id: str) -> ParsingResult:
|
||||
def parse_mmcif_file(
|
||||
file_id: str,
|
||||
mmcif_file: str,
|
||||
chain_id: str,
|
||||
threshold_clashes: float = 1000,
|
||||
hb_allowance: float = 0.4,
|
||||
plddt_threshold: float = 0,
|
||||
) -> ParsingResult:
|
||||
"""
|
||||
Args:
|
||||
file_id: A string identifier for this file. Should be unique within the
|
||||
@@ -76,6 +83,8 @@ def parse_mmcif_file(file_id: str, mmcif_file: str, chain_id: str) -> ParsingRes
|
||||
try:
|
||||
mmcif_filtered_obj = MmcifChainFiltered(
|
||||
Path(mmcif_file), file_id, chain_id=chain_id)
|
||||
mmcif_filtered_obj.remove_clashes(threshold_clashes, hb_allowance)
|
||||
mmcif_filtered_obj.remove_low_plddt(plddt_threshold)
|
||||
parsing_result = mmcif_filtered_obj.parsing_result
|
||||
except FileNotFoundError as e:
|
||||
parsing_result = None
|
||||
@@ -121,6 +130,9 @@ def extract_multimeric_template_features_for_single_chain(
|
||||
pdb_id: str,
|
||||
chain_id: str,
|
||||
mmcif_file: str,
|
||||
threshold_clashes: float = 1000,
|
||||
hb_allowance: float = 0.4,
|
||||
plddt_threshold: float = 0,
|
||||
) -> SingleHitResult:
|
||||
"""
|
||||
Args:
|
||||
@@ -134,7 +146,13 @@ def extract_multimeric_template_features_for_single_chain(
|
||||
A SingleHitResult object
|
||||
"""
|
||||
mmcif_parse_result = parse_mmcif_file(
|
||||
pdb_id, mmcif_file, chain_id=chain_id)
|
||||
pdb_id,
|
||||
mmcif_file,
|
||||
chain_id=chain_id,
|
||||
threshold_clashes=threshold_clashes,
|
||||
hb_allowance=hb_allowance,
|
||||
plddt_threshold=plddt_threshold,
|
||||
)
|
||||
if (mmcif_parse_result is not None) and (mmcif_parse_result.mmcif_object is not None):
|
||||
mapping,template_sequence = _obtain_mapping(mmcif_parse_result=mmcif_parse_result,
|
||||
chain_id=chain_id,
|
||||
|
||||
@@ -42,6 +42,7 @@ def test_obtain_kalign_binary_path_asserts_when_binary_missing(monkeypatch):
|
||||
|
||||
def test_parse_mmcif_file_returns_parsing_result(monkeypatch, tmp_path):
|
||||
expected = SimpleNamespace(name="parsed")
|
||||
calls = []
|
||||
|
||||
class FakeFiltered:
|
||||
def __init__(self, path, file_id, chain_id):
|
||||
@@ -50,11 +51,28 @@ def test_parse_mmcif_file_returns_parsing_result(monkeypatch, tmp_path):
|
||||
assert chain_id == "A"
|
||||
self.parsing_result = expected
|
||||
|
||||
def remove_clashes(self, threshold, hb_allowance):
|
||||
calls.append(("remove_clashes", threshold, hb_allowance))
|
||||
|
||||
def remove_low_plddt(self, threshold):
|
||||
calls.append(("remove_low_plddt", threshold))
|
||||
|
||||
monkeypatch.setattr(mtu, "MmcifChainFiltered", FakeFiltered)
|
||||
|
||||
result = mtu.parse_mmcif_file("1abc", str(tmp_path / "template.cif"), "A")
|
||||
result = mtu.parse_mmcif_file(
|
||||
"1abc",
|
||||
str(tmp_path / "template.cif"),
|
||||
"A",
|
||||
threshold_clashes=12.5,
|
||||
hb_allowance=0.7,
|
||||
plddt_threshold=42.0,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
assert calls == [
|
||||
("remove_clashes", 12.5, 0.7),
|
||||
("remove_low_plddt", 42.0),
|
||||
]
|
||||
|
||||
|
||||
def test_parse_mmcif_file_returns_none_when_file_missing(monkeypatch, tmp_path):
|
||||
|
||||
@@ -968,16 +968,30 @@ def test_create_multimeric_template_features_updates_matching_monomer(monkeypatc
|
||||
multimer.multimeric_template_dir = str(tmp_path)
|
||||
multimer.multimeric_template_meta_data = {"proteinA": {"1abc.cif": "B"}}
|
||||
multimer.monomers_mapping = {"proteinA": monomer}
|
||||
multimer.threshold_clashes = 12.5
|
||||
multimer.hb_allowance = 0.7
|
||||
multimer.plddt_threshold = 42.0
|
||||
extractor_calls = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
objects_mod,
|
||||
"extract_multimeric_template_features_for_single_chain",
|
||||
lambda **kwargs: SimpleNamespace(features={"templated": kwargs["chain_id"]}),
|
||||
lambda **kwargs: extractor_calls.append(kwargs)
|
||||
or SimpleNamespace(features={"templated": kwargs["chain_id"]}),
|
||||
)
|
||||
|
||||
multimer.create_multimeric_template_features()
|
||||
|
||||
assert monomer.feature_dict["templated"] == "B"
|
||||
assert extractor_calls == [{
|
||||
"query_seq": "ACDE",
|
||||
"pdb_id": "1abc",
|
||||
"chain_id": "B",
|
||||
"mmcif_file": str(template_file),
|
||||
"threshold_clashes": 12.5,
|
||||
"hb_allowance": 0.7,
|
||||
"plddt_threshold": 42.0,
|
||||
}]
|
||||
|
||||
|
||||
def test_create_multimeric_template_features_rejects_non_mmcif_files(tmp_path):
|
||||
|
||||
@@ -103,6 +103,9 @@ class _FakeFlagsModule(types.ModuleType):
|
||||
def DEFINE_integer(self, name, default, *_args, **_kwargs):
|
||||
return self.FLAGS.define(name, default)
|
||||
|
||||
def DEFINE_float(self, name, default, *_args, **_kwargs):
|
||||
return self.FLAGS.define(name, default)
|
||||
|
||||
def DEFINE_boolean(self, name, default, *_args, **_kwargs):
|
||||
return self.FLAGS.define(name, default)
|
||||
|
||||
@@ -208,12 +211,18 @@ def _load_run_structure_prediction_module():
|
||||
multimeric_template,
|
||||
multimeric_template_meta_data,
|
||||
multimeric_template_dir,
|
||||
threshold_clashes=1000,
|
||||
hb_allowance=0.4,
|
||||
plddt_threshold=0,
|
||||
):
|
||||
self.interactors = list(interactors)
|
||||
self.pair_msa = pair_msa
|
||||
self.multimeric_template = multimeric_template
|
||||
self.multimeric_template_meta_data = multimeric_template_meta_data
|
||||
self.multimeric_template_dir = multimeric_template_dir
|
||||
self.threshold_clashes = threshold_clashes
|
||||
self.hb_allowance = hb_allowance
|
||||
self.plddt_threshold = plddt_threshold
|
||||
self.description = "_and_".join(interactor.description for interactor in interactors)
|
||||
self.input_seqs = [interactor.sequence for interactor in interactors]
|
||||
self.multimeric_mode = True
|
||||
@@ -315,6 +324,9 @@ def _load_run_multimer_jobs_module():
|
||||
"fold_backend": "alphafold2",
|
||||
"description_file": None,
|
||||
"path_to_mmt": None,
|
||||
"threshold_clashes": 1000,
|
||||
"hb_allowance": 0.4,
|
||||
"plddt_threshold": 0,
|
||||
"compress_result_pickles": False,
|
||||
"remove_result_pickles": False,
|
||||
"remove_keys_from_pickles": True,
|
||||
@@ -649,6 +661,45 @@ def test_pre_modelling_setup_saves_multimer_features_and_builds_unique_ap_style_
|
||||
]
|
||||
|
||||
|
||||
def test_pre_modelling_setup_passes_multimeric_template_filters(
|
||||
run_structure_prediction_module,
|
||||
tmp_path,
|
||||
):
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "pair_msa", True)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "multimeric_template", True)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "description_file", "meta.csv")
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "path_to_mmt", "/tmp/templates")
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "threshold_clashes", 12.5)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "hb_allowance", 0.7)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "plddt_threshold", 42.0)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "save_features_for_multimeric_object", False)
|
||||
_set_flag(
|
||||
run_structure_prediction_module.FLAGS,
|
||||
"features_directory",
|
||||
[str(tmp_path / "features")],
|
||||
)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "use_ap_style", False)
|
||||
|
||||
feature_dir = tmp_path / "features"
|
||||
feature_dir.mkdir()
|
||||
for description in ("protA", "protB"):
|
||||
(feature_dir / f"{description}_feature_metadata_2026-03-30.json").write_text(
|
||||
'{"meta": 1}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monomer_a = run_structure_prediction_module.MonomericObject("protA", "AAAA")
|
||||
monomer_b = run_structure_prediction_module.MonomericObject("protB", "BBBB")
|
||||
returned_object, _ = run_structure_prediction_module.pre_modelling_setup(
|
||||
[monomer_a, monomer_b],
|
||||
output_dir=str(tmp_path / "outputs"),
|
||||
)
|
||||
|
||||
assert returned_object.threshold_clashes == 12.5
|
||||
assert returned_object.hb_allowance == 0.7
|
||||
assert returned_object.plddt_threshold == 42.0
|
||||
|
||||
|
||||
def test_pre_modelling_setup_builds_ap_style_homo_oligomer_dir(
|
||||
run_structure_prediction_module,
|
||||
tmp_path,
|
||||
@@ -1103,3 +1154,40 @@ def test_run_multimer_jobs_combines_inputs_when_padding_requested(
|
||||
input_index = calls[0].index("--input")
|
||||
assert calls[0][input_index + 1] == "job1,job2"
|
||||
assert "--nopair_msa" in calls[0]
|
||||
|
||||
|
||||
def test_run_multimer_jobs_forwards_multimeric_template_filters(
|
||||
run_multimer_jobs_module,
|
||||
monkeypatch,
|
||||
):
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
run_multimer_jobs_module.subprocess,
|
||||
"run",
|
||||
lambda command, check, env: calls.append(command),
|
||||
)
|
||||
run_multimer_jobs_module.generate_fold_specifications = (
|
||||
lambda input_files, delimiter, exclude_permutations: ["job1"]
|
||||
)
|
||||
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "mode", "custom")
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "protein_lists", ["proteins.txt"])
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "dry_run", False)
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "fold_backend", "alphafold2")
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "output_path", "/tmp/output")
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "data_dir", "/tmp/models")
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "monomer_objects_dir", ["/tmp/features"])
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "multimeric_template", True)
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "threshold_clashes", 12.5)
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "hb_allowance", 0.7)
|
||||
_set_flag(run_multimer_jobs_module.FLAGS, "plddt_threshold", 42.0)
|
||||
|
||||
run_multimer_jobs_module.main(["prog"])
|
||||
|
||||
assert len(calls) == 1
|
||||
assert "--threshold_clashes" in calls[0]
|
||||
assert calls[0][calls[0].index("--threshold_clashes") + 1] == "12.5"
|
||||
assert "--hb_allowance" in calls[0]
|
||||
assert calls[0][calls[0].index("--hb_allowance") + 1] == "0.7"
|
||||
assert "--plddt_threshold" in calls[0]
|
||||
assert calls[0][calls[0].index("--plddt_threshold") + 1] == "42.0"
|
||||
|
||||
Reference in New Issue
Block a user