fix(#42): add skip-MSA feature generation mode

This commit is contained in:
Dima
2026-04-10 11:35:27 +02:00
parent 098792fa21
commit 3dfd6d8aad
9 changed files with 579 additions and 16 deletions

View File

@@ -26,6 +26,17 @@ from alphapulldown.utils.mmseqs_species_identifiers import (
strip_mmseq_comment_lines,
)
def _query_only_a3m(sequence: str, query_id: str = "query") -> str:
"""Return a single-sequence A3M string for query-only workflows."""
return f">{query_id}\n{sequence}\n"
def _query_only_stockholm(sequence: str, query_id: str = "query") -> str:
"""Return a single-sequence Stockholm alignment string."""
return f"# STOCKHOLM 1.0\n{query_id} {sequence}\n//\n"
class MonomericObject:
"""
monomeric objects
@@ -41,6 +52,7 @@ class MonomericObject:
self.sequence = sequence
self.feature_dict = dict()
self._uniprot_runner = None
self.skip_msa = False
pass
@property
@@ -140,7 +152,8 @@ class MonomericObject:
def make_features(
self, pipeline, output_dir: str,
use_precomputed_msa: bool = False,
save_msa: bool = True, compress_msa_files: bool = False
save_msa: bool = True, compress_msa_files: bool = False,
skip_msa: bool = False,
):
"""a method that make msa and template features"""
os.makedirs(os.path.join(output_dir, self.description), exist_ok=True)
@@ -155,13 +168,20 @@ class MonomericObject:
logging.info(
"will save msa files in :{}".format(msa_output_dir))
plPath(msa_output_dir).mkdir(parents=True, exist_ok=True)
with temp_fasta_file(sequence_str) as fasta_file:
self.feature_dict = pipeline.process(
fasta_file, msa_output_dir)
pairing_results = self.all_seq_msa_features(
fasta_file, self._uniprot_runner, msa_output_dir, use_precomputed_msa
self.skip_msa = skip_msa
if skip_msa:
self.feature_dict = self._build_query_only_feature_dict()
self.feature_dict.update(
self._search_templates_with_query_only_msa(pipeline, msa_output_dir)
)
self.feature_dict.update(pairing_results)
else:
with temp_fasta_file(sequence_str) as fasta_file:
self.feature_dict = pipeline.process(
fasta_file, msa_output_dir)
pairing_results = self.all_seq_msa_features(
fasta_file, self._uniprot_runner, msa_output_dir, use_precomputed_msa
)
self.feature_dict.update(pairing_results)
# Add extra features to make it compatible with pickle features obtaiend from mmseqs2
template_confidence_scores = self.feature_dict.get('template_confidence_scores', None)
@@ -187,6 +207,60 @@ class MonomericObject:
MonomericObject.zip_msa_files(
os.path.join(output_dir, self.description))
def _build_query_only_feature_dict(self) -> Dict[str, Any]:
"""Build AF2-compatible features with the query as the only MSA row."""
query_only_msa = parsers.parse_a3m(_query_only_a3m(self.sequence))
sequence_features = pipeline.make_sequence_features(
sequence=self.sequence,
description=self.description,
num_res=len(self.sequence),
)
msa_features = pipeline.make_msa_features((query_only_msa,))
all_seq_features = {
f"{key}_all_seq": np.array(value, copy=True)
for key, value in msa_features.items()
}
all_seq_features["msa_uniprot_accession_identifiers_all_seq"] = np.array(
[b""], dtype=object
)
return {**sequence_features, **msa_features, **all_seq_features}
def _search_templates_with_query_only_msa(
self, af2_pipeline: pipeline.DataPipeline, msa_output_dir: str
) -> Dict[str, Any]:
"""Run template search from a synthetic single-sequence alignment."""
template_searcher = getattr(af2_pipeline, "template_searcher", None)
template_featurizer = getattr(af2_pipeline, "template_featurizer", None)
if template_searcher is None or template_featurizer is None:
return {}
stockholm_msa = _query_only_stockholm(self.sequence)
if template_searcher.input_format == "sto":
template_query = stockholm_msa
elif template_searcher.input_format == "a3m":
template_query = _query_only_a3m(self.sequence)
else:
raise ValueError(
"Unrecognized template input format: "
f"{template_searcher.input_format}"
)
pdb_templates_result = template_searcher.query(template_query)
pdb_hits_out_path = os.path.join(
msa_output_dir, f"pdb_hits.{template_searcher.output_format}"
)
with open(pdb_hits_out_path, "w") as handle:
handle.write(pdb_templates_result)
pdb_template_hits = template_searcher.get_template_hits(
output_string=pdb_templates_result,
input_sequence=self.sequence,
)
templates_result = template_featurizer.get_templates(
query_sequence=self.sequence,
hits=pdb_template_hits,
)
return dict(templates_result.features)
def make_mmseq_features(
self, DEFAULT_API_SERVER,
@@ -195,6 +269,7 @@ class MonomericObject:
use_precomputed_msa=False,
use_templates=False,
custom_template_path=None,
skip_msa: bool = False,
):
"""
A method to use mmseq_remote to calculate MSA.
@@ -212,7 +287,29 @@ class MonomericObject:
logging.info(f"Skipping {self.description} (result.zip)")
a3m_path = os.path.join(result_dir, self.description + ".a3m")
if use_precomputed_msa and os.path.isfile(a3m_path):
self.skip_msa = skip_msa
if skip_msa:
a3m_lines = [_query_only_a3m(self.sequence, query_id="101")]
plPath(a3m_path).write_text(a3m_lines[0])
(
unpaired_msa,
paired_msa,
query_seqs_unique,
query_seqs_cardinality,
template_features,
) = get_msa_and_templates(
jobname=self.description,
query_sequences=self.sequence,
a3m_lines=a3m_lines,
result_dir=plPath(result_dir),
msa_mode="single_sequence",
use_templates=use_templates,
custom_template_path=custom_template_path,
pair_mode="none",
host_url=DEFAULT_API_SERVER,
user_agent="alphapulldown",
)
elif use_precomputed_msa and os.path.isfile(a3m_path):
logging.info(f"Using precomputed MSA from {a3m_path}")
a3m_lines = [plPath(a3m_path).read_text()]
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality,

View File

@@ -142,6 +142,7 @@ flags.DEFINE_boolean('use_precomputed_msas', False, '')
flags.DEFINE_boolean('re_search_templates_mmseqs2', False, '')
flags.DEFINE_bool("use_mmseqs2", False, "")
flags.DEFINE_bool("save_msa_files", False, "")
flags.DEFINE_bool("skip_msa", False, "")
flags.DEFINE_bool("skip_existing", False, "")
flags.DEFINE_string("new_uniclust_dir", None, "")
flags.DEFINE_integer("seq_index", None, "")
@@ -272,11 +273,40 @@ def get_af3_chain_kind(description, sequence):
def create_af3_chain(sequence, description, chain_id):
"""Construct an AF3 chain object for the provided sequence."""
chain_kind = get_af3_chain_kind(description, sequence)
query_only_a3m = f">query\n{sequence}\n" if FLAGS.skip_msa else None
if chain_kind == "dna":
return folding_input.DnaChain(sequence=sequence, id=chain_id, modifications=[])
return folding_input.DnaChain(
sequence=sequence,
id=chain_id,
modifications=[],
description=description,
)
if chain_kind == "rna":
return folding_input.RnaChain(sequence=sequence, id=chain_id, modifications=[])
return folding_input.ProteinChain(sequence=sequence, id=chain_id, ptms=[])
kwargs = {
"sequence": sequence,
"id": chain_id,
"modifications": [],
"description": description,
}
if FLAGS.skip_msa:
kwargs["unpaired_msa"] = query_only_a3m
return folding_input.RnaChain(**kwargs)
kwargs = {
"sequence": sequence,
"id": chain_id,
"ptms": [],
"description": description,
}
if FLAGS.skip_msa:
kwargs.update(
{
"paired_msa": "",
"unpaired_msa": query_only_a3m,
"templates": None,
}
)
return folding_input.ProteinChain(**kwargs)
# =================== AlphaFold 2 Feature Creation ===================
@@ -444,6 +474,13 @@ def _reuse_truemultimer_monomer_features(feat):
monomer = _load_existing_monomer_from_output_dir(source_name)
if monomer is None:
return None
if FLAGS.skip_msa and not getattr(monomer, "skip_msa", False):
logging.info(
"Existing monomer features for %s were generated with bulk MSAs. "
"Recomputing query-only features for --skip_msa.",
source_name,
)
return None
if monomer.sequence != feat["sequence"]:
logging.warning(
"Existing monomer features for %s use sequence %s, but the current "
@@ -489,6 +526,7 @@ def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None
if _should_skip_monomer_output(monomer.description):
return
monomer.skip_msa = FLAGS.skip_msa
if FLAGS.use_mmseqs2:
monomer.make_mmseq_features(
DEFAULT_API_SERVER=DEFAULT_API_SERVER,
@@ -496,12 +534,15 @@ def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None
use_precomputed_msa=FLAGS.use_precomputed_msas,
use_templates=FLAGS.re_search_templates_mmseqs2 or custom_template_path is not None,
custom_template_path=custom_template_path,
skip_msa=FLAGS.skip_msa,
)
else:
monomer.make_features(
pipeline=pipeline, output_dir=FLAGS.output_dir,
use_precomputed_msa=FLAGS.use_precomputed_msas,
save_msa=FLAGS.save_msa_files)
save_msa=FLAGS.save_msa_files,
skip_msa=FLAGS.skip_msa,
)
_persist_monomer_outputs(monomer)
def create_individual_features_truemultimer():

View File

@@ -348,6 +348,12 @@ def pre_modelling_setup(
A MultimericObject or MonomericObject
output_directory for this particular modelling job
"""
if FLAGS.pair_msa and any(getattr(interactor, "skip_msa", False) for interactor in interactors):
raise ValueError(
"--skip_msa generates query-only MSAs and cannot be combined with "
"--pair_msa=True. Re-run structure prediction with --pair_msa=False."
)
if len(interactors) > 1:
# this means it's going to be a MultimericObject
object_to_model = MultimericObject(

View File

@@ -271,6 +271,7 @@ def create_interactors(data : List[Dict[str, List[str]]],
monomer.feature_dict,
curr_interactor_region,
)
chopped_object.skip_msa = getattr(monomer, "skip_msa", False)
chopped_object.prepare_final_sliced_feature_dict()
interactors.append(chopped_object)
return interactors

View File

@@ -163,6 +163,7 @@ def tmp_flags(monkeypatch, tmp_path):
use_mmseqs2=False,
use_precomputed_msas=False,
save_msa_files=False,
skip_msa=False,
skip_existing=False,
compress_features=False,
db_preset="full_dbs",

View File

@@ -78,22 +78,50 @@ def build_af3_stub_modules():
mmcif_mod = types.ModuleType("alphafold3.structure.mmcif")
class ProteinChain:
def __init__(self, sequence, id, ptms=None):
def __init__(
self,
sequence,
id,
ptms=None,
residue_ids=None,
description=None,
paired_msa=None,
unpaired_msa=None,
templates=None,
):
self.sequence = sequence
self.id = id
self.ptms = [] if ptms is None else list(ptms)
self.residue_ids = residue_ids
self.description = description
self.paired_msa = paired_msa
self.unpaired_msa = unpaired_msa
self.templates = templates
class RnaChain:
def __init__(self, sequence, id, modifications=None):
def __init__(
self,
sequence,
id,
modifications=None,
residue_ids=None,
description=None,
unpaired_msa=None,
):
self.sequence = sequence
self.id = id
self.modifications = [] if modifications is None else list(modifications)
self.residue_ids = residue_ids
self.description = description
self.unpaired_msa = unpaired_msa
class DnaChain:
def __init__(self, sequence, id, modifications=None):
def __init__(self, sequence, id, modifications=None, residue_ids=None, description=None):
self.sequence = sequence
self.id = id
self.modifications = [] if modifications is None else list(modifications)
self.residue_ids = residue_ids
self.description = description
class Input:
def __init__(self, name, chains, rng_seeds):
@@ -598,6 +626,62 @@ class TestCreateIndividualFeaturesComprehensive:
assert saved_monomer.sequence == "ACDF"
assert saved_monomer.uniprot_runner == "runner"
def test_process_multimeric_features_does_not_reuse_bulk_msa_pickle_for_skip_msa(
self, tmp_flags
):
template_path = Path(self.test_dir) / "template1.cif"
template_path.write_text("data_template\n", encoding="utf-8")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.output_dir = os.path.join(self.test_dir, "skip_msa_truemultimer_output")
FLAGS.use_mmseqs2 = False
FLAGS.compress_features = False
FLAGS.skip_existing = False
FLAGS.skip_msa = True
FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer"
FLAGS.uniprot_database_path = "/db/uniprot.fasta"
output_dir = Path(FLAGS.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
source = MonomericObject("proteinA", "ACDE")
source.feature_dict = {
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int": np.zeros((1, 4), dtype=np.int32),
"num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32),
"msa_species_identifiers": np.asarray([b"9606"], dtype=object),
}
with open(output_dir / "proteinA.pkl", "wb") as handle:
pickle.dump(source, handle)
feat = {
"protein": "proteinA.template1.cif.A",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDE",
}
with patch.object(
create_features,
"extract_multimeric_template_features_for_single_chain",
) as mock_extract, \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)
mock_extract.assert_not_called()
mock_custom_db.assert_called_once()
mock_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_called_once_with()
mock_runner.assert_called_once_with("/usr/bin/jackhmmer", "/db/uniprot.fasta")
mock_save.assert_called_once()
def test_main_dispatches_to_truemultimer_for_af2_template_runs(self):
"""The main entrypoint should route AF2 template jobs to the TrueMultimer path."""
from absl import flags
@@ -1318,6 +1402,7 @@ def test_create_and_save_monomer_objects_writes_compressed_af2_outputs(tmp_flags
"output_dir": str(tmp_path),
"use_precomputed_msa": True,
"save_msa": True,
"skip_msa": False,
}
]
assert monomer.mmseq_calls == []
@@ -1359,11 +1444,61 @@ def test_create_and_save_monomer_objects_uses_mmseqs_when_requested(tmp_flags, t
"use_precomputed_msa": True,
"use_templates": True,
"custom_template_path": None,
"skip_msa": False,
}
]
assert (tmp_path / "protA.pkl").exists()
def test_create_and_save_monomer_objects_passes_skip_msa_to_af2_builder(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.use_precomputed_msas = False
create_features.FLAGS.save_msa_files = False
create_features.FLAGS.skip_msa = True
monomer = RecordingDummyMonomer("protA")
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(monomer, pipeline="pipeline")
assert monomer.feature_calls == [
{
"pipeline": "pipeline",
"output_dir": str(tmp_path),
"use_precomputed_msa": False,
"save_msa": False,
"skip_msa": True,
}
]
def test_create_and_save_monomer_objects_passes_skip_msa_to_mmseqs_builder(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = True
create_features.FLAGS.use_precomputed_msas = False
create_features.FLAGS.re_search_templates_mmseqs2 = False
create_features.FLAGS.skip_msa = True
monomer = RecordingDummyMonomer("protA")
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(monomer, pipeline=None)
assert monomer.mmseq_calls == [
{
"DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER,
"output_dir": str(tmp_path),
"use_precomputed_msa": False,
"use_templates": False,
"custom_template_path": None,
"skip_msa": True,
}
]
def test_create_and_save_monomer_objects_passes_custom_templates_to_mmseqs(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
@@ -1390,6 +1525,7 @@ def test_create_and_save_monomer_objects_passes_custom_templates_to_mmseqs(tmp_f
"use_precomputed_msa": False,
"use_templates": True,
"custom_template_path": custom_template_path,
"skip_msa": False,
}
]
assert (tmp_path / "protA.pkl").exists()
@@ -1661,6 +1797,37 @@ def test_create_af3_individual_features_skips_existing_outputs(tmp_flags, tmp_pa
assert existing_output.read_text(encoding="utf-8") == "{}"
def test_create_af3_individual_features_prefills_query_only_msas_when_skip_msa(
tmp_flags, tmp_path
):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.data_pipeline = "alphafold3"
create_features.FLAGS.skip_msa = True
af3_modules, folding_input_stub = build_af3_stub_modules()
pipeline = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, "create_pipeline_af3", return_value=pipeline), \
patch.object(create_features, "folding_input", folding_input_stub), \
patch.object(
create_features,
"iter_seqs",
return_value=[("ACDE", "protein_chain protein"), ("AUGA", "rna_chain RNA")],
), \
patch("pathlib.Path.write_text", new=real_write_text):
create_features.create_af3_individual_features()
protein_input = pipeline.process.call_args_list[0].args[0]
protein_chain = protein_input.chains[0]
assert protein_chain.unpaired_msa == ">query\nACDE\n"
assert protein_chain.paired_msa == ""
assert protein_chain.templates is None
rna_input = pipeline.process.call_args_list[1].args[0]
rna_chain = rna_input.chains[0]
assert rna_chain.unpaired_msa == ">query\nAUGA\n"
def test_main_dispatches_to_af3_feature_creation(tmp_flags, tmp_path):
create_features.FLAGS.data_pipeline = "alphafold3"
create_features.FLAGS.output_dir = str(tmp_path / "af3_out")

View File

@@ -220,6 +220,42 @@ def test_create_interactors_builds_chopped_object_for_region_lists(monkeypatch):
assert calls["args"] == ("proteinA", "ACDEFG", monomer.feature_dict, [(2, 4)])
def test_create_interactors_propagates_skip_msa_marker_to_chopped_objects(monkeypatch):
monomer = MonomericObject("proteinA", "ACDEFG")
monomer.feature_dict = {"template_aatype": np.ones((1,), dtype=np.float32)}
monomer.skip_msa = True
class FakeChoppedObject:
def __init__(self, description, sequence, feature_dict, regions):
self.description = description
self.sequence = sequence
self.feature_dict = feature_dict
self.regions = regions
self.prepared = False
def prepare_final_sliced_feature_dict(self):
self.prepared = True
monkeypatch.setattr(
modelling_setup,
"make_dir_monomer_dictionary",
lambda _: {"proteinA.pkl": "/unused"},
)
monkeypatch.setattr(modelling_setup, "load_monomer_objects", lambda *_: monomer)
monkeypatch.setattr(modelling_setup, "check_empty_templates", lambda _: False)
monkeypatch.setattr(modelling_setup, "ChoppedObject", FakeChoppedObject)
result = modelling_setup.create_interactors(
[{"col_1": [{"proteinA": [(2, 4)]}]}],
["/unused"],
)
chopped = result[0][0]
assert isinstance(chopped, FakeChoppedObject)
assert chopped.prepared is True
assert chopped.skip_msa is True
def test_create_interactors_currently_skips_append_when_templates_are_empty(monkeypatch):
monomer = MonomericObject("proteinA", "ACDE")
monomer.feature_dict = {}

View File

@@ -145,7 +145,9 @@ def test_make_features_rezips_when_inputs_were_zipped_and_compression_is_enabled
staticmethod(lambda path: zip_calls.append(path)),
)
monkeypatch.setattr(
MonomericObject, "remove_msa_files", staticmethod(lambda _path: None)
MonomericObject,
"remove_msa_files",
staticmethod(lambda msa_output_path=None, **_kwargs: None),
)
monomer.make_features(
@@ -196,6 +198,86 @@ def test_make_features_removes_msa_when_precomputed_inputs_are_not_saved(
assert remove_calls == [str(tmp_path / "proteinA")]
def test_make_features_skip_msa_builds_query_only_features_and_templates(
monkeypatch, tmp_path
):
monomer = MonomericObject("proteinA", "ACDE")
calls = {}
class FakeTemplateSearcher:
input_format = "a3m"
output_format = "hhr"
def query(self, alignment):
calls["template_query"] = alignment
return "template_hits"
def get_template_hits(self, output_string, input_sequence):
calls["template_hits"] = (output_string, input_sequence)
return ["hitA"]
class FakeTemplateFeaturizer:
def get_templates(self, query_sequence, hits):
calls["template_features"] = (query_sequence, hits)
return SimpleNamespace(
features={
"template_aatype": np.ones((1, 4, 22), dtype=np.float32),
"template_all_atom_masks": np.ones((1, 4, 37), dtype=np.float32),
"template_all_atom_positions": np.ones(
(1, 4, 37, 3), dtype=np.float32
),
"template_domain_names": np.asarray([b"1abc_A"], dtype=object),
"template_sequence": np.asarray([b"ACDE"], dtype=object),
"template_sum_probs": np.asarray([0.5], dtype=np.float32),
}
)
class FakePipeline:
template_searcher = FakeTemplateSearcher()
template_featurizer = FakeTemplateFeaturizer()
def process(self, *_args, **_kwargs):
raise AssertionError("skip_msa should bypass pipeline.process")
monkeypatch.setattr(
MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False)
)
monkeypatch.setattr(
monomer,
"all_seq_msa_features",
lambda *_args, **_kwargs: (_ for _ in ()).throw(
AssertionError("skip_msa should bypass all_seq_msa_features")
),
)
monkeypatch.setattr(
MonomericObject,
"remove_msa_files",
staticmethod(lambda msa_output_path=None, **_kwargs: None),
)
monkeypatch.setattr(
MonomericObject, "zip_msa_files", staticmethod(lambda _path: None)
)
monomer.make_features(
pipeline=FakePipeline(),
output_dir=str(tmp_path),
save_msa=False,
skip_msa=True,
)
assert calls["template_query"] == ">query\nACDE\n"
assert calls["template_hits"] == ("template_hits", "ACDE")
assert calls["template_features"] == ("ACDE", ["hitA"])
assert monomer.skip_msa is True
assert monomer.feature_dict["msa"].shape == (1, 4)
assert monomer.feature_dict["msa_all_seq"].shape == (1, 4)
assert np.array_equal(
monomer.feature_dict["num_alignments"], np.asarray([1, 1, 1, 1], dtype=np.int32)
)
assert monomer.feature_dict["msa_species_identifiers_all_seq"].tolist() == [b""]
assert monomer.feature_dict["template_domain_names"].tolist() == [b"1abc_A"]
def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m(
monkeypatch, tmp_path
):
@@ -265,6 +347,68 @@ def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m(
assert monomer.feature_dict["template_release_date"] == ["none"]
def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
monkeypatch, tmp_path
):
monomer = MonomericObject("proteinA", "ACDE")
calls = {}
monkeypatch.setattr(
MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False)
)
def fake_get_msa_and_templates(**kwargs):
calls["get_msa_and_templates"] = kwargs
return (["UNPAIRED"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"])
monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates)
monkeypatch.setattr(
objects_mod,
"build_monomer_feature",
lambda sequence, msa, template_features: {
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int": np.asarray([[0, 0, 0, 0]], dtype=np.int32),
"template_confidence_scores": None,
"template_release_date": None,
},
)
def fake_enrich(feature_dict, a3m, **kwargs):
calls["enrich"] = {"a3m": a3m, "kwargs": kwargs}
feature_dict["msa_species_identifiers"] = np.asarray([b""], dtype=object)
feature_dict["msa_uniprot_accession_identifiers"] = np.asarray(
[b""], dtype=object
)
monkeypatch.setattr(
objects_mod,
"enrich_mmseq_feature_dict_with_identifiers",
fake_enrich,
)
monkeypatch.setattr(
MonomericObject, "zip_msa_files", staticmethod(lambda _path: None)
)
monomer.make_mmseq_features(
DEFAULT_API_SERVER="https://fake.server",
output_dir=str(tmp_path),
use_templates=True,
skip_msa=True,
)
assert calls["get_msa_and_templates"]["msa_mode"] == "single_sequence"
assert calls["get_msa_and_templates"]["pair_mode"] == "none"
assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"]
assert calls["get_msa_and_templates"]["use_templates"] is True
assert calls["enrich"]["a3m"] == ">101\nACDE"
assert monomer.skip_msa is True
assert monomer.feature_dict["msa"].shape == (1, 4)
assert monomer.feature_dict["msa_all_seq"].shape == (1, 4)
assert monomer.feature_dict["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
b""
]
def test_make_mmseq_features_compresses_fresh_mmseq_result_dir(
monkeypatch, tmp_path
):

View File

@@ -790,6 +790,76 @@ def test_pre_modelling_setup_warns_for_long_paths_and_uses_chopped_metadata_name
assert any("No feature metadata found for fragmentA" in message for message in warnings)
def test_pre_modelling_setup_rejects_pair_msa_for_skip_msa_interactors(
run_structure_prediction_module,
tmp_path,
):
_set_flag(run_structure_prediction_module.FLAGS, "pair_msa", True)
_set_flag(run_structure_prediction_module.FLAGS, "multimeric_template", False)
_set_flag(run_structure_prediction_module.FLAGS, "description_file", None)
_set_flag(run_structure_prediction_module.FLAGS, "path_to_mmt", None)
_set_flag(run_structure_prediction_module.FLAGS, "save_features_for_multimeric_object", False)
_set_flag(
run_structure_prediction_module.FLAGS,
"features_directory",
[str(tmp_path / "features")],
)
_set_flag(run_structure_prediction_module.FLAGS, "use_ap_style", False)
feature_dir = tmp_path / "features"
feature_dir.mkdir()
(feature_dir / "protA_feature_metadata_2026-03-30.json").write_text(
'{"meta": 1}',
encoding="utf-8",
)
monomer = run_structure_prediction_module.MonomericObject("protA", "ACDE")
monomer.skip_msa = True
with pytest.raises(ValueError, match="--pair_msa=False"):
run_structure_prediction_module.pre_modelling_setup(
[monomer],
output_dir=str(tmp_path / "outputs"),
)
def test_pre_modelling_setup_allows_skip_msa_when_pairing_disabled(
run_structure_prediction_module,
tmp_path,
):
_set_flag(run_structure_prediction_module.FLAGS, "pair_msa", False)
_set_flag(run_structure_prediction_module.FLAGS, "multimeric_template", False)
_set_flag(run_structure_prediction_module.FLAGS, "description_file", None)
_set_flag(run_structure_prediction_module.FLAGS, "path_to_mmt", None)
_set_flag(run_structure_prediction_module.FLAGS, "save_features_for_multimeric_object", False)
_set_flag(
run_structure_prediction_module.FLAGS,
"features_directory",
[str(tmp_path / "features")],
)
_set_flag(run_structure_prediction_module.FLAGS, "use_ap_style", False)
feature_dir = tmp_path / "features"
feature_dir.mkdir()
for description in ("protA", "protB"):
(feature_dir / f"{description}_feature_metadata_2026-03-30.json").write_text(
'{"meta": 1}',
encoding="utf-8",
)
monomer_a = run_structure_prediction_module.MonomericObject("protA", "AAAA")
monomer_a.skip_msa = True
monomer_b = run_structure_prediction_module.MonomericObject("protB", "BBBB")
returned_object, _ = run_structure_prediction_module.pre_modelling_setup(
[monomer_a, monomer_b],
output_dir=str(tmp_path / "outputs"),
)
assert isinstance(returned_object, run_structure_prediction_module.MultimericObject)
assert returned_object.pair_msa is False
def test_main_routes_protein_and_json_jobs_to_predict_structure(
run_structure_prediction_module,
monkeypatch,