mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 05:58:11 +08:00
fix(#42): add skip-MSA feature generation mode
This commit is contained in:
@@ -26,6 +26,17 @@ from alphapulldown.utils.mmseqs_species_identifiers import (
|
||||
strip_mmseq_comment_lines,
|
||||
)
|
||||
|
||||
|
||||
def _query_only_a3m(sequence: str, query_id: str = "query") -> str:
|
||||
"""Return a single-sequence A3M string for query-only workflows."""
|
||||
return f">{query_id}\n{sequence}\n"
|
||||
|
||||
|
||||
def _query_only_stockholm(sequence: str, query_id: str = "query") -> str:
|
||||
"""Return a single-sequence Stockholm alignment string."""
|
||||
return f"# STOCKHOLM 1.0\n{query_id} {sequence}\n//\n"
|
||||
|
||||
|
||||
class MonomericObject:
|
||||
"""
|
||||
monomeric objects
|
||||
@@ -41,6 +52,7 @@ class MonomericObject:
|
||||
self.sequence = sequence
|
||||
self.feature_dict = dict()
|
||||
self._uniprot_runner = None
|
||||
self.skip_msa = False
|
||||
pass
|
||||
|
||||
@property
|
||||
@@ -140,7 +152,8 @@ class MonomericObject:
|
||||
def make_features(
|
||||
self, pipeline, output_dir: str,
|
||||
use_precomputed_msa: bool = False,
|
||||
save_msa: bool = True, compress_msa_files: bool = False
|
||||
save_msa: bool = True, compress_msa_files: bool = False,
|
||||
skip_msa: bool = False,
|
||||
):
|
||||
"""a method that make msa and template features"""
|
||||
os.makedirs(os.path.join(output_dir, self.description), exist_ok=True)
|
||||
@@ -155,13 +168,20 @@ class MonomericObject:
|
||||
logging.info(
|
||||
"will save msa files in :{}".format(msa_output_dir))
|
||||
plPath(msa_output_dir).mkdir(parents=True, exist_ok=True)
|
||||
with temp_fasta_file(sequence_str) as fasta_file:
|
||||
self.feature_dict = pipeline.process(
|
||||
fasta_file, msa_output_dir)
|
||||
pairing_results = self.all_seq_msa_features(
|
||||
fasta_file, self._uniprot_runner, msa_output_dir, use_precomputed_msa
|
||||
self.skip_msa = skip_msa
|
||||
if skip_msa:
|
||||
self.feature_dict = self._build_query_only_feature_dict()
|
||||
self.feature_dict.update(
|
||||
self._search_templates_with_query_only_msa(pipeline, msa_output_dir)
|
||||
)
|
||||
self.feature_dict.update(pairing_results)
|
||||
else:
|
||||
with temp_fasta_file(sequence_str) as fasta_file:
|
||||
self.feature_dict = pipeline.process(
|
||||
fasta_file, msa_output_dir)
|
||||
pairing_results = self.all_seq_msa_features(
|
||||
fasta_file, self._uniprot_runner, msa_output_dir, use_precomputed_msa
|
||||
)
|
||||
self.feature_dict.update(pairing_results)
|
||||
|
||||
# Add extra features to make it compatible with pickle features obtaiend from mmseqs2
|
||||
template_confidence_scores = self.feature_dict.get('template_confidence_scores', None)
|
||||
@@ -187,6 +207,60 @@ class MonomericObject:
|
||||
MonomericObject.zip_msa_files(
|
||||
os.path.join(output_dir, self.description))
|
||||
|
||||
def _build_query_only_feature_dict(self) -> Dict[str, Any]:
|
||||
"""Build AF2-compatible features with the query as the only MSA row."""
|
||||
query_only_msa = parsers.parse_a3m(_query_only_a3m(self.sequence))
|
||||
sequence_features = pipeline.make_sequence_features(
|
||||
sequence=self.sequence,
|
||||
description=self.description,
|
||||
num_res=len(self.sequence),
|
||||
)
|
||||
msa_features = pipeline.make_msa_features((query_only_msa,))
|
||||
all_seq_features = {
|
||||
f"{key}_all_seq": np.array(value, copy=True)
|
||||
for key, value in msa_features.items()
|
||||
}
|
||||
all_seq_features["msa_uniprot_accession_identifiers_all_seq"] = np.array(
|
||||
[b""], dtype=object
|
||||
)
|
||||
return {**sequence_features, **msa_features, **all_seq_features}
|
||||
|
||||
def _search_templates_with_query_only_msa(
|
||||
self, af2_pipeline: pipeline.DataPipeline, msa_output_dir: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Run template search from a synthetic single-sequence alignment."""
|
||||
template_searcher = getattr(af2_pipeline, "template_searcher", None)
|
||||
template_featurizer = getattr(af2_pipeline, "template_featurizer", None)
|
||||
if template_searcher is None or template_featurizer is None:
|
||||
return {}
|
||||
|
||||
stockholm_msa = _query_only_stockholm(self.sequence)
|
||||
if template_searcher.input_format == "sto":
|
||||
template_query = stockholm_msa
|
||||
elif template_searcher.input_format == "a3m":
|
||||
template_query = _query_only_a3m(self.sequence)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unrecognized template input format: "
|
||||
f"{template_searcher.input_format}"
|
||||
)
|
||||
|
||||
pdb_templates_result = template_searcher.query(template_query)
|
||||
pdb_hits_out_path = os.path.join(
|
||||
msa_output_dir, f"pdb_hits.{template_searcher.output_format}"
|
||||
)
|
||||
with open(pdb_hits_out_path, "w") as handle:
|
||||
handle.write(pdb_templates_result)
|
||||
|
||||
pdb_template_hits = template_searcher.get_template_hits(
|
||||
output_string=pdb_templates_result,
|
||||
input_sequence=self.sequence,
|
||||
)
|
||||
templates_result = template_featurizer.get_templates(
|
||||
query_sequence=self.sequence,
|
||||
hits=pdb_template_hits,
|
||||
)
|
||||
return dict(templates_result.features)
|
||||
|
||||
def make_mmseq_features(
|
||||
self, DEFAULT_API_SERVER,
|
||||
@@ -195,6 +269,7 @@ class MonomericObject:
|
||||
use_precomputed_msa=False,
|
||||
use_templates=False,
|
||||
custom_template_path=None,
|
||||
skip_msa: bool = False,
|
||||
):
|
||||
"""
|
||||
A method to use mmseq_remote to calculate MSA.
|
||||
@@ -212,7 +287,29 @@ class MonomericObject:
|
||||
logging.info(f"Skipping {self.description} (result.zip)")
|
||||
|
||||
a3m_path = os.path.join(result_dir, self.description + ".a3m")
|
||||
if use_precomputed_msa and os.path.isfile(a3m_path):
|
||||
self.skip_msa = skip_msa
|
||||
if skip_msa:
|
||||
a3m_lines = [_query_only_a3m(self.sequence, query_id="101")]
|
||||
plPath(a3m_path).write_text(a3m_lines[0])
|
||||
(
|
||||
unpaired_msa,
|
||||
paired_msa,
|
||||
query_seqs_unique,
|
||||
query_seqs_cardinality,
|
||||
template_features,
|
||||
) = get_msa_and_templates(
|
||||
jobname=self.description,
|
||||
query_sequences=self.sequence,
|
||||
a3m_lines=a3m_lines,
|
||||
result_dir=plPath(result_dir),
|
||||
msa_mode="single_sequence",
|
||||
use_templates=use_templates,
|
||||
custom_template_path=custom_template_path,
|
||||
pair_mode="none",
|
||||
host_url=DEFAULT_API_SERVER,
|
||||
user_agent="alphapulldown",
|
||||
)
|
||||
elif use_precomputed_msa and os.path.isfile(a3m_path):
|
||||
logging.info(f"Using precomputed MSA from {a3m_path}")
|
||||
a3m_lines = [plPath(a3m_path).read_text()]
|
||||
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality,
|
||||
|
||||
@@ -142,6 +142,7 @@ flags.DEFINE_boolean('use_precomputed_msas', False, '')
|
||||
flags.DEFINE_boolean('re_search_templates_mmseqs2', False, '')
|
||||
flags.DEFINE_bool("use_mmseqs2", False, "")
|
||||
flags.DEFINE_bool("save_msa_files", False, "")
|
||||
flags.DEFINE_bool("skip_msa", False, "")
|
||||
flags.DEFINE_bool("skip_existing", False, "")
|
||||
flags.DEFINE_string("new_uniclust_dir", None, "")
|
||||
flags.DEFINE_integer("seq_index", None, "")
|
||||
@@ -272,11 +273,40 @@ def get_af3_chain_kind(description, sequence):
|
||||
def create_af3_chain(sequence, description, chain_id):
|
||||
"""Construct an AF3 chain object for the provided sequence."""
|
||||
chain_kind = get_af3_chain_kind(description, sequence)
|
||||
query_only_a3m = f">query\n{sequence}\n" if FLAGS.skip_msa else None
|
||||
if chain_kind == "dna":
|
||||
return folding_input.DnaChain(sequence=sequence, id=chain_id, modifications=[])
|
||||
return folding_input.DnaChain(
|
||||
sequence=sequence,
|
||||
id=chain_id,
|
||||
modifications=[],
|
||||
description=description,
|
||||
)
|
||||
if chain_kind == "rna":
|
||||
return folding_input.RnaChain(sequence=sequence, id=chain_id, modifications=[])
|
||||
return folding_input.ProteinChain(sequence=sequence, id=chain_id, ptms=[])
|
||||
kwargs = {
|
||||
"sequence": sequence,
|
||||
"id": chain_id,
|
||||
"modifications": [],
|
||||
"description": description,
|
||||
}
|
||||
if FLAGS.skip_msa:
|
||||
kwargs["unpaired_msa"] = query_only_a3m
|
||||
return folding_input.RnaChain(**kwargs)
|
||||
|
||||
kwargs = {
|
||||
"sequence": sequence,
|
||||
"id": chain_id,
|
||||
"ptms": [],
|
||||
"description": description,
|
||||
}
|
||||
if FLAGS.skip_msa:
|
||||
kwargs.update(
|
||||
{
|
||||
"paired_msa": "",
|
||||
"unpaired_msa": query_only_a3m,
|
||||
"templates": None,
|
||||
}
|
||||
)
|
||||
return folding_input.ProteinChain(**kwargs)
|
||||
|
||||
# =================== AlphaFold 2 Feature Creation ===================
|
||||
|
||||
@@ -444,6 +474,13 @@ def _reuse_truemultimer_monomer_features(feat):
|
||||
monomer = _load_existing_monomer_from_output_dir(source_name)
|
||||
if monomer is None:
|
||||
return None
|
||||
if FLAGS.skip_msa and not getattr(monomer, "skip_msa", False):
|
||||
logging.info(
|
||||
"Existing monomer features for %s were generated with bulk MSAs. "
|
||||
"Recomputing query-only features for --skip_msa.",
|
||||
source_name,
|
||||
)
|
||||
return None
|
||||
if monomer.sequence != feat["sequence"]:
|
||||
logging.warning(
|
||||
"Existing monomer features for %s use sequence %s, but the current "
|
||||
@@ -489,6 +526,7 @@ def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None
|
||||
|
||||
if _should_skip_monomer_output(monomer.description):
|
||||
return
|
||||
monomer.skip_msa = FLAGS.skip_msa
|
||||
if FLAGS.use_mmseqs2:
|
||||
monomer.make_mmseq_features(
|
||||
DEFAULT_API_SERVER=DEFAULT_API_SERVER,
|
||||
@@ -496,12 +534,15 @@ def create_and_save_monomer_objects(monomer, pipeline, custom_template_path=None
|
||||
use_precomputed_msa=FLAGS.use_precomputed_msas,
|
||||
use_templates=FLAGS.re_search_templates_mmseqs2 or custom_template_path is not None,
|
||||
custom_template_path=custom_template_path,
|
||||
skip_msa=FLAGS.skip_msa,
|
||||
)
|
||||
else:
|
||||
monomer.make_features(
|
||||
pipeline=pipeline, output_dir=FLAGS.output_dir,
|
||||
use_precomputed_msa=FLAGS.use_precomputed_msas,
|
||||
save_msa=FLAGS.save_msa_files)
|
||||
save_msa=FLAGS.save_msa_files,
|
||||
skip_msa=FLAGS.skip_msa,
|
||||
)
|
||||
_persist_monomer_outputs(monomer)
|
||||
|
||||
def create_individual_features_truemultimer():
|
||||
|
||||
@@ -348,6 +348,12 @@ def pre_modelling_setup(
|
||||
A MultimericObject or MonomericObject
|
||||
output_directory for this particular modelling job
|
||||
"""
|
||||
if FLAGS.pair_msa and any(getattr(interactor, "skip_msa", False) for interactor in interactors):
|
||||
raise ValueError(
|
||||
"--skip_msa generates query-only MSAs and cannot be combined with "
|
||||
"--pair_msa=True. Re-run structure prediction with --pair_msa=False."
|
||||
)
|
||||
|
||||
if len(interactors) > 1:
|
||||
# this means it's going to be a MultimericObject
|
||||
object_to_model = MultimericObject(
|
||||
|
||||
@@ -271,6 +271,7 @@ def create_interactors(data : List[Dict[str, List[str]]],
|
||||
monomer.feature_dict,
|
||||
curr_interactor_region,
|
||||
)
|
||||
chopped_object.skip_msa = getattr(monomer, "skip_msa", False)
|
||||
chopped_object.prepare_final_sliced_feature_dict()
|
||||
interactors.append(chopped_object)
|
||||
return interactors
|
||||
|
||||
@@ -163,6 +163,7 @@ def tmp_flags(monkeypatch, tmp_path):
|
||||
use_mmseqs2=False,
|
||||
use_precomputed_msas=False,
|
||||
save_msa_files=False,
|
||||
skip_msa=False,
|
||||
skip_existing=False,
|
||||
compress_features=False,
|
||||
db_preset="full_dbs",
|
||||
|
||||
@@ -78,22 +78,50 @@ def build_af3_stub_modules():
|
||||
mmcif_mod = types.ModuleType("alphafold3.structure.mmcif")
|
||||
|
||||
class ProteinChain:
|
||||
def __init__(self, sequence, id, ptms=None):
|
||||
def __init__(
|
||||
self,
|
||||
sequence,
|
||||
id,
|
||||
ptms=None,
|
||||
residue_ids=None,
|
||||
description=None,
|
||||
paired_msa=None,
|
||||
unpaired_msa=None,
|
||||
templates=None,
|
||||
):
|
||||
self.sequence = sequence
|
||||
self.id = id
|
||||
self.ptms = [] if ptms is None else list(ptms)
|
||||
self.residue_ids = residue_ids
|
||||
self.description = description
|
||||
self.paired_msa = paired_msa
|
||||
self.unpaired_msa = unpaired_msa
|
||||
self.templates = templates
|
||||
|
||||
class RnaChain:
|
||||
def __init__(self, sequence, id, modifications=None):
|
||||
def __init__(
|
||||
self,
|
||||
sequence,
|
||||
id,
|
||||
modifications=None,
|
||||
residue_ids=None,
|
||||
description=None,
|
||||
unpaired_msa=None,
|
||||
):
|
||||
self.sequence = sequence
|
||||
self.id = id
|
||||
self.modifications = [] if modifications is None else list(modifications)
|
||||
self.residue_ids = residue_ids
|
||||
self.description = description
|
||||
self.unpaired_msa = unpaired_msa
|
||||
|
||||
class DnaChain:
|
||||
def __init__(self, sequence, id, modifications=None):
|
||||
def __init__(self, sequence, id, modifications=None, residue_ids=None, description=None):
|
||||
self.sequence = sequence
|
||||
self.id = id
|
||||
self.modifications = [] if modifications is None else list(modifications)
|
||||
self.residue_ids = residue_ids
|
||||
self.description = description
|
||||
|
||||
class Input:
|
||||
def __init__(self, name, chains, rng_seeds):
|
||||
@@ -598,6 +626,62 @@ class TestCreateIndividualFeaturesComprehensive:
|
||||
assert saved_monomer.sequence == "ACDF"
|
||||
assert saved_monomer.uniprot_runner == "runner"
|
||||
|
||||
def test_process_multimeric_features_does_not_reuse_bulk_msa_pickle_for_skip_msa(
|
||||
self, tmp_flags
|
||||
):
|
||||
template_path = Path(self.test_dir) / "template1.cif"
|
||||
template_path.write_text("data_template\n", encoding="utf-8")
|
||||
|
||||
from absl import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS(["test"])
|
||||
FLAGS.output_dir = os.path.join(self.test_dir, "skip_msa_truemultimer_output")
|
||||
FLAGS.use_mmseqs2 = False
|
||||
FLAGS.compress_features = False
|
||||
FLAGS.skip_existing = False
|
||||
FLAGS.skip_msa = True
|
||||
FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer"
|
||||
FLAGS.uniprot_database_path = "/db/uniprot.fasta"
|
||||
|
||||
output_dir = Path(FLAGS.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
source = MonomericObject("proteinA", "ACDE")
|
||||
source.feature_dict = {
|
||||
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
|
||||
"deletion_matrix_int": np.zeros((1, 4), dtype=np.int32),
|
||||
"num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32),
|
||||
"msa_species_identifiers": np.asarray([b"9606"], dtype=object),
|
||||
}
|
||||
with open(output_dir / "proteinA.pkl", "wb") as handle:
|
||||
pickle.dump(source, handle)
|
||||
|
||||
feat = {
|
||||
"protein": "proteinA.template1.cif.A",
|
||||
"chains": ["A"],
|
||||
"templates": [str(template_path)],
|
||||
"sequence": "ACDE",
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
create_features,
|
||||
"extract_multimeric_template_features_for_single_chain",
|
||||
) as mock_extract, \
|
||||
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
|
||||
patch.object(create_features, "create_arguments") as mock_arguments, \
|
||||
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
|
||||
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
|
||||
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
|
||||
create_features.process_multimeric_features(feat, 1)
|
||||
|
||||
mock_extract.assert_not_called()
|
||||
mock_custom_db.assert_called_once()
|
||||
mock_arguments.assert_called_once_with("/tmp/custom_db")
|
||||
mock_pipeline.assert_called_once_with()
|
||||
mock_runner.assert_called_once_with("/usr/bin/jackhmmer", "/db/uniprot.fasta")
|
||||
mock_save.assert_called_once()
|
||||
|
||||
def test_main_dispatches_to_truemultimer_for_af2_template_runs(self):
|
||||
"""The main entrypoint should route AF2 template jobs to the TrueMultimer path."""
|
||||
from absl import flags
|
||||
@@ -1318,6 +1402,7 @@ def test_create_and_save_monomer_objects_writes_compressed_af2_outputs(tmp_flags
|
||||
"output_dir": str(tmp_path),
|
||||
"use_precomputed_msa": True,
|
||||
"save_msa": True,
|
||||
"skip_msa": False,
|
||||
}
|
||||
]
|
||||
assert monomer.mmseq_calls == []
|
||||
@@ -1359,11 +1444,61 @@ def test_create_and_save_monomer_objects_uses_mmseqs_when_requested(tmp_flags, t
|
||||
"use_precomputed_msa": True,
|
||||
"use_templates": True,
|
||||
"custom_template_path": None,
|
||||
"skip_msa": False,
|
||||
}
|
||||
]
|
||||
assert (tmp_path / "protA.pkl").exists()
|
||||
|
||||
|
||||
def test_create_and_save_monomer_objects_passes_skip_msa_to_af2_builder(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.output_dir = str(tmp_path)
|
||||
create_features.FLAGS.compress_features = False
|
||||
create_features.FLAGS.skip_existing = False
|
||||
create_features.FLAGS.use_mmseqs2 = False
|
||||
create_features.FLAGS.use_precomputed_msas = False
|
||||
create_features.FLAGS.save_msa_files = False
|
||||
create_features.FLAGS.skip_msa = True
|
||||
|
||||
monomer = RecordingDummyMonomer("protA")
|
||||
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
|
||||
create_features.create_and_save_monomer_objects(monomer, pipeline="pipeline")
|
||||
|
||||
assert monomer.feature_calls == [
|
||||
{
|
||||
"pipeline": "pipeline",
|
||||
"output_dir": str(tmp_path),
|
||||
"use_precomputed_msa": False,
|
||||
"save_msa": False,
|
||||
"skip_msa": True,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_create_and_save_monomer_objects_passes_skip_msa_to_mmseqs_builder(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.output_dir = str(tmp_path)
|
||||
create_features.FLAGS.compress_features = False
|
||||
create_features.FLAGS.skip_existing = False
|
||||
create_features.FLAGS.use_mmseqs2 = True
|
||||
create_features.FLAGS.use_precomputed_msas = False
|
||||
create_features.FLAGS.re_search_templates_mmseqs2 = False
|
||||
create_features.FLAGS.skip_msa = True
|
||||
|
||||
monomer = RecordingDummyMonomer("protA")
|
||||
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
|
||||
create_features.create_and_save_monomer_objects(monomer, pipeline=None)
|
||||
|
||||
assert monomer.mmseq_calls == [
|
||||
{
|
||||
"DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER,
|
||||
"output_dir": str(tmp_path),
|
||||
"use_precomputed_msa": False,
|
||||
"use_templates": False,
|
||||
"custom_template_path": None,
|
||||
"skip_msa": True,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_create_and_save_monomer_objects_passes_custom_templates_to_mmseqs(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.output_dir = str(tmp_path)
|
||||
create_features.FLAGS.compress_features = False
|
||||
@@ -1390,6 +1525,7 @@ def test_create_and_save_monomer_objects_passes_custom_templates_to_mmseqs(tmp_f
|
||||
"use_precomputed_msa": False,
|
||||
"use_templates": True,
|
||||
"custom_template_path": custom_template_path,
|
||||
"skip_msa": False,
|
||||
}
|
||||
]
|
||||
assert (tmp_path / "protA.pkl").exists()
|
||||
@@ -1661,6 +1797,37 @@ def test_create_af3_individual_features_skips_existing_outputs(tmp_flags, tmp_pa
|
||||
assert existing_output.read_text(encoding="utf-8") == "{}"
|
||||
|
||||
|
||||
def test_create_af3_individual_features_prefills_query_only_msas_when_skip_msa(
|
||||
tmp_flags, tmp_path
|
||||
):
|
||||
create_features.FLAGS.output_dir = str(tmp_path)
|
||||
create_features.FLAGS.data_pipeline = "alphafold3"
|
||||
create_features.FLAGS.skip_msa = True
|
||||
|
||||
af3_modules, folding_input_stub = build_af3_stub_modules()
|
||||
pipeline = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
|
||||
with patch.dict(sys.modules, af3_modules), \
|
||||
patch.object(create_features, "create_pipeline_af3", return_value=pipeline), \
|
||||
patch.object(create_features, "folding_input", folding_input_stub), \
|
||||
patch.object(
|
||||
create_features,
|
||||
"iter_seqs",
|
||||
return_value=[("ACDE", "protein_chain protein"), ("AUGA", "rna_chain RNA")],
|
||||
), \
|
||||
patch("pathlib.Path.write_text", new=real_write_text):
|
||||
create_features.create_af3_individual_features()
|
||||
|
||||
protein_input = pipeline.process.call_args_list[0].args[0]
|
||||
protein_chain = protein_input.chains[0]
|
||||
assert protein_chain.unpaired_msa == ">query\nACDE\n"
|
||||
assert protein_chain.paired_msa == ""
|
||||
assert protein_chain.templates is None
|
||||
|
||||
rna_input = pipeline.process.call_args_list[1].args[0]
|
||||
rna_chain = rna_input.chains[0]
|
||||
assert rna_chain.unpaired_msa == ">query\nAUGA\n"
|
||||
|
||||
|
||||
def test_main_dispatches_to_af3_feature_creation(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.data_pipeline = "alphafold3"
|
||||
create_features.FLAGS.output_dir = str(tmp_path / "af3_out")
|
||||
|
||||
@@ -220,6 +220,42 @@ def test_create_interactors_builds_chopped_object_for_region_lists(monkeypatch):
|
||||
assert calls["args"] == ("proteinA", "ACDEFG", monomer.feature_dict, [(2, 4)])
|
||||
|
||||
|
||||
def test_create_interactors_propagates_skip_msa_marker_to_chopped_objects(monkeypatch):
|
||||
monomer = MonomericObject("proteinA", "ACDEFG")
|
||||
monomer.feature_dict = {"template_aatype": np.ones((1,), dtype=np.float32)}
|
||||
monomer.skip_msa = True
|
||||
|
||||
class FakeChoppedObject:
|
||||
def __init__(self, description, sequence, feature_dict, regions):
|
||||
self.description = description
|
||||
self.sequence = sequence
|
||||
self.feature_dict = feature_dict
|
||||
self.regions = regions
|
||||
self.prepared = False
|
||||
|
||||
def prepare_final_sliced_feature_dict(self):
|
||||
self.prepared = True
|
||||
|
||||
monkeypatch.setattr(
|
||||
modelling_setup,
|
||||
"make_dir_monomer_dictionary",
|
||||
lambda _: {"proteinA.pkl": "/unused"},
|
||||
)
|
||||
monkeypatch.setattr(modelling_setup, "load_monomer_objects", lambda *_: monomer)
|
||||
monkeypatch.setattr(modelling_setup, "check_empty_templates", lambda _: False)
|
||||
monkeypatch.setattr(modelling_setup, "ChoppedObject", FakeChoppedObject)
|
||||
|
||||
result = modelling_setup.create_interactors(
|
||||
[{"col_1": [{"proteinA": [(2, 4)]}]}],
|
||||
["/unused"],
|
||||
)
|
||||
|
||||
chopped = result[0][0]
|
||||
assert isinstance(chopped, FakeChoppedObject)
|
||||
assert chopped.prepared is True
|
||||
assert chopped.skip_msa is True
|
||||
|
||||
|
||||
def test_create_interactors_currently_skips_append_when_templates_are_empty(monkeypatch):
|
||||
monomer = MonomericObject("proteinA", "ACDE")
|
||||
monomer.feature_dict = {}
|
||||
|
||||
@@ -145,7 +145,9 @@ def test_make_features_rezips_when_inputs_were_zipped_and_compression_is_enabled
|
||||
staticmethod(lambda path: zip_calls.append(path)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
MonomericObject, "remove_msa_files", staticmethod(lambda _path: None)
|
||||
MonomericObject,
|
||||
"remove_msa_files",
|
||||
staticmethod(lambda msa_output_path=None, **_kwargs: None),
|
||||
)
|
||||
|
||||
monomer.make_features(
|
||||
@@ -196,6 +198,86 @@ def test_make_features_removes_msa_when_precomputed_inputs_are_not_saved(
|
||||
assert remove_calls == [str(tmp_path / "proteinA")]
|
||||
|
||||
|
||||
def test_make_features_skip_msa_builds_query_only_features_and_templates(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
monomer = MonomericObject("proteinA", "ACDE")
|
||||
calls = {}
|
||||
|
||||
class FakeTemplateSearcher:
|
||||
input_format = "a3m"
|
||||
output_format = "hhr"
|
||||
|
||||
def query(self, alignment):
|
||||
calls["template_query"] = alignment
|
||||
return "template_hits"
|
||||
|
||||
def get_template_hits(self, output_string, input_sequence):
|
||||
calls["template_hits"] = (output_string, input_sequence)
|
||||
return ["hitA"]
|
||||
|
||||
class FakeTemplateFeaturizer:
|
||||
def get_templates(self, query_sequence, hits):
|
||||
calls["template_features"] = (query_sequence, hits)
|
||||
return SimpleNamespace(
|
||||
features={
|
||||
"template_aatype": np.ones((1, 4, 22), dtype=np.float32),
|
||||
"template_all_atom_masks": np.ones((1, 4, 37), dtype=np.float32),
|
||||
"template_all_atom_positions": np.ones(
|
||||
(1, 4, 37, 3), dtype=np.float32
|
||||
),
|
||||
"template_domain_names": np.asarray([b"1abc_A"], dtype=object),
|
||||
"template_sequence": np.asarray([b"ACDE"], dtype=object),
|
||||
"template_sum_probs": np.asarray([0.5], dtype=np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
class FakePipeline:
|
||||
template_searcher = FakeTemplateSearcher()
|
||||
template_featurizer = FakeTemplateFeaturizer()
|
||||
|
||||
def process(self, *_args, **_kwargs):
|
||||
raise AssertionError("skip_msa should bypass pipeline.process")
|
||||
|
||||
monkeypatch.setattr(
|
||||
MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
monomer,
|
||||
"all_seq_msa_features",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
AssertionError("skip_msa should bypass all_seq_msa_features")
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
MonomericObject,
|
||||
"remove_msa_files",
|
||||
staticmethod(lambda msa_output_path=None, **_kwargs: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
MonomericObject, "zip_msa_files", staticmethod(lambda _path: None)
|
||||
)
|
||||
|
||||
monomer.make_features(
|
||||
pipeline=FakePipeline(),
|
||||
output_dir=str(tmp_path),
|
||||
save_msa=False,
|
||||
skip_msa=True,
|
||||
)
|
||||
|
||||
assert calls["template_query"] == ">query\nACDE\n"
|
||||
assert calls["template_hits"] == ("template_hits", "ACDE")
|
||||
assert calls["template_features"] == ("ACDE", ["hitA"])
|
||||
assert monomer.skip_msa is True
|
||||
assert monomer.feature_dict["msa"].shape == (1, 4)
|
||||
assert monomer.feature_dict["msa_all_seq"].shape == (1, 4)
|
||||
assert np.array_equal(
|
||||
monomer.feature_dict["num_alignments"], np.asarray([1, 1, 1, 1], dtype=np.int32)
|
||||
)
|
||||
assert monomer.feature_dict["msa_species_identifiers_all_seq"].tolist() == [b""]
|
||||
assert monomer.feature_dict["template_domain_names"].tolist() == [b"1abc_A"]
|
||||
|
||||
|
||||
def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
@@ -265,6 +347,68 @@ def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m(
|
||||
assert monomer.feature_dict["template_release_date"] == ["none"]
|
||||
|
||||
|
||||
def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
monomer = MonomericObject("proteinA", "ACDE")
|
||||
calls = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False)
|
||||
)
|
||||
|
||||
def fake_get_msa_and_templates(**kwargs):
|
||||
calls["get_msa_and_templates"] = kwargs
|
||||
return (["UNPAIRED"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"])
|
||||
|
||||
monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates)
|
||||
monkeypatch.setattr(
|
||||
objects_mod,
|
||||
"build_monomer_feature",
|
||||
lambda sequence, msa, template_features: {
|
||||
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
|
||||
"deletion_matrix_int": np.asarray([[0, 0, 0, 0]], dtype=np.int32),
|
||||
"template_confidence_scores": None,
|
||||
"template_release_date": None,
|
||||
},
|
||||
)
|
||||
|
||||
def fake_enrich(feature_dict, a3m, **kwargs):
|
||||
calls["enrich"] = {"a3m": a3m, "kwargs": kwargs}
|
||||
feature_dict["msa_species_identifiers"] = np.asarray([b""], dtype=object)
|
||||
feature_dict["msa_uniprot_accession_identifiers"] = np.asarray(
|
||||
[b""], dtype=object
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
objects_mod,
|
||||
"enrich_mmseq_feature_dict_with_identifiers",
|
||||
fake_enrich,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
MonomericObject, "zip_msa_files", staticmethod(lambda _path: None)
|
||||
)
|
||||
|
||||
monomer.make_mmseq_features(
|
||||
DEFAULT_API_SERVER="https://fake.server",
|
||||
output_dir=str(tmp_path),
|
||||
use_templates=True,
|
||||
skip_msa=True,
|
||||
)
|
||||
|
||||
assert calls["get_msa_and_templates"]["msa_mode"] == "single_sequence"
|
||||
assert calls["get_msa_and_templates"]["pair_mode"] == "none"
|
||||
assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"]
|
||||
assert calls["get_msa_and_templates"]["use_templates"] is True
|
||||
assert calls["enrich"]["a3m"] == ">101\nACDE"
|
||||
assert monomer.skip_msa is True
|
||||
assert monomer.feature_dict["msa"].shape == (1, 4)
|
||||
assert monomer.feature_dict["msa_all_seq"].shape == (1, 4)
|
||||
assert monomer.feature_dict["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
|
||||
b""
|
||||
]
|
||||
|
||||
|
||||
def test_make_mmseq_features_compresses_fresh_mmseq_result_dir(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
|
||||
@@ -790,6 +790,76 @@ def test_pre_modelling_setup_warns_for_long_paths_and_uses_chopped_metadata_name
|
||||
assert any("No feature metadata found for fragmentA" in message for message in warnings)
|
||||
|
||||
|
||||
def test_pre_modelling_setup_rejects_pair_msa_for_skip_msa_interactors(
|
||||
run_structure_prediction_module,
|
||||
tmp_path,
|
||||
):
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "pair_msa", True)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "multimeric_template", False)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "description_file", None)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "path_to_mmt", None)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "save_features_for_multimeric_object", False)
|
||||
_set_flag(
|
||||
run_structure_prediction_module.FLAGS,
|
||||
"features_directory",
|
||||
[str(tmp_path / "features")],
|
||||
)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "use_ap_style", False)
|
||||
|
||||
feature_dir = tmp_path / "features"
|
||||
feature_dir.mkdir()
|
||||
(feature_dir / "protA_feature_metadata_2026-03-30.json").write_text(
|
||||
'{"meta": 1}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monomer = run_structure_prediction_module.MonomericObject("protA", "ACDE")
|
||||
monomer.skip_msa = True
|
||||
|
||||
with pytest.raises(ValueError, match="--pair_msa=False"):
|
||||
run_structure_prediction_module.pre_modelling_setup(
|
||||
[monomer],
|
||||
output_dir=str(tmp_path / "outputs"),
|
||||
)
|
||||
|
||||
|
||||
def test_pre_modelling_setup_allows_skip_msa_when_pairing_disabled(
|
||||
run_structure_prediction_module,
|
||||
tmp_path,
|
||||
):
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "pair_msa", False)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "multimeric_template", False)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "description_file", None)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "path_to_mmt", None)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "save_features_for_multimeric_object", False)
|
||||
_set_flag(
|
||||
run_structure_prediction_module.FLAGS,
|
||||
"features_directory",
|
||||
[str(tmp_path / "features")],
|
||||
)
|
||||
_set_flag(run_structure_prediction_module.FLAGS, "use_ap_style", False)
|
||||
|
||||
feature_dir = tmp_path / "features"
|
||||
feature_dir.mkdir()
|
||||
for description in ("protA", "protB"):
|
||||
(feature_dir / f"{description}_feature_metadata_2026-03-30.json").write_text(
|
||||
'{"meta": 1}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monomer_a = run_structure_prediction_module.MonomericObject("protA", "AAAA")
|
||||
monomer_a.skip_msa = True
|
||||
monomer_b = run_structure_prediction_module.MonomericObject("protB", "BBBB")
|
||||
|
||||
returned_object, _ = run_structure_prediction_module.pre_modelling_setup(
|
||||
[monomer_a, monomer_b],
|
||||
output_dir=str(tmp_path / "outputs"),
|
||||
)
|
||||
|
||||
assert isinstance(returned_object, run_structure_prediction_module.MultimericObject)
|
||||
assert returned_object.pair_msa is False
|
||||
|
||||
|
||||
def test_main_routes_protein_and_json_jobs_to_predict_structure(
|
||||
run_structure_prediction_module,
|
||||
monkeypatch,
|
||||
|
||||
Reference in New Issue
Block a user