Files
AlphaPulldown/test/integration/test_create_individual_features.py
2026-04-23 10:28:45 +02:00

1992 lines
85 KiB
Python

#!/usr/bin/env python3
"""
Comprehensive parametrized tests for create_individual_features.py using pytest.
Tests both AlphaFold2 and AlphaFold3 pipelines with various configurations.
"""
import os
import sys
import tempfile
import json
import lzma
import pickle
import pytest
import logging
import types
import numpy as np
from pathlib import Path
from unittest.mock import patch, MagicMock, mock_open
from parameterized import parameterized
# Import the module under test
import alphapulldown.objects as objects_mod
import alphapulldown.scripts.create_individual_features as create_features
from alphapulldown.objects import MonomericObject
from alphapulldown.utils import mmseqs_species_identifiers
logger = logging.getLogger(__name__)
# Minimal real MonomericObject for pickling
class DummyMonomer:
def __init__(self, description, sequence=None):
self.description = description
self.sequence = sequence
self.feature_dict = {}
self.uniprot_runner = None
def make_features(self, *a, **k):
return None
def make_mmseq_features(self, *a, **k):
return None
def all_seq_msa_features(self, *a, **k):
return {}
class RecordingDummyMonomer(DummyMonomer):
def __init__(self, description, sequence=None):
super().__init__(description, sequence)
self.feature_calls = []
self.mmseq_calls = []
def make_features(self, *args, **kwargs):
self.feature_calls.append(kwargs)
def make_mmseq_features(self, *args, **kwargs):
self.mmseq_calls.append(kwargs)
class DummyJsonObj:
def to_json(self):
return '{"test": "features"}'
def real_write_text(self, content, *args, **kwargs):
"""Real write_text function for Path objects."""
self.parent.mkdir(parents=True, exist_ok=True)
with open(self, 'w') as f:
f.write(content)
return len(content)
def build_af3_stub_modules():
alphafold3_pkg = types.ModuleType("alphafold3")
alphafold3_pkg.__path__ = []
common_pkg = types.ModuleType("alphafold3.common")
common_pkg.__path__ = []
structure_pkg = types.ModuleType("alphafold3.structure")
structure_pkg.__path__ = []
folding_input_mod = types.ModuleType("alphafold3.common.folding_input")
mmcif_mod = types.ModuleType("alphafold3.structure.mmcif")
class ProteinChain:
def __init__(
self,
sequence,
id,
ptms=None,
residue_ids=None,
description=None,
paired_msa=None,
unpaired_msa=None,
templates=None,
):
self.sequence = sequence
self.id = id
self.ptms = [] if ptms is None else list(ptms)
self.residue_ids = residue_ids
self.description = description
self.paired_msa = paired_msa
self.unpaired_msa = unpaired_msa
self.templates = templates
class RnaChain:
def __init__(
self,
sequence,
id,
modifications=None,
residue_ids=None,
description=None,
unpaired_msa=None,
):
self.sequence = sequence
self.id = id
self.modifications = [] if modifications is None else list(modifications)
self.residue_ids = residue_ids
self.description = description
self.unpaired_msa = unpaired_msa
class DnaChain:
def __init__(self, sequence, id, modifications=None, residue_ids=None, description=None):
self.sequence = sequence
self.id = id
self.modifications = [] if modifications is None else list(modifications)
self.residue_ids = residue_ids
self.description = description
class Input:
def __init__(self, name, chains, rng_seeds):
self.name = name
self.chains = list(chains)
self.rng_seeds = list(rng_seeds)
folding_input_mod.ProteinChain = ProteinChain
folding_input_mod.RnaChain = RnaChain
folding_input_mod.DnaChain = DnaChain
folding_input_mod.Input = Input
mmcif_mod.int_id_to_str_id = lambda idx: chr(ord("A") + idx - 1)
alphafold3_pkg.common = common_pkg
alphafold3_pkg.structure = structure_pkg
common_pkg.folding_input = folding_input_mod
structure_pkg.mmcif = mmcif_mod
return {
"alphafold3": alphafold3_pkg,
"alphafold3.common": common_pkg,
"alphafold3.common.folding_input": folding_input_mod,
"alphafold3.structure": structure_pkg,
"alphafold3.structure.mmcif": mmcif_mod,
}, folding_input_mod
class TestCreateIndividualFeaturesComprehensive:
"""Comprehensive test cases for create_individual_features.py."""
@pytest.fixture(autouse=True)
def setup_and_teardown(self, tmp_flags):
"""Set up test fixtures."""
self.test_dir = tempfile.mkdtemp()
self.fasta_dir = os.path.join(self.test_dir, "fastas")
os.makedirs(self.fasta_dir, exist_ok=True)
# Create test FASTA files
self.create_test_fastas()
# Mock database paths
self.af2_db = "/g/alphafold/AlphaFold_DBs/2.3.0"
self.af3_db = "/g/alphafold/AlphaFold_DBs/3.0.0"
logger.info(f"Test setup complete. Using temp directory: {self.test_dir}")
yield
# Clean up test fixtures
import shutil
shutil.rmtree(self.test_dir)
logger.info("Test cleanup complete")
def create_test_fastas(self):
"""Create test FASTA files."""
logger.info("Creating test FASTA files")
# Single protein
with open(os.path.join(self.fasta_dir, "single_protein.fasta"), "w") as f:
f.write(">A0A024R1R8\nMSSHEGGKKKALKQPKKQAKEMDEEEKAFKQKQKEEQKKLEVLKAKVVGKGPLATGGIKKSGKK\n")
# Multiple proteins
with open(os.path.join(self.fasta_dir, "multi_protein.fasta"), "w") as f:
f.write(">A0A024R1R8\nMSSHEGGKKKALKQPKKQAKEMDEEEKAFKQKQKEEQKKLEVLKAKVVGKGPLATGGIKKSGKK\n")
f.write(">P61626\nMKALIVLGLVLLSVTVQGKVFERCELARTLKRLGMDGYRGISLANWMCLAKWESGYNTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRDPQGIRAWVAWRNRCQNRDVRQYVQGCGV\n")
# RNA
with open(os.path.join(self.fasta_dir, "rna.fasta"), "w") as f:
f.write(">RNA_TEST\nAUGGCUACGUAGCUAGCUAGCUAGCUAGCUAGCUAG\n")
# DNA
with open(os.path.join(self.fasta_dir, "dna.fasta"), "w") as f:
f.write(">DNA_TEST\nATGGCATCGATCGATCGATCGATCGATCGATCGATCGATC\n")
logger.info("Test FASTA files created successfully")
@parameterized.expand([
("alphafold2", "single_protein.fasta", False, False),
("alphafold2", "multi_protein.fasta", False, False),
("alphafold2", "single_protein.fasta", True, False), # mmseqs2
("alphafold2", "single_protein.fasta", False, True), # compressed
("alphafold3", "single_protein.fasta", False, False),
("alphafold3", "multi_protein.fasta", False, False),
("alphafold3", "rna.fasta", False, False),
("alphafold3", "dna.fasta", False, False),
])
def test_feature_creation(self, pipeline, fasta_file, use_mmseqs2, compress_features):
"""Test feature creation for different configurations."""
logger.info(f"Testing feature creation: pipeline={pipeline}, file={fasta_file}, mmseqs2={use_mmseqs2}, compress={compress_features}")
fasta_path = os.path.join(self.fasta_dir, fasta_file)
output_dir = os.path.join(self.test_dir, f"output_{pipeline}_{fasta_file}")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Set flags directly to avoid UnrecognizedFlagError
FLAGS.data_pipeline = pipeline
FLAGS.fasta_paths = [fasta_path]
FLAGS.data_dir = self.af2_db if pipeline == "alphafold2" else self.af3_db
FLAGS.output_dir = output_dir
FLAGS.max_template_date = "2021-09-30"
FLAGS.use_mmseqs2 = use_mmseqs2
FLAGS.compress_features = compress_features
FLAGS.save_msa_files = False
FLAGS.skip_existing = False
if pipeline == "alphafold2":
logger.info("Testing AlphaFold2 pipeline")
with patch.object(create_features, 'create_pipeline_af2') as mock_af2_pipeline, \
patch.object(create_features, 'create_uniprot_runner') as mock_uniprot_runner, \
patch('alphapulldown.utils.save_meta_data.get_meta_dict', return_value={}), \
patch.object(create_features, 'MonomericObject', DummyMonomer):
mock_af2_pipeline.return_value = "pipeline"
mock_uniprot_runner.return_value = "runner"
create_features.create_individual_features()
# Check for expected files
expected_files = []
if fasta_file == "single_protein.fasta":
expected_files.append("A0A024R1R8.pkl")
elif fasta_file == "multi_protein.fasta":
expected_files.extend(["A0A024R1R8.pkl", "P61626.pkl"])
logger.info(f"Checking for expected files: {expected_files}")
for expected_file in expected_files:
file_path = os.path.join(output_dir, expected_file)
if compress_features:
file_path += ".xz"
assert os.path.exists(file_path), f"Expected file {file_path} not found"
logger.info(f"Verified file exists: {file_path}")
else:
logger.info("Testing AlphaFold3 pipeline")
af3_modules, folding_input_stub = build_af3_stub_modules()
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, 'create_pipeline_af3') as mock_af3_pipeline, \
patch.object(create_features, 'folding_input', folding_input_stub), \
patch('pathlib.Path.write_text', new=real_write_text), \
patch('alphapulldown.utils.save_meta_data.get_meta_dict', return_value={}):
mock_af3_pipeline.return_value = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
create_features.create_af3_individual_features()
process_calls = mock_af3_pipeline.return_value.process.call_args_list
observed_chain_types = [
type(call.args[0].chains[0]).__name__ for call in process_calls
]
expected_files = []
expected_chain_types = []
if fasta_file == "single_protein.fasta":
expected_files.append("A0A024R1R8_af3_input.json")
expected_chain_types.append("ProteinChain")
elif fasta_file == "multi_protein.fasta":
expected_files.extend(["A0A024R1R8_af3_input.json", "P61626_af3_input.json"])
expected_chain_types.extend(["ProteinChain", "ProteinChain"])
elif fasta_file == "rna.fasta":
expected_files.append("RNA_TEST_af3_input.json")
expected_chain_types.append("RnaChain")
elif fasta_file == "dna.fasta":
expected_files.append("DNA_TEST_af3_input.json")
expected_chain_types.append("DnaChain")
logger.info(f"Checking for expected files: {expected_files}")
assert observed_chain_types == expected_chain_types
for expected_file in expected_files:
file_path = os.path.join(output_dir, expected_file)
assert os.path.exists(file_path), f"Expected file {file_path} not found"
logger.info(f"Verified file exists: {file_path}")
logger.info("Feature creation test completed successfully")
def test_af3_invalid_sequence_fails_run(self):
"""Invalid AF3 sequences should fail the AF3 run instead of being skipped."""
invalid_fasta = os.path.join(self.fasta_dir, "invalid_af3.fasta")
with open(invalid_fasta, "w") as handle:
handle.write(">INVALID\nACDZ*\n")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.data_pipeline = "alphafold3"
FLAGS.fasta_paths = [invalid_fasta]
FLAGS.data_dir = self.af3_db
FLAGS.output_dir = os.path.join(self.test_dir, "output_invalid_af3")
FLAGS.max_template_date = "2021-09-30"
error_messages = []
af3_modules, folding_input_stub = build_af3_stub_modules()
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, "create_pipeline_af3") as mock_af3_pipeline, \
patch.object(create_features, "folding_input", folding_input_stub), \
patch.object(create_features.logging, "error", side_effect=error_messages.append):
mock_af3_pipeline.return_value = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
with pytest.raises(RuntimeError, match="INVALID"):
create_features.create_af3_individual_features()
mock_af3_pipeline.return_value.process.assert_not_called()
assert not os.path.exists(
os.path.join(FLAGS.output_dir, "INVALID_af3_input.json")
)
assert any("Failed to create AlphaFold3 input object" in message for message in error_messages)
def test_af3_ambiguous_sequence_requires_chain_hint(self):
"""Ambiguous AF3 alphabets should require an explicit chain hint in the FASTA header."""
ambiguous_fasta = os.path.join(self.fasta_dir, "ambiguous_af3.fasta")
with open(ambiguous_fasta, "w") as handle:
handle.write(">AMBIG\nACGT\n")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.data_pipeline = "alphafold3"
FLAGS.fasta_paths = [ambiguous_fasta]
FLAGS.data_dir = self.af3_db
FLAGS.output_dir = os.path.join(self.test_dir, "output_ambiguous_af3")
FLAGS.max_template_date = "2021-09-30"
error_messages = []
af3_modules, folding_input_stub = build_af3_stub_modules()
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, "create_pipeline_af3") as mock_af3_pipeline, \
patch.object(create_features, "folding_input", folding_input_stub), \
patch.object(create_features.logging, "error", side_effect=error_messages.append):
mock_af3_pipeline.return_value = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
with pytest.raises(RuntimeError, match="AMBIG"):
create_features.create_af3_individual_features()
mock_af3_pipeline.return_value.process.assert_not_called()
assert not os.path.exists(
os.path.join(FLAGS.output_dir, "AMBIG_af3_input.json")
)
assert any("Ambiguous sequence alphabet" in message for message in error_messages)
def test_create_individual_features_truemultimer_respects_seq_index(self):
"""TrueMultimer mode should only process the selected CSV row."""
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.description_file = os.path.join(self.test_dir, "description.csv")
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "multi_protein.fasta")]
FLAGS.path_to_mmt = os.path.join(self.test_dir, "templates")
FLAGS.multiple_mmts = True
FLAGS.seq_index = 2
feats = [
{"protein": "prot1"},
{"protein": "prot2"},
{"protein": "prot3"},
]
with patch.object(create_features, "parse_csv_file", return_value=feats) as mock_parse, \
patch.object(create_features, "process_multimeric_features") as mock_process:
create_features.create_individual_features_truemultimer()
mock_parse.assert_called_once_with(
FLAGS.description_file,
FLAGS.fasta_paths,
FLAGS.path_to_mmt,
FLAGS.multiple_mmts,
)
mock_process.assert_called_once_with(feats[1], 2)
def test_process_multimeric_features_rejects_missing_templates(self):
"""TrueMultimer mode should fail early if a template path is missing."""
feat = {
"protein": "complexA",
"chains": ["A"],
"templates": [os.path.join(self.test_dir, "missing_template.cif")],
"sequence": "ACDE",
}
with pytest.raises(FileNotFoundError, match="does not exist"):
create_features.process_multimeric_features(feat, 1)
def test_process_multimeric_features_creates_custom_db_and_saves_monomer(self):
"""TrueMultimer processing should build a custom DB and hand a monomer to the saver."""
template_path = os.path.join(self.test_dir, "template1.cif")
Path(template_path).write_text("data_template\n", encoding="utf-8")
class RecordingMonomer:
def __init__(self, description, sequence):
self.description = description
self.sequence = sequence
self.feature_dict = {}
self.uniprot_runner = None
feat = {
"protein": "complexB",
"chains": ["A", "B"],
"templates": [template_path],
"sequence": "ACDEFG",
}
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.output_dir = os.path.join(self.test_dir, "truemultimer_output")
FLAGS.data_dir = self.af2_db
FLAGS.max_template_date = "2021-09-30"
FLAGS.use_mmseqs2 = False
FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer"
FLAGS.uniprot_database_path = "/db/uniprot.fasta"
with patch.object(create_features, "MonomericObject", RecordingMonomer), \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_create_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)
mock_custom_db.assert_called_once()
custom_db_args = mock_custom_db.call_args.args
assert custom_db_args[1:] == (
"complexB",
[template_path],
["A", "B"],
)
mock_create_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_called_once_with()
mock_runner.assert_called_once_with(
FLAGS.jackhmmer_binary_path,
FLAGS.uniprot_database_path,
)
saved_monomer, saved_pipeline = mock_save.call_args.args
assert saved_pipeline == "pipeline"
assert saved_monomer.description == "complexB"
assert saved_monomer.sequence == "ACDEFG"
assert saved_monomer.uniprot_runner == "runner"
@pytest.mark.parametrize("compressed_source", [False, True])
def test_process_multimeric_features_reuses_existing_source_pickle(
self, tmp_flags, compressed_source
):
template_path = Path(self.test_dir) / "template1.cif"
template_path.write_text("data_template\n", encoding="utf-8")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.output_dir = os.path.join(self.test_dir, "reused_truemultimer_output")
FLAGS.use_mmseqs2 = False
FLAGS.compress_features = False
FLAGS.skip_existing = False
output_dir = Path(FLAGS.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
source = MonomericObject("proteinA", "ACDE")
source.feature_dict = {
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int": np.zeros((1, 4), dtype=np.int32),
"num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32),
"msa_species_identifiers": np.asarray([b"9606"], dtype=object),
"msa_all_seq": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int_all_seq": np.zeros((1, 4), dtype=np.int32),
"msa_species_identifiers_all_seq": np.asarray([b"9606"], dtype=object),
"template_aatype": np.zeros((1, 4, 22), dtype=np.float32),
"template_all_atom_masks": np.ones((1, 4, 37), dtype=np.float32),
"template_all_atom_positions": np.ones((1, 4, 37, 3), dtype=np.float32),
"template_domain_names": np.asarray([b"old_template"], dtype=object),
"template_sequence": np.asarray([b"OLD"], dtype=object),
"template_sum_probs": np.asarray([0.5], dtype=np.float32),
"template_confidence_scores": np.full((1, 4), 0.75, dtype=np.float32),
"template_release_date": np.asarray(["2024-01-01"], dtype=object),
}
if compressed_source:
with lzma.open(output_dir / "proteinA.pkl.xz", "wb") as handle:
pickle.dump(source, handle)
else:
with open(output_dir / "proteinA.pkl", "wb") as handle:
pickle.dump(source, handle)
new_template_features = {
"template_aatype": np.ones((2, 4, 22), dtype=np.float32),
"template_all_atom_masks": np.full((2, 4, 37), 2.0, dtype=np.float32),
"template_all_atom_positions": np.full((2, 4, 37, 3), 3.0, dtype=np.float32),
"template_domain_names": np.asarray([b"newA", b"newB"], dtype=object),
"template_sequence": np.asarray([b"NEWA", b"NEWB"], dtype=object),
"template_sum_probs": np.asarray([0.1, 0.2], dtype=np.float32),
}
feat = {
"protein": "proteinA.template1.cif.A",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDE",
}
with patch.object(
create_features,
"extract_multimeric_template_features_for_single_chain",
return_value=types.SimpleNamespace(features=new_template_features),
) as mock_extract, \
patch.object(create_features, "create_custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner") as mock_runner, \
patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.process_multimeric_features(feat, 1)
mock_extract.assert_called_once_with(
query_seq="ACDE",
pdb_id="template1",
chain_id="A",
mmcif_file=str(template_path),
threshold_clashes=create_features.FLAGS.threshold_clashes,
hb_allowance=create_features.FLAGS.hb_allowance,
plddt_threshold=create_features.FLAGS.plddt_threshold,
)
mock_custom_db.assert_not_called()
mock_arguments.assert_not_called()
mock_pipeline.assert_not_called()
mock_runner.assert_not_called()
output_pickle = output_dir / "proteinA.template1.cif.A.pkl"
assert output_pickle.exists()
with open(output_pickle, "rb") as handle:
reused = pickle.load(handle)
assert reused.description == "proteinA.template1.cif.A"
assert np.array_equal(reused.feature_dict["msa"], source.feature_dict["msa"])
assert reused.feature_dict["template_sequence"].tolist() == [b"NEWA", b"NEWB"]
assert np.array_equal(
reused.feature_dict["template_confidence_scores"],
np.ones((2, 4), dtype=np.float32),
)
assert reused.feature_dict["template_release_date"].tolist() == ["none", "none"]
assert list(output_dir.glob("proteinA.template1.cif.A_feature_metadata_*.json"))
def test_process_multimeric_features_falls_back_when_source_sequence_mismatches(
self, tmp_flags
):
template_path = Path(self.test_dir) / "template1.cif"
template_path.write_text("data_template\n", encoding="utf-8")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.output_dir = os.path.join(self.test_dir, "mismatched_truemultimer_output")
FLAGS.use_mmseqs2 = False
FLAGS.compress_features = False
FLAGS.skip_existing = False
FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer"
FLAGS.uniprot_database_path = "/db/uniprot.fasta"
output_dir = Path(FLAGS.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
source = MonomericObject("proteinA", "ACDE")
source.feature_dict = {
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int": np.zeros((1, 4), dtype=np.int32),
"num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32),
"msa_species_identifiers": np.asarray([b"9606"], dtype=object),
}
with open(output_dir / "proteinA.pkl", "wb") as handle:
pickle.dump(source, handle)
feat = {
"protein": "proteinA.template1.cif.A",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDF",
}
with patch.object(
create_features,
"extract_multimeric_template_features_for_single_chain",
) as mock_extract, \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)
mock_extract.assert_not_called()
mock_custom_db.assert_called_once()
mock_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_called_once_with()
mock_runner.assert_called_once_with(
FLAGS.jackhmmer_binary_path,
FLAGS.uniprot_database_path,
)
saved_monomer, saved_pipeline = mock_save.call_args.args
assert saved_pipeline == "pipeline"
assert saved_monomer.description == "proteinA.template1.cif.A"
assert saved_monomer.sequence == "ACDF"
assert saved_monomer.uniprot_runner == "runner"
def test_process_multimeric_features_does_not_reuse_bulk_msa_pickle_for_skip_msa(
self, tmp_flags
):
template_path = Path(self.test_dir) / "template1.cif"
template_path.write_text("data_template\n", encoding="utf-8")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.output_dir = os.path.join(self.test_dir, "skip_msa_truemultimer_output")
FLAGS.use_mmseqs2 = False
FLAGS.compress_features = False
FLAGS.skip_existing = False
FLAGS.skip_msa = True
FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer"
FLAGS.uniprot_database_path = "/db/uniprot.fasta"
output_dir = Path(FLAGS.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
source = MonomericObject("proteinA", "ACDE")
source.feature_dict = {
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int": np.zeros((1, 4), dtype=np.int32),
"num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32),
"msa_species_identifiers": np.asarray([b"9606"], dtype=object),
}
with open(output_dir / "proteinA.pkl", "wb") as handle:
pickle.dump(source, handle)
feat = {
"protein": "proteinA.template1.cif.A",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDE",
}
with patch.object(
create_features,
"extract_multimeric_template_features_for_single_chain",
) as mock_extract, \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)
mock_extract.assert_not_called()
mock_custom_db.assert_called_once()
mock_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_called_once_with()
mock_runner.assert_not_called()
saved_monomer, saved_pipeline = mock_save.call_args.args
assert saved_pipeline == "pipeline"
assert saved_monomer.uniprot_runner is None
def test_process_multimeric_features_does_not_reuse_skip_msa_pickle_for_full_msa(
self, tmp_flags
):
template_path = Path(self.test_dir) / "template1.cif"
template_path.write_text("data_template\n", encoding="utf-8")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.output_dir = os.path.join(self.test_dir, "full_msa_truemultimer_output")
FLAGS.use_mmseqs2 = False
FLAGS.compress_features = False
FLAGS.skip_existing = False
FLAGS.skip_msa = False
FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer"
FLAGS.uniprot_database_path = "/db/uniprot.fasta"
output_dir = Path(FLAGS.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
source = MonomericObject("proteinA", "ACDE")
source.skip_msa = True
source.feature_dict = {
"msa": np.asarray([[1, 2, 3, 4]], dtype=np.int32),
"deletion_matrix_int": np.zeros((1, 4), dtype=np.int32),
"num_alignments": np.asarray([1, 1, 1, 1], dtype=np.int32),
"msa_species_identifiers": np.asarray([b""], dtype=object),
}
with open(output_dir / "proteinA.pkl", "wb") as handle:
pickle.dump(source, handle)
feat = {
"protein": "proteinA.template1.cif.A",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDE",
}
with patch.object(
create_features,
"extract_multimeric_template_features_for_single_chain",
) as mock_extract, \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)
mock_extract.assert_not_called()
mock_custom_db.assert_called_once()
mock_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_called_once_with()
mock_runner.assert_called_once_with("/usr/bin/jackhmmer", "/db/uniprot.fasta")
mock_save.assert_called_once()
def test_main_dispatches_to_truemultimer_for_af2_template_runs(self):
"""The main entrypoint should route AF2 template jobs to the TrueMultimer path."""
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.data_pipeline = "alphafold2"
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "single_protein.fasta")]
FLAGS.data_dir = self.af2_db
FLAGS.output_dir = os.path.join(self.test_dir, "main_truemultimer")
FLAGS.max_template_date = "2021-09-30"
FLAGS.path_to_mmt = os.path.join(self.test_dir, "templates")
with patch.object(create_features, "check_template_date") as mock_check, \
patch.object(create_features, "create_individual_features_truemultimer") as mock_tm, \
patch.object(create_features, "create_individual_features") as mock_single:
create_features.main([])
mock_check.assert_called_once_with()
mock_tm.assert_called_once_with()
mock_single.assert_not_called()
def test_database_path_mapping(self):
"""Test that database paths are correctly mapped for both pipelines."""
logger.info("Testing database path mapping")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
test_cases = [
("alphafold2", "uniref90", "uniref90/uniref90.fasta"),
("alphafold2", "uniref30", "uniref30/UniRef30_2023_02"),
("alphafold3", "uniref90", "uniref90_2022_05.fa"),
]
for pipeline, key, expected_subpath in test_cases:
logger.info(f"Testing {pipeline} pipeline with key '{key}'")
FLAGS.data_pipeline = pipeline
FLAGS.data_dir = "/test/db"
expected_path = os.path.join("/test/db", expected_subpath)
actual_path = create_features.get_database_path(key)
assert actual_path == expected_path, f"Expected {expected_path}, got {actual_path}"
logger.info(f"Database path mapping correct: {actual_path}")
FLAGS.data_pipeline = "alphafold3"
FLAGS.data_dir = "/test/db"
with pytest.raises(
KeyError,
match="Database 'uniref30' is not configured for the alphafold3 pipeline",
):
create_features.get_database_path("uniref30")
def test_af3_pipeline_creation_failure(self):
"""Test that AF3 pipeline creation fails gracefully when AF3 is not available."""
logger.info("Testing AF3 pipeline creation failure")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
with patch('alphapulldown.scripts.create_individual_features.AF3DataPipeline', None), \
patch('alphapulldown.scripts.create_individual_features.AF3DataPipelineConfig', None):
FLAGS.data_pipeline = "alphafold3"
FLAGS.data_dir = "/test/db"
with pytest.raises(ImportError, match="pip install -e .*alphafold3,test.*build_data"):
create_features.create_pipeline_af3()
logger.info("AF3 pipeline creation correctly failed with ImportError")
def test_template_date_check(self):
"""Test template date validation."""
logger.info("Testing template date validation")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Test valid date
FLAGS.max_template_date = "2021-09-30"
try:
create_features.check_template_date()
logger.info("Valid template date accepted")
except SystemExit:
pytest.fail("Valid date should not cause SystemExit")
# Test invalid date (None)
FLAGS.max_template_date = None
with pytest.raises(SystemExit):
create_features.check_template_date()
logger.info("Invalid template date correctly rejected")
def test_sequence_index_filtering(self):
"""Test sequence index filtering functionality."""
logger.info("Testing sequence index filtering")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Test with valid sequence index
FLAGS.seq_index = 1
FLAGS.fasta_paths = ["test.fasta"]
# Mock the iter_seqs function to return test data
with patch('alphapulldown.utils.file_handling.iter_seqs') as mock_iter_seqs:
mock_iter_seqs.return_value = [("SEQ1", "desc1"), ("SEQ2", "desc2"), ("SEQ3", "desc3")]
# Test that only the specified sequence is processed
sequences = list(mock_iter_seqs.return_value)
if FLAGS.seq_index is not None:
sequences = [sequences[FLAGS.seq_index - 1]] # seq_index is 1-based
assert len(sequences) == 1, f"Expected 1 sequence, got {len(sequences)}"
assert sequences[0][0] == "SEQ1", f"Expected SEQ1, got {sequences[0][0]}"
logger.info("Sequence filtering with valid index successful")
def test_skip_existing_flag(self):
"""Test skip existing functionality."""
logger.info("Testing skip existing functionality")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
output_dir = os.path.join(self.test_dir, "skip_test")
os.makedirs(output_dir, exist_ok=True)
# Create a dummy existing file
existing_file = os.path.join(output_dir, "test.pkl")
with open(existing_file, 'w') as f:
f.write("dummy")
FLAGS.output_dir = output_dir
FLAGS.skip_existing = True
# Mock the create_individual_features function to avoid database access
with patch.object(create_features, 'create_individual_features') as mock_create_features:
mock_create_features.return_value = None
# This should not create new files when skip_existing is True
create_features.create_individual_features()
logger.info("Skip existing functionality tested successfully")
def test_output_directory_creation(self):
"""Test output directory creation."""
logger.info("Testing output directory creation")
output_dir = os.path.join(self.test_dir, "new_output_dir")
# Test directory creation by running the main function
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
FLAGS.output_dir = output_dir
FLAGS.max_template_date = "2021-09-30"
FLAGS.data_pipeline = "alphafold2"
FLAGS.fasta_paths = ["dummy.fasta"] # Use a dummy path instead of empty list
FLAGS.data_dir = "/test/db"
# Mock the pipeline creation to avoid real database access
with patch.object(create_features, 'create_pipeline_af2') as mock_af2_pipeline, \
patch.object(create_features, 'create_uniprot_runner') as mock_uniprot_runner, \
patch('alphapulldown.scripts.create_individual_features.iter_seqs') as mock_iter_seqs:
mock_af2_pipeline.return_value = MagicMock()
mock_uniprot_runner.return_value = MagicMock()
mock_iter_seqs.return_value = [] # Return empty iterator
# The main function should create the output directory
create_features.main([])
assert os.path.exists(output_dir), f"Output directory {output_dir} was not created"
assert os.path.isdir(output_dir), f"{output_dir} is not a directory"
logger.info(f"Output directory created successfully: {output_dir}")
def test_alphafold3_chain_type_detection(self):
"""Test AlphaFold3 chain type detection."""
logger.info("Testing AlphaFold3 chain type detection")
# Test protein sequence detection
protein_seq = "MKALIVLGLVLLSVTVQGKVFERCELARTLKRLGMDGYRGISLANWMCLAKWESGYNTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRDPQGIRAWVAWRNRCQNRDVRQYVQGCGV"
assert all(c in 'ACDEFGHIKLMNPQRSTVWY' for c in protein_seq.upper()), "Protein sequence contains invalid amino acids"
logger.info("Protein chain type detection successful")
# Test RNA sequence detection
rna_seq = "AUGGCUACGUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAG"
assert all(c in 'ACGU' for c in rna_seq.upper()), "RNA sequence contains invalid nucleotides"
logger.info("RNA chain type detection successful")
# Test DNA sequence detection
dna_seq = "ATGGCATCGATCGATCGATCGATCGATCGATCGATCGATC"
assert all(c in 'ACGT' for c in dna_seq.upper()), "DNA sequence contains invalid nucleotides"
logger.info("DNA chain type detection successful")
def test_compression_flag(self):
"""Test feature compression functionality."""
logger.info("Testing feature compression functionality")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Test compression enabled
FLAGS.compress_features = True
assert FLAGS.compress_features, "Compression flag should be True"
logger.info("Compression flag enabled successfully")
# Test compression disabled
FLAGS.compress_features = False
assert not FLAGS.compress_features, "Compression flag should be False"
logger.info("Compression flag disabled successfully")
# Test file extension handling
test_file = "test.pkl"
if FLAGS.compress_features:
test_file += ".xz"
assert test_file == "test.pkl", "File extension should not be modified when compression is disabled"
logger.info("File extension handling tested successfully")
def test_create_arguments_function(self):
"""Test create_arguments function."""
logger.info("Testing create_arguments function")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Test basic argument creation
FLAGS.data_dir = "/test/db"
FLAGS.max_template_date = "2021-09-30"
# Ensure default composition from data_dir (no prior values lingering)
FLAGS.use_mmseqs2 = False
FLAGS.data_pipeline = "alphafold2"
FLAGS.uniref90_database_path = None
FLAGS.uniref30_database_path = None
FLAGS.mgnify_database_path = None
FLAGS.bfd_database_path = None
FLAGS.small_bfd_database_path = None
FLAGS.pdb70_database_path = None
FLAGS.uniprot_database_path = None
FLAGS.pdb_seqres_database_path = None
FLAGS.template_mmcif_dir = None
FLAGS.obsolete_pdbs_path = None
create_features.create_arguments()
assert FLAGS.uniref90_database_path == "/test/db/uniref90/uniref90.fasta", f"Expected '/test/db/uniref90/uniref90.fasta', got '{FLAGS.uniref90_database_path}'"
assert FLAGS.max_template_date == "2021-09-30", f"Expected '2021-09-30', got '{FLAGS.max_template_date}'"
logger.info("Basic argument creation successful")
# Test with custom template database
custom_db_path = "/custom/templates"
create_features.create_arguments(custom_db_path)
assert FLAGS.pdb_seqres_database_path == "/custom/templates/pdb_seqres.txt", f"Expected '/custom/templates/pdb_seqres.txt', got '{FLAGS.pdb_seqres_database_path}'"
logger.info("Custom template database argument creation successful")
def test_create_arguments_with_custom_template_db(self):
"""Test create_arguments function with custom template database."""
logger.info("Testing create_arguments with custom template database")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Test custom template database path handling
custom_db_path = "/custom/template/db"
create_features.create_arguments(custom_db_path)
assert FLAGS.pdb_seqres_database_path == "/custom/template/db/pdb_seqres.txt", f"Expected '/custom/template/db/pdb_seqres.txt', got '{FLAGS.pdb_seqres_database_path}'"
logger.info("Custom template database path handling successful")
# Test that other flags are preserved
FLAGS.data_dir = "/test/db"
FLAGS.max_template_date = "2021-09-30"
create_features.create_arguments()
assert FLAGS.data_dir == "/test/db", "Data directory should be preserved"
assert FLAGS.max_template_date == "2021-09-30", "Max template date should be preserved"
logger.info("Flag preservation in custom template database mode successful")
def test_create_arguments_alphafold3_clears_af2_only_databases(self):
"""Test that AF3 argument creation only populates databases used by AF3."""
logger.info("Testing AF3 argument creation without AF2-only database leftovers")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test'])
FLAGS.use_mmseqs2 = False
FLAGS.data_pipeline = "alphafold3"
FLAGS.data_dir = "/test/db"
FLAGS.uniref90_database_path = None
FLAGS.mgnify_database_path = None
FLAGS.small_bfd_database_path = None
FLAGS.uniprot_database_path = None
FLAGS.pdb_seqres_database_path = None
FLAGS.template_mmcif_dir = None
FLAGS.uniref30_database_path = "/stale/uniref30"
FLAGS.bfd_database_path = "/stale/bfd"
FLAGS.pdb70_database_path = "/stale/pdb70"
FLAGS.obsolete_pdbs_path = "/stale/obsolete.dat"
create_features.create_arguments()
assert FLAGS.uniref90_database_path == "/test/db/uniref90_2022_05.fa"
assert FLAGS.mgnify_database_path == "/test/db/mgy_clusters_2022_05.fa"
assert FLAGS.small_bfd_database_path == "/test/db/bfd-first_non_consensus_sequences.fasta"
assert FLAGS.uniprot_database_path == "/test/db/uniprot_all_2021_04.fa"
assert FLAGS.pdb_seqres_database_path == "/test/db/pdb_seqres_2022_09_28.fasta"
assert FLAGS.template_mmcif_dir == "/test/db/mmcif_files"
assert FLAGS.uniref30_database_path is None
assert FLAGS.bfd_database_path is None
assert FLAGS.pdb70_database_path is None
assert FLAGS.obsolete_pdbs_path is None
create_features.create_arguments("/custom/template/db")
assert FLAGS.pdb_seqres_database_path == "/test/db/pdb_seqres_2022_09_28.fasta"
assert FLAGS.template_mmcif_dir == "/test/db/mmcif_files"
assert FLAGS.obsolete_pdbs_path is None
logger.info("AF3 argument creation only kept AF3-relevant database paths")
def test_create_arguments_reduced_dbs_clears_unused_af2_databases(self):
"""Test that reduced_dbs only sets the AF2 paths it actually needs."""
logger.info("Testing reduced_dbs argument creation without full-db leftovers")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test'])
FLAGS.use_mmseqs2 = False
FLAGS.data_pipeline = "alphafold2"
FLAGS.db_preset = "reduced_dbs"
FLAGS.use_hhsearch = False
FLAGS.data_dir = "/test/db"
FLAGS.uniref90_database_path = None
FLAGS.mgnify_database_path = None
FLAGS.small_bfd_database_path = None
FLAGS.uniprot_database_path = None
FLAGS.pdb_seqres_database_path = None
FLAGS.template_mmcif_dir = None
FLAGS.obsolete_pdbs_path = None
FLAGS.uniref30_database_path = "/stale/uniref30"
FLAGS.bfd_database_path = "/stale/bfd"
FLAGS.pdb70_database_path = "/stale/pdb70"
create_features.create_arguments()
assert FLAGS.uniref90_database_path == "/test/db/uniref90/uniref90.fasta"
assert FLAGS.mgnify_database_path == "/test/db/mgnify/mgy_clusters_2022_05.fa"
assert FLAGS.small_bfd_database_path == "/test/db/small_bfd/bfd-first_non_consensus_sequences.fasta"
assert FLAGS.uniprot_database_path == "/test/db/uniprot/uniprot.fasta"
assert FLAGS.pdb_seqres_database_path == "/test/db/pdb_seqres/pdb_seqres.txt"
assert FLAGS.template_mmcif_dir == "/test/db/pdb_mmcif/mmcif_files"
assert FLAGS.obsolete_pdbs_path == "/test/db/pdb_mmcif/obsolete.dat"
assert FLAGS.uniref30_database_path is None
assert FLAGS.bfd_database_path is None
assert FLAGS.pdb70_database_path is None
logger.info("Reduced-dbs argument creation cleared unused full-database paths")
def test_create_arguments_reduced_dbs_keeps_pdb70_for_hhsearch(self):
"""Test that reduced_dbs still sets pdb70 when HHsearch templates are requested."""
logger.info("Testing reduced_dbs HHsearch argument creation")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test'])
FLAGS.use_mmseqs2 = False
FLAGS.data_pipeline = "alphafold2"
FLAGS.db_preset = "reduced_dbs"
FLAGS.use_hhsearch = True
FLAGS.data_dir = "/test/db"
FLAGS.pdb70_database_path = None
create_features.create_arguments()
assert FLAGS.pdb70_database_path == "/test/db/pdb70/pdb70"
assert FLAGS.bfd_database_path is None
assert FLAGS.uniref30_database_path is None
logger.info("Reduced-dbs HHsearch argument creation kept pdb70 without restoring full BFD")
def test_mmseqs2_without_data_dir(self):
"""Test that MMseqs2 works without data_dir flag."""
logger.info("Testing MMseqs2 without data_dir flag")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Set up flags for MMseqs2 without data_dir
FLAGS.use_mmseqs2 = True
FLAGS.data_dir = None
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "single_protein.fasta")]
FLAGS.output_dir = os.path.join(self.test_dir, "test_output")
FLAGS.max_template_date = "2021-09-30"
# Test that main() doesn't exit when data_dir is None but use_mmseqs2 is True
with patch('sys.exit') as mock_exit, \
patch('alphapulldown.scripts.create_individual_features.create_pipeline_af2') as mock_pipeline, \
patch('alphapulldown.scripts.create_individual_features.create_uniprot_runner') as mock_uniprot, \
patch('alphapulldown.objects.MonomericObject', DummyMonomer), \
patch('alphapulldown.utils.save_meta_data.get_meta_dict', return_value={}), \
patch('builtins.open', mock_open()), \
patch('pickle.dump'):
mock_pipeline.return_value = MagicMock()
mock_uniprot.return_value = MagicMock()
create_features.main([])
mock_exit.assert_not_called()
logger.info("MMseqs2 without data_dir flag test successful")
def test_mmseqs2_with_data_dir(self):
"""Test that MMseqs2 works with data_dir flag (should still work)."""
logger.info("Testing MMseqs2 with data_dir flag")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Set up flags for MMseqs2 with data_dir
FLAGS.use_mmseqs2 = True
FLAGS.data_dir = "/test/db"
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "single_protein.fasta")]
FLAGS.output_dir = os.path.join(self.test_dir, "test_output")
FLAGS.max_template_date = "2021-09-30"
# Test that main() doesn't exit when data_dir is provided and use_mmseqs2 is True
with patch('sys.exit') as mock_exit, \
patch('alphapulldown.scripts.create_individual_features.create_pipeline_af2') as mock_pipeline, \
patch('alphapulldown.scripts.create_individual_features.create_uniprot_runner') as mock_uniprot, \
patch('alphapulldown.objects.MonomericObject', DummyMonomer), \
patch('alphapulldown.utils.save_meta_data.get_meta_dict', return_value={}), \
patch('builtins.open', mock_open()), \
patch('pickle.dump'):
mock_pipeline.return_value = MagicMock()
mock_uniprot.return_value = MagicMock()
create_features.main([])
mock_exit.assert_not_called()
logger.info("MMseqs2 with data_dir flag test successful")
def test_non_mmseqs2_without_data_dir(self):
"""Test that non-MMseqs2 fails without data_dir flag."""
logger.info("Testing non-MMseqs2 without data_dir flag")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Set up flags for non-MMseqs2 without data_dir
FLAGS.use_mmseqs2 = False
FLAGS.data_dir = None
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "single_protein.fasta")]
FLAGS.output_dir = os.path.join(self.test_dir, "test_output")
FLAGS.max_template_date = "2021-09-30"
# Test that get_database_path raises ValueError when data_dir is None and use_mmseqs2 is False
with pytest.raises(ValueError, match="data_dir is required when not using MMseqs2"):
create_features.get_database_path("uniref90")
logger.info("Non-MMseqs2 without data_dir flag correctly failed")
def test_database_path_handling_mmseqs2(self):
"""Test database path handling when using MMseqs2."""
logger.info("Testing database path handling with MMseqs2")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Test with MMseqs2 and no data_dir
FLAGS.use_mmseqs2 = True
FLAGS.data_dir = None
# Test get_database_path returns None
result = create_features.get_database_path("uniref90")
assert result is None, f"Expected None, got {result}"
# Test create_arguments sets database paths to None
create_features.create_arguments()
assert FLAGS.uniref90_database_path is None, "uniref90_database_path should be None"
assert FLAGS.mgnify_database_path is None, "mgnify_database_path should be None"
assert FLAGS.bfd_database_path is None, "bfd_database_path should be None"
logger.info("Database path handling with MMseqs2 successful")
def test_pipeline_creation_mmseqs2(self):
"""Test pipeline creation when using MMseqs2."""
logger.info("Testing pipeline creation with MMseqs2")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Set up flags for MMseqs2
FLAGS.use_mmseqs2 = True
FLAGS.data_dir = None
FLAGS.db_preset = "full_dbs"
# Mock the AF2DataPipeline to avoid real database access
with patch('alphapulldown.scripts.create_individual_features.AF2DataPipeline') as mock_pipeline:
mock_pipeline.return_value = MagicMock()
# Test that pipeline creation doesn't fail
pipeline = create_features.create_pipeline_af2()
assert pipeline is not None, "Pipeline should be created successfully"
# Verify that template_searcher and template_featurizer are None
# We can't directly access these, but we can verify the pipeline was created
mock_pipeline.assert_called_once()
logger.info("Pipeline creation with MMseqs2 successful")
def test_feature_creation_mmseqs2(self):
"""Test feature creation when using MMseqs2."""
logger.info("Testing feature creation with MMseqs2")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Set up flags for MMseqs2
FLAGS.use_mmseqs2 = True
FLAGS.data_dir = None
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "single_protein.fasta")]
FLAGS.output_dir = os.path.join(self.test_dir, "test_output")
FLAGS.max_template_date = "2021-09-30"
# Mock the necessary functions to avoid real database access
with patch('alphapulldown.scripts.create_individual_features.create_pipeline_af2') as mock_pipeline, \
patch('alphapulldown.scripts.create_individual_features.create_uniprot_runner') as mock_uniprot, \
patch.object(create_features, 'iter_seqs') as mock_iter_seqs, \
patch.object(create_features, 'MonomericObject', DummyMonomer), \
patch('alphapulldown.utils.save_meta_data.get_meta_dict', return_value={}):
mock_iter_seqs.return_value = [("TESTSEQ", "test_protein")]
# Test that feature creation doesn't fail
create_features.create_individual_features()
# Verify that pipeline and uniprot_runner are None for MMseqs2
mock_pipeline.assert_not_called()
mock_uniprot.assert_not_called()
assert os.path.exists(os.path.join(FLAGS.output_dir, "test_protein.pkl"))
logger.info("Feature creation with MMseqs2 successful")
def test_flag_validation_mmseqs2(self):
"""Test flag validation for MMseqs2 scenarios."""
logger.info("Testing flag validation for MMseqs2")
# Initialize flags properly
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['test']) # Parse flags with dummy argv
# Test case 1: MMseqs2 with data_dir (should work)
FLAGS.use_mmseqs2 = True
FLAGS.data_dir = "/test/db"
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "single_protein.fasta")]
FLAGS.output_dir = os.path.join(self.test_dir, "test_output")
FLAGS.max_template_date = "2021-09-30"
with patch('sys.exit') as mock_exit, \
patch('alphapulldown.scripts.create_individual_features.create_pipeline_af2') as mock_pipeline, \
patch('alphapulldown.scripts.create_individual_features.create_uniprot_runner') as mock_uniprot, \
patch('alphapulldown.objects.MonomericObject', DummyMonomer), \
patch('alphapulldown.utils.save_meta_data.get_meta_dict', return_value={}), \
patch('builtins.open', mock_open()), \
patch('pickle.dump'):
mock_pipeline.return_value = MagicMock()
mock_uniprot.return_value = MagicMock()
create_features.main([])
mock_exit.assert_not_called()
# Test case 2: MMseqs2 without data_dir (should work)
FLAGS.data_dir = None
with patch('sys.exit') as mock_exit, \
patch('alphapulldown.scripts.create_individual_features.create_pipeline_af2') as mock_pipeline, \
patch('alphapulldown.scripts.create_individual_features.create_uniprot_runner') as mock_uniprot, \
patch('alphapulldown.objects.MonomericObject', DummyMonomer), \
patch('alphapulldown.utils.save_meta_data.get_meta_dict', return_value={}), \
patch('builtins.open', mock_open()), \
patch('pickle.dump'):
mock_pipeline.return_value = MagicMock()
mock_uniprot.return_value = MagicMock()
create_features.main([])
mock_exit.assert_not_called()
# Test case 3: Non-MMseqs2 without data_dir (should fail)
FLAGS.use_mmseqs2 = False
FLAGS.data_dir = None
with pytest.raises(SystemExit):
create_features.main([])
logger.info("Flag validation for MMseqs2 scenarios successful")
def test_create_pipeline_af2_uses_hhsearch_template_stack(tmp_flags):
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.use_hhsearch = True
create_features.FLAGS.hhsearch_binary_path = "/bin/hhsearch"
create_features.FLAGS.pdb70_database_path = "/db/pdb70"
create_features.FLAGS.template_mmcif_dir = "/db/mmcif"
create_features.FLAGS.max_template_date = "2021-09-30"
create_features.FLAGS.kalign_binary_path = "/bin/kalign"
create_features.FLAGS.obsolete_pdbs_path = "/db/obsolete.dat"
with patch.object(create_features.hhsearch, "HHSearch", return_value="searcher") as mock_searcher, \
patch.object(create_features.templates, "HhsearchHitFeaturizer", return_value="featurizer") as mock_featurizer, \
patch.object(create_features, "AF2DataPipeline", return_value="pipeline") as mock_pipeline:
pipeline = create_features.create_pipeline_af2()
assert pipeline == "pipeline"
mock_searcher.assert_called_once_with(
binary_path="/bin/hhsearch",
databases=["/db/pdb70"],
)
mock_featurizer.assert_called_once_with(
mmcif_dir="/db/mmcif",
max_template_date="2021-09-30",
max_hits=20,
kalign_binary_path="/bin/kalign",
release_dates_path=None,
obsolete_pdbs_path="/db/obsolete.dat",
)
assert mock_pipeline.call_args.kwargs["template_searcher"] == "searcher"
assert mock_pipeline.call_args.kwargs["template_featurizer"] == "featurizer"
def test_create_pipeline_af2_uses_hmmsearch_template_stack(tmp_flags):
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.use_hhsearch = False
create_features.FLAGS.hmmsearch_binary_path = "/bin/hmmsearch"
create_features.FLAGS.hmmbuild_binary_path = "/bin/hmmbuild"
create_features.FLAGS.pdb_seqres_database_path = "/db/pdb_seqres.txt"
create_features.FLAGS.template_mmcif_dir = "/db/mmcif"
create_features.FLAGS.max_template_date = "2021-09-30"
create_features.FLAGS.kalign_binary_path = "/bin/kalign"
create_features.FLAGS.obsolete_pdbs_path = "/db/obsolete.dat"
with patch.object(create_features.hmmsearch, "Hmmsearch", return_value="searcher") as mock_searcher, \
patch.object(create_features.templates, "HmmsearchHitFeaturizer", return_value="featurizer") as mock_featurizer, \
patch.object(create_features, "AF2DataPipeline", return_value="pipeline") as mock_pipeline:
pipeline = create_features.create_pipeline_af2()
assert pipeline == "pipeline"
mock_searcher.assert_called_once_with(
binary_path="/bin/hmmsearch",
hmmbuild_binary_path="/bin/hmmbuild",
database_path="/db/pdb_seqres.txt",
)
mock_featurizer.assert_called_once_with(
mmcif_dir="/db/mmcif",
max_template_date="2021-09-30",
max_hits=20,
kalign_binary_path="/bin/kalign",
obsolete_pdbs_path="/db/obsolete.dat",
release_dates_path=None,
)
assert mock_pipeline.call_args.kwargs["template_searcher"] == "searcher"
assert mock_pipeline.call_args.kwargs["template_featurizer"] == "featurizer"
def test_create_pipeline_af2_skip_msa_returns_template_only_pipeline(tmp_flags):
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.use_hhsearch = False
create_features.FLAGS.skip_msa = True
create_features.FLAGS.hmmsearch_binary_path = "/bin/hmmsearch"
create_features.FLAGS.hmmbuild_binary_path = "/bin/hmmbuild"
create_features.FLAGS.pdb_seqres_database_path = "/db/pdb_seqres.txt"
create_features.FLAGS.template_mmcif_dir = "/db/mmcif"
create_features.FLAGS.max_template_date = "2021-09-30"
create_features.FLAGS.kalign_binary_path = "/bin/kalign"
create_features.FLAGS.obsolete_pdbs_path = "/db/obsolete.dat"
with patch.object(create_features.hmmsearch, "Hmmsearch", return_value="searcher") as mock_searcher, \
patch.object(create_features.templates, "HmmsearchHitFeaturizer", return_value="featurizer") as mock_featurizer, \
patch.object(create_features, "AF2DataPipeline") as mock_pipeline:
pipeline = create_features.create_pipeline_af2()
mock_searcher.assert_called_once_with(
binary_path="/bin/hmmsearch",
hmmbuild_binary_path="/bin/hmmbuild",
database_path="/db/pdb_seqres.txt",
)
mock_featurizer.assert_called_once_with(
mmcif_dir="/db/mmcif",
max_template_date="2021-09-30",
max_hits=20,
kalign_binary_path="/bin/kalign",
obsolete_pdbs_path="/db/obsolete.dat",
release_dates_path=None,
)
mock_pipeline.assert_not_called()
assert pipeline.template_searcher == "searcher"
assert pipeline.template_featurizer == "featurizer"
def test_create_individual_features_only_saves_selected_sequence(tmp_flags):
create_features.FLAGS.seq_index = 2
with patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
patch.object(create_features, "MonomericObject", DummyMonomer), \
patch.object(create_features, "iter_seqs", return_value=[("AAAA", "first"), ("BBBB", "second")]), \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.create_individual_features()
mock_arguments.assert_called_once_with()
mock_pipeline.assert_called_once_with()
mock_runner.assert_called_once()
saved_monomer, saved_pipeline = mock_save.call_args.args
assert saved_pipeline == "pipeline"
assert saved_monomer.description == "second"
assert saved_monomer.uniprot_runner == "runner"
def test_create_individual_features_skip_msa_avoids_uniprot_runner(tmp_flags):
create_features.FLAGS.seq_index = None
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.skip_msa = True
with patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="template-only-pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner") as mock_runner, \
patch.object(create_features, "MonomericObject", DummyMonomer), \
patch.object(create_features, "iter_seqs", return_value=[("AAAA", "first")]), \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.create_individual_features()
mock_arguments.assert_called_once_with()
mock_pipeline.assert_called_once_with()
mock_runner.assert_not_called()
saved_monomer, saved_pipeline = mock_save.call_args.args
assert saved_pipeline == "template-only-pipeline"
assert saved_monomer.description == "first"
assert saved_monomer.uniprot_runner is None
def test_create_and_save_monomer_objects_writes_compressed_af2_outputs(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = True
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.use_precomputed_msas = True
create_features.FLAGS.save_msa_files = True
monomer = RecordingDummyMonomer("protA")
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(monomer, pipeline="pipeline")
metadata_files = list(tmp_path.glob("protA_feature_metadata_*.json.xz"))
assert len(metadata_files) == 1
with lzma.open(metadata_files[0], "rt", encoding="utf-8") as handle:
assert json.load(handle) == {"source": "test"}
assert (tmp_path / "protA.pkl.xz").exists()
assert monomer.feature_calls == [
{
"pipeline": "pipeline",
"output_dir": str(tmp_path),
"use_precomputed_msa": True,
"save_msa": True,
"skip_msa": False,
}
]
assert monomer.mmseq_calls == []
def test_create_and_save_monomer_objects_skips_existing_outputs(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = True
create_features.FLAGS.use_mmseqs2 = True
existing_pickle = tmp_path / "protA.pkl"
existing_pickle.write_bytes(b"already-there")
monomer = RecordingDummyMonomer("protA")
create_features.create_and_save_monomer_objects(monomer, pipeline=None)
assert monomer.feature_calls == []
assert monomer.mmseq_calls == []
assert list(tmp_path.glob("protA_feature_metadata_*.json")) == []
def test_create_and_save_monomer_objects_uses_mmseqs_when_requested(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = True
create_features.FLAGS.use_precomputed_msas = True
create_features.FLAGS.re_search_templates_mmseqs2 = True
monomer = RecordingDummyMonomer("protA")
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(monomer, pipeline=None)
assert monomer.feature_calls == []
assert monomer.mmseq_calls == [
{
"DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER,
"output_dir": str(tmp_path),
"use_precomputed_msa": True,
"use_templates": True,
"custom_template_path": None,
"skip_msa": False,
}
]
assert (tmp_path / "protA.pkl").exists()
def test_create_and_save_monomer_objects_passes_skip_msa_to_af2_builder(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.use_precomputed_msas = False
create_features.FLAGS.save_msa_files = False
create_features.FLAGS.skip_msa = True
monomer = RecordingDummyMonomer("protA")
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(monomer, pipeline="pipeline")
assert monomer.feature_calls == [
{
"pipeline": "pipeline",
"output_dir": str(tmp_path),
"use_precomputed_msa": False,
"save_msa": False,
"skip_msa": True,
}
]
def test_create_and_save_monomer_objects_passes_skip_msa_to_mmseqs_builder(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = True
create_features.FLAGS.use_precomputed_msas = False
create_features.FLAGS.re_search_templates_mmseqs2 = False
create_features.FLAGS.skip_msa = True
monomer = RecordingDummyMonomer("protA")
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(monomer, pipeline=None)
assert monomer.mmseq_calls == [
{
"DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER,
"output_dir": str(tmp_path),
"use_precomputed_msa": False,
"use_templates": False,
"custom_template_path": None,
"skip_msa": True,
}
]
def test_create_and_save_monomer_objects_passes_custom_templates_to_mmseqs(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = True
create_features.FLAGS.use_precomputed_msas = False
create_features.FLAGS.re_search_templates_mmseqs2 = False
monomer = RecordingDummyMonomer("protA")
custom_template_path = str(tmp_path / "custom_db" / "templates")
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(
monomer,
pipeline=None,
custom_template_path=custom_template_path,
)
assert monomer.feature_calls == []
assert monomer.mmseq_calls == [
{
"DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER,
"output_dir": str(tmp_path),
"use_precomputed_msa": False,
"use_templates": True,
"custom_template_path": custom_template_path,
"skip_msa": False,
}
]
assert (tmp_path / "protA.pkl").exists()
def test_create_and_save_monomer_objects_reuses_mmseq_identifier_sidecar(
tmp_flags, tmp_path, monkeypatch
):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = True
create_features.FLAGS.use_precomputed_msas = False
create_features.FLAGS.re_search_templates_mmseqs2 = False
a3m_text = "\n".join(
[
"# mmseqs header",
">101",
"ACDE",
">UniRef100_A0A636IKY3\t136\t0.883",
"ACDF",
"",
]
)
monkeypatch.setattr(
MonomericObject, "unzip_msa_files", staticmethod(lambda _path: False)
)
monkeypatch.setattr(
MonomericObject, "zip_msa_files", staticmethod(lambda _path: None)
)
monkeypatch.setattr(
objects_mod,
"get_msa_and_templates",
lambda **_kwargs: (
[a3m_text],
["PAIRED"],
["UNIQUE"],
["CARD"],
["TEMPLATE"],
),
)
monkeypatch.setattr(objects_mod, "msa_to_str", lambda *args: a3m_text)
monkeypatch.setattr(
objects_mod,
"unserialize_msa",
lambda a3m_lines, sequence: (
[a3m_text],
["PRECOMP_PAIRED"],
["UNIQUE"],
["CARD"],
["PRECOMP_TEMPLATE"],
),
)
monkeypatch.setattr(
objects_mod,
"build_monomer_feature",
lambda *_args, **_kwargs: {
"msa": np.asarray([[1, 2, 3, 4], [1, 2, 3, 5]], dtype=np.int32),
"deletion_matrix_int": np.asarray(
[[0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.int32
),
"template_confidence_scores": None,
"template_release_date": None,
},
)
first_calls = []
def fake_uniprot_batch(accessions, *, urlopen):
first_calls.append(tuple(accessions))
return {
"results": [
{
"primaryAccession": "A0A636IKY3",
"organism": {"taxonId": 562},
}
]
}
monkeypatch.setattr(
mmseqs_species_identifiers,
"_query_uniprot_batch",
fake_uniprot_batch,
)
monkeypatch.setattr(
mmseqs_species_identifiers,
"_query_uniparc_batch",
lambda accessions, *, urlopen: {"results": []},
)
with patch(
"alphapulldown.utils.save_meta_data.get_meta_dict",
return_value={"source": "test"},
):
first = MonomericObject("protA", "ACDE")
create_features.create_and_save_monomer_objects(first, pipeline=None)
assert first_calls == [("A0A636IKY3",)]
assert (tmp_path / "protA.a3m").exists()
assert (tmp_path / "protA.mmseq_ids.json").exists()
assert (tmp_path / "protA.pkl").exists()
create_features.FLAGS.use_precomputed_msas = True
mmseqs_species_identifiers._SPECIES_ID_CACHE.clear()
second_calls = []
def fail_uniprot_batch(accessions, *, urlopen):
second_calls.append(tuple(accessions))
raise AssertionError("expected mmseq sidecar cache to skip UniProt lookups")
monkeypatch.setattr(
mmseqs_species_identifiers,
"_query_uniprot_batch",
fail_uniprot_batch,
)
with patch(
"alphapulldown.utils.save_meta_data.get_meta_dict",
return_value={"source": "test"},
):
second = MonomericObject("protA", "ACDE")
create_features.create_and_save_monomer_objects(second, pipeline=None)
assert second_calls == []
assert second.feature_dict["msa_species_identifiers_all_seq"].tolist() == [
b"",
b"562",
]
assert second.feature_dict["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
b"",
b"A0A636IKY3",
]
def test_process_multimeric_features_uses_mmseqs_without_local_pipeline(tmp_flags, tmp_path):
template_path = tmp_path / "template.cif"
template_path.write_text("data_template\n", encoding="utf-8")
create_features.FLAGS.output_dir = str(tmp_path / "out")
create_features.FLAGS.use_mmseqs2 = True
feat = {
"protein": "complex_mmseqs",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDE",
}
with patch.object(create_features, "MonomericObject", RecordingDummyMonomer), \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)
mock_custom_db.assert_called_once()
mock_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_not_called()
mock_runner.assert_not_called()
saved_monomer, saved_pipeline = mock_save.call_args.args
saved_kwargs = mock_save.call_args.kwargs
assert saved_pipeline is None
assert saved_monomer.description == "complex_mmseqs"
assert saved_monomer.uniprot_runner is None
assert saved_kwargs == {"custom_template_path": "/tmp/custom_db/templates"}
def test_process_multimeric_features_skip_msa_avoids_uniprot_runner(tmp_flags, tmp_path):
template_path = tmp_path / "template.cif"
template_path.write_text("data_template\n", encoding="utf-8")
create_features.FLAGS.output_dir = str(tmp_path / "out")
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.skip_msa = True
feat = {
"protein": "complex_local",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDE",
}
with patch.object(create_features, "MonomericObject", RecordingDummyMonomer), \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="template-only-pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)
mock_custom_db.assert_called_once()
mock_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_called_once_with()
mock_runner.assert_not_called()
saved_monomer, saved_pipeline = mock_save.call_args.args
saved_kwargs = mock_save.call_args.kwargs
assert saved_pipeline == "template-only-pipeline"
assert saved_monomer.description == "complex_local"
assert saved_monomer.uniprot_runner is None
assert saved_kwargs == {"custom_template_path": None}
def test_create_custom_db_passes_thresholds_to_builder(tmp_flags):
create_features.FLAGS.threshold_clashes = 12.5
create_features.FLAGS.hb_allowance = 0.7
create_features.FLAGS.plddt_threshold = 42.0
with patch.object(create_features, "create_db") as mock_create_db:
db_path = create_features.create_custom_db("/tmp/base", "proteinX", ["a.cif"], ["A"])
assert str(db_path) == "/tmp/base/custom_template_db/proteinX"
mock_create_db.assert_called_once_with(
db_path,
["a.cif"],
["A"],
12.5,
0.7,
42.0,
)
def test_create_pipeline_af3_prefers_explicit_database_overrides(tmp_flags):
class DummyConfig:
def __init__(self, **kwargs):
self.kwargs = kwargs
create_features.FLAGS.max_template_date = "2021-09-30"
create_features.FLAGS.data_pipeline = "alphafold3"
create_features.FLAGS.data_dir = "/db"
create_features.FLAGS.small_bfd_database_path = "/override/small_bfd"
create_features.FLAGS.uniref90_database_path = "/override/uniref90"
create_features.FLAGS.template_mmcif_dir = "/override/mmcif"
with patch.object(create_features, "AF3DataPipelineConfig", side_effect=DummyConfig) as mock_config, \
patch.object(create_features, "AF3DataPipeline", side_effect=lambda config: config) as mock_pipeline:
config = create_features.create_pipeline_af3()
mock_config.assert_called_once()
mock_pipeline.assert_called_once()
assert config.kwargs["small_bfd_database_path"] == "/override/small_bfd"
assert config.kwargs["uniref90_database_path"] == "/override/uniref90"
assert config.kwargs["pdb_database_path"] == "/override/mmcif"
assert config.kwargs["mgnify_database_path"] == "/db/mgy_clusters_2022_05.fa"
assert config.kwargs["seqres_database_path"] == "/db/pdb_seqres_2022_09_28.fasta"
def test_main_rejects_af3_mmseqs2(tmp_flags, tmp_path):
create_features.FLAGS.data_pipeline = "alphafold3"
create_features.FLAGS.use_mmseqs2 = True
create_features.FLAGS.data_dir = None
create_features.FLAGS.output_dir = str(tmp_path / "af3_out")
with patch.object(create_features.logging, "error") as mock_error, \
pytest.raises(SystemExit):
create_features.main([])
mock_error.assert_called_once_with(
"AlphaFold3 does not support --use_mmseqs2. "
"Please provide local databases via --data_dir."
)
def test_create_af3_individual_features_falls_back_to_double_letter_chain_ids(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.seq_index = 27
af3_modules, folding_input_stub = build_af3_stub_modules()
del af3_modules["alphafold3.structure"].mmcif
af3_modules.pop("alphafold3.structure.mmcif")
sequences = [("ACDE", f"chain_{idx}") for idx in range(1, 28)]
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, "create_pipeline_af3", return_value=MagicMock(process=MagicMock(return_value={"plain": "json"}))), \
patch.object(create_features, "folding_input", folding_input_stub), \
patch.object(create_features, "iter_seqs", return_value=sequences), \
patch("pathlib.Path.write_text", new=real_write_text):
create_features.create_af3_individual_features()
outpath = tmp_path / "chain_27_af3_input.json"
assert outpath.exists()
assert json.loads(outpath.read_text(encoding="utf-8")) == {"plain": "json"}
def test_create_af3_individual_features_skips_existing_outputs(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.skip_existing = True
af3_modules, folding_input_stub = build_af3_stub_modules()
existing_output = tmp_path / "protA_af3_input.json"
existing_output.write_text("{}", encoding="utf-8")
pipeline = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, "create_pipeline_af3", return_value=pipeline), \
patch.object(create_features, "folding_input", folding_input_stub), \
patch.object(create_features, "iter_seqs", return_value=[("ACDE", "protA")]), \
patch("pathlib.Path.write_text", new=real_write_text):
create_features.create_af3_individual_features()
pipeline.process.assert_not_called()
assert existing_output.read_text(encoding="utf-8") == "{}"
def test_create_af3_individual_features_prefills_query_only_msas_when_skip_msa(
tmp_flags, tmp_path
):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.data_pipeline = "alphafold3"
create_features.FLAGS.skip_msa = True
af3_modules, folding_input_stub = build_af3_stub_modules()
pipeline = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, "create_pipeline_af3", return_value=pipeline), \
patch.object(create_features, "folding_input", folding_input_stub), \
patch.object(
create_features,
"iter_seqs",
return_value=[("ACDE", "protein_chain protein"), ("AUGA", "rna_chain RNA")],
), \
patch("pathlib.Path.write_text", new=real_write_text):
create_features.create_af3_individual_features()
protein_input = pipeline.process.call_args_list[0].args[0]
protein_chain = protein_input.chains[0]
assert protein_chain.unpaired_msa == ">query\nACDE\n"
assert protein_chain.paired_msa == ""
assert protein_chain.templates is None
rna_input = pipeline.process.call_args_list[1].args[0]
rna_chain = rna_input.chains[0]
assert rna_chain.unpaired_msa == ">query\nAUGA\n"
def test_main_dispatches_to_af3_feature_creation(tmp_flags, tmp_path):
create_features.FLAGS.data_pipeline = "alphafold3"
create_features.FLAGS.output_dir = str(tmp_path / "af3_out")
with patch.object(create_features, "create_af3_individual_features") as mock_af3, \
patch.object(create_features, "check_template_date") as mock_check:
create_features.main([])
mock_af3.assert_called_once_with()
mock_check.assert_not_called()