Files
AlphaPulldown/test/cluster/check_alphafold3_predictions.py
2026-04-23 10:28:45 +02:00

3598 lines
146 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
"""
Functional Alphapulldown tests for AlphaFold3 (parameterised).
The script is identical for Slurm and workstation users only the
wrapper decides *how* each case is executed.
"""
from __future__ import annotations
import lzma
import os
import subprocess
import time
import sys
import tempfile
import hashlib
from pathlib import Path
import shutil
import pickle
import json
import numpy as np
import re
import unittest
from types import SimpleNamespace
from typing import Dict, List, Tuple, Any
from unittest import mock
from absl.testing import absltest, parameterized
import alphapulldown
from alphafold3.constants import residue_names as af3_residue_names
from alphapulldown.objects import MultimericObject
from alphapulldown.utils.modelling_setup import (
create_custom_info,
create_interactors,
parse_fold,
)
from alphapulldown_input_parser import generate_fold_specifications
# --------------------------------------------------------------------------- #
# configuration / environment guards #
# --------------------------------------------------------------------------- #
# Point to the full Alphafold database once, via env-var.
DATA_DIR = os.getenv(
"ALPHAFOLD_DATA_DIR",
"/g/kosinski/dima/alphafold3_weights/" # default for EMBL cluster
)
if not os.path.exists(DATA_DIR):
absltest.skip("set $ALPHAFOLD_DATA_DIR to run Alphafold functional tests")
REPO_ROOT = Path(__file__).resolve().parents[2]
TEST_ROOT = REPO_ROOT / "test"
def _has_nvidia_gpu() -> bool:
nvidia_smi = shutil.which("nvidia-smi")
if not nvidia_smi:
return False
try:
result = subprocess.run(
[nvidia_smi, "-L"],
capture_output=True,
text=True,
check=False,
)
except OSError:
return False
return result.returncode == 0 and bool(result.stdout.strip())
def _gpu_functional_test_skip_reason() -> str | None:
if os.getenv("RUN_GPU_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
return None
if os.getenv("CI", "").lower() in ("1", "true", "yes") or os.getenv(
"GITHUB_ACTIONS", ""
).lower() == "true":
return (
"GPU functional tests are disabled on CI/CD. "
"Set RUN_GPU_FUNCTIONAL_TESTS=1 to override."
)
if not _has_nvidia_gpu():
return "GPU functional tests require an NVIDIA GPU and nvidia-smi."
return None
def _mmseqs_functional_test_skip_reason() -> str | None:
if os.getenv("RUN_MMSEQS_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
return None
return (
"MMseqs functional inference tests are disabled by default. "
"Set RUN_MMSEQS_FUNCTIONAL_TESTS=1 to enable."
)
def _a3m_sequences(a3m_text: str) -> list[str]:
if not a3m_text:
return []
lines = [line.strip() for line in a3m_text.splitlines() if line.strip()]
return [lines[index] for index in range(1, len(lines), 2)]
def _a3m_query_sequence(a3m_text: str) -> str:
sequences = _a3m_sequences(a3m_text)
return sequences[0] if sequences else ""
def _a3m_payload_sequences(a3m_text: str) -> list[str]:
sequences = _a3m_sequences(a3m_text)
return sequences[1:]
def _aligned_a3m_row_length(a3m_row: str) -> int:
return len(re.sub(r"[a-z]", "", a3m_row))
def _protein_entries_from_af3_input(payload: dict[str, Any]) -> list[dict[str, Any]]:
return [
sequence_entry["protein"]
for sequence_entry in payload.get("sequences", [])
if "protein" in sequence_entry
]
def _load_json_payload(path: Path) -> dict[str, Any]:
if path.suffix == ".xz":
with lzma.open(path, "rt", encoding="utf-8") as handle:
return json.load(handle)
return json.loads(path.read_text(encoding="utf-8"))
def _load_feature_metadata(feature_dir: Path, protein_id: str) -> tuple[Path, dict[str, Any]]:
matches = sorted(feature_dir.glob(f"{protein_id}_feature_metadata_*.json*"))
if len(matches) != 1:
raise AssertionError(
f"Expected exactly one metadata file for {protein_id} in {feature_dir}, found {matches}"
)
return matches[0], _load_json_payload(matches[0])
def _metadata_bool(value: Any) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"true", "1", "yes"}:
return True
if normalized in {"false", "0", "no", ""}:
return False
raise AssertionError(f"Unsupported metadata boolean value: {value!r}")
def _non_empty_a3m_payload_rows(a3m_text: str) -> list[str]:
return _a3m_payload_sequences(a3m_text) if a3m_text else []
def _load_feature_dict(feature_path: Path) -> dict[str, Any]:
payload = _load_feature_payload(feature_path)
if hasattr(payload, "feature_dict"):
return payload.feature_dict
return payload
def _load_feature_payload(feature_path: Path) -> Any:
opener = lzma.open if feature_path.suffix == ".xz" else open
with opener(feature_path, "rb") as handle:
return pickle.load(handle)
def _write_feature_payload(feature_path: Path, payload: Any) -> None:
opener = lzma.open if feature_path.suffix == ".xz" else open
with opener(feature_path, "wb") as handle:
pickle.dump(payload, handle)
def _non_empty_identifier_count(values) -> int:
count = 0
for value in values:
if isinstance(value, bytes):
value = value.decode("utf-8")
if str(value).strip():
count += 1
return count
# --------------------------------------------------------------------------- #
# common helper mix-in / assertions #
# --------------------------------------------------------------------------- #
class _TestBase(parameterized.TestCase):
use_temp_dir = True # Class variable to control directory behavior - default to True
@classmethod
def setUpClass(cls):
super().setUpClass()
skip_reason = _gpu_functional_test_skip_reason()
if skip_reason:
raise unittest.SkipTest(skip_reason)
# Create a base directory for all test outputs
if cls.use_temp_dir:
cls.base_output_dir = Path(tempfile.mkdtemp(prefix="af3_test_"))
else:
cls.base_output_dir = Path("test/test_data/predictions/af3_backend")
if cls.base_output_dir.exists():
try:
shutil.rmtree(cls.base_output_dir)
except (PermissionError, OSError) as e:
# If we can't remove the directory due to permissions, just warn and continue
print(f"Warning: Could not remove existing output directory {cls.base_output_dir}: {e}")
cls.base_output_dir.mkdir(parents=True, exist_ok=True)
def setUp(self):
super().setUp()
# directories inside the repo (relative to this file)
self.test_data_dir = TEST_ROOT / "test_data"
self.test_fastas_dir = self.test_data_dir / "fastas"
self.test_features_dir = self.test_data_dir / "features"
self.test_protein_lists_dir = self.test_data_dir / "protein_lists"
self.test_templates_dir = self.test_data_dir / "templates"
self.test_modelling_dir = self.test_data_dir / "predictions"
# Create a unique output directory for this test
test_name = self._testMethodName
self.output_dir = self.base_output_dir / test_name
self.output_dir.mkdir(parents=True, exist_ok=True)
# paths to alphapulldown CLI scripts
apd_path = Path(alphapulldown.__path__[0])
self.script_multimer = apd_path / "scripts" / "run_multimer_jobs.py"
self.script_single = apd_path / "scripts" / "run_structure_prediction.py"
self.script_create_features = (
apd_path / "scripts" / "create_individual_features.py"
)
@classmethod
def tearDownClass(cls):
super().tearDownClass()
# Clean up all test outputs after all tests are done
if cls.use_temp_dir and cls.base_output_dir.exists():
try:
shutil.rmtree(cls.base_output_dir)
except (PermissionError, OSError) as e:
# If we can't remove the temp directory, just warn
print(f"Warning: Could not remove temporary directory {cls.base_output_dir}: {e}")
# Try to remove individual files that we can
try:
for item in cls.base_output_dir.rglob("*"):
if item.is_file():
try:
item.unlink()
except (PermissionError, OSError):
pass # Skip files we can't remove
except Exception:
pass # Ignore any errors during cleanup
def _get_sequence_from_pkl(self, protein_name: str) -> str:
"""Extract sequence from a PKL file."""
pkl_path = self.test_features_dir / f"{protein_name}.pkl"
if pkl_path.exists():
with open(pkl_path, 'rb') as f:
monomeric_object = pickle.load(f)
if hasattr(monomeric_object, 'feature_dict'):
sequence = monomeric_object.feature_dict.get('sequence', [])
if len(sequence) > 0:
return sequence[0].decode('utf-8')
return None
def _get_sequence_from_fasta(self, protein_name: str) -> str:
"""Extract sequence from a FASTA file with case-insensitive search."""
fasta_path = self.test_fastas_dir / f"{protein_name}.fasta"
if not fasta_path.exists():
# Try case-insensitive search
for fasta_file in self.test_fastas_dir.glob("*.fasta"):
if fasta_file.stem.lower() == protein_name.lower():
fasta_path = fasta_file
break
if fasta_path.exists():
with open(fasta_path, 'r') as f:
lines = f.readlines()
if len(lines) >= 2:
return lines[1].strip()
return None
def _get_sequence_from_json(self, json_file: str) -> List[Tuple[str, str]]:
"""Extract sequences from a JSON file."""
sequences = []
json_path = self.test_features_dir / json_file
if json_path.exists():
with open(json_path, 'r') as f:
json_data = json.load(f)
json_sequences = json_data.get('sequences', [])
for seq_data in json_sequences:
if 'protein' in seq_data:
protein_seq = seq_data['protein']
chain_id = protein_seq.get('id', 'A')
sequence = protein_seq.get('sequence', '')
# Apply post-translational modifications if present
modifications = protein_seq.get('modifications', [])
if modifications:
sequence = self._apply_ptms_to_sequence(sequence, modifications)
sequences.append((chain_id, sequence))
elif 'rna' in seq_data:
rna_seq = seq_data['rna']
chain_id = rna_seq.get('id', 'A')
sequence = rna_seq.get('sequence', '')
sequences.append((chain_id, sequence))
elif 'dna' in seq_data:
dna_seq = seq_data['dna']
chain_id = dna_seq.get('id', 'A')
sequence = dna_seq.get('sequence', '')
sequences.append((chain_id, sequence))
elif 'ligand' in seq_data:
ligand_seq = seq_data['ligand']
chain_id = ligand_seq.get('id', 'L')
# For ligands, we use the CCD codes as the "sequence"
ccd_codes = ligand_seq.get('ccdCodes', [])
if ccd_codes:
# Join multiple CCD codes if present (e.g., ["ATP", "MG"] -> "ATP+MG")
sequence = '+'.join(ccd_codes)
else:
# Fallback to SMILES if no CCD codes
smiles = ligand_seq.get('smiles', '')
sequence = f"SMILES:{smiles}" if smiles else "UNKNOWN_LIGAND"
sequences.append((chain_id, sequence))
return sequences
def _apply_ptms_to_sequence(self, sequence: str, modifications: List[Dict]) -> str:
"""
Apply PTMs to the expected structure-side sequence representation.
Args:
sequence: Original protein sequence
modifications: List of PTM dictionaries with 'ptmType' and 'ptmPosition'
Returns:
Modified sequence with PTMs applied (same length as original)
"""
# Convert to list for easier modification
seq_list = list(sequence)
for ptm in modifications:
ptm_type = ptm.get('ptmType')
ptm_position = ptm.get('ptmPosition', 1) - 1 # Convert to 0-based indexing
if ptm_position < len(seq_list):
if ptm_type == "HYS":
seq_list[ptm_position] = "H"
elif ptm_type == "2MG":
seq_list[ptm_position] = "G"
else:
seq_list[ptm_position] = af3_residue_names.letters_three_to_one(
ptm_type,
default='X',
)
return ''.join(seq_list)
def _get_sequence_for_protein(self, protein_name: str, chain_id: str = 'A') -> str:
"""Get sequence for a single protein, trying PKL first, then FASTA."""
# Try PKL file first
sequence = self._get_sequence_from_pkl(protein_name)
if sequence:
return sequence
# Fallback to FASTA
sequence = self._get_sequence_from_fasta(protein_name)
if sequence:
return sequence
return None
def _chain_id_from_index(self, index: int) -> str:
"""Mirror AF3's reverse-spreadsheet chain ID progression."""
if index < 26:
return chr(ord('A') + index)
first_char = chr(ord('A') + (index // 26) - 1)
second_char = chr(ord('A') + (index % 26))
return first_char + second_char
def _get_region_sequences(self, protein_name: str, regions: list[tuple[int, int]]) -> list[str]:
"""Return one sequence fragment per requested 1-based closed interval."""
full_sequence = self._get_sequence_for_protein(protein_name)
if not full_sequence:
return []
region_sequences = []
for start, end in regions:
start_idx = start - 1
end_idx = end
region_sequences.append(full_sequence[start_idx:end_idx])
return region_sequences
def _process_homo_oligomer_line(self, line: str) -> List[Tuple[str, str]]:
"""Process a homo-oligomer line (format: 'PROTEIN,number')."""
if "," not in line:
return []
parts = line.split(",")
protein_name = parts[0].strip()
num_copies = int(parts[1].strip())
sequence = self._get_sequence_for_protein(protein_name)
if not sequence:
return []
sequences = []
for i in range(num_copies):
chain_id = chr(ord('A') + i)
sequences.append((chain_id, sequence))
return sequences
def _process_mixed_line(self, line: str) -> List[Tuple[str, str]]:
"""Process a line with multiple proteins/features separated by semicolons."""
if ";" not in line:
return []
sequences = []
parts = line.split(";")
for i, part in enumerate(parts):
part = part.strip()
if part.endswith('.json'):
# JSON input
json_sequences = self._get_sequence_from_json(part)
for chain_id, sequence in json_sequences:
if chain_id == 'A': # Use default chain ID if not specified
chain_id = chr(ord('A') + i)
sequences.append((chain_id, sequence))
else:
# Protein input (handle chopped proteins)
if "," in part:
# Extract protein name before first comma
protein_name = part.split(",")[0].strip()
else:
protein_name = part
sequence = self._get_sequence_for_protein(protein_name)
if sequence:
chain_id = chr(ord('A') + i)
sequences.append((chain_id, sequence))
return sequences
def _process_single_protein_line(self, line: str) -> List[Tuple[str, str]]:
"""Process a line with a single protein."""
part = line.strip()
if part.endswith('.json'):
# JSON input
return self._get_sequence_from_json(part)
else:
# Protein input (handle chopped proteins)
if "," in part:
# Extract protein name before first comma
protein_name = part.split(",")[0].strip()
else:
protein_name = part
sequence = self._get_sequence_for_protein(protein_name)
if sequence:
return [('A', sequence)]
return []
def _process_homo_oligomer_chopped_line(self, line: str) -> List[Tuple[str, str]]:
"""Process a homo-oligomer of chopped proteins (format: 'PROTEIN,number,regions')."""
if "," not in line:
return []
parts = line.split(",")
if len(parts) < 3:
return []
protein_name = parts[0].strip()
num_copies = int(parts[1].strip())
# Parse regions (everything after the number of copies)
regions = []
for region_str in parts[2:]:
if "-" in region_str:
s, e = region_str.split("-")
regions.append((int(s), int(e)))
# AF3 cannot represent immediately repeated author residue IDs at a
# region boundary (e.g. 6-7 followed by 7-8). Collapse only that shared
# boundary residue while keeping the explicit region naming unchanged.
normalized_regions = []
for start, end in regions:
if normalized_regions and start == normalized_regions[-1][1]:
start += 1
if start <= end:
normalized_regions.append((start, end))
region_sequences = self._get_region_sequences(protein_name, normalized_regions)
if not region_sequences:
return []
concatenated_sequence = "".join(region_sequences)
sequences = []
for copy_index in range(num_copies):
chain_id = self._chain_id_from_index(copy_index)
sequences.append((chain_id, concatenated_sequence))
return sequences
def _extract_expected_sequences(self, protein_list: str) -> List[Tuple[str, str]]:
"""
Extract expected sequences from input files based on test case name.
Args:
protein_list: Name of the protein list file
Returns:
List of tuples (chain_id, sequence) for expected chains
"""
expected_sequences = []
# Read the protein list file
protein_list_path = self.test_protein_lists_dir / protein_list
with open(protein_list_path, 'r') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
# Extract test case name from filename
test_case = protein_list.replace('.txt', '')
for line in lines:
match test_case:
case "test_homooligomer":
# Homo-oligomer format: "PROTEIN,number"
sequences = self._process_homo_oligomer_line(line)
case "test_monomer":
# Single protein
sequences = self._process_single_protein_line(line)
case "test_dimer" | "test_trimer" | "test_truemultimer":
# Multiple proteins separated by semicolons
sequences = self._process_mixed_line(line)
case "test_dimer_chopped":
# Chopped proteins (comma-separated ranges)
sequences = self._process_chopped_protein_line(line)
case "test_long_name":
# Homo-oligomer of chopped proteins: "PROTEIN,number,regions"
sequences = self._process_homo_oligomer_chopped_line(line)
case "test_monomer_with_rna" | "test_monomer_with_dna" | "test_monomer_with_ligand":
# Mixed inputs (protein + JSON)
sequences = self._process_mixed_line(line)
case "test_protein_with_ptms":
# JSON-only input
sequences = self._process_single_protein_line(line)
case "test_multi_seeds_samples":
# Test case for multiple seeds and diffusion samples (chopped protein)
sequences = self._process_chopped_protein_line(line)
case _:
# Default case: try to process as mixed line
sequences = self._process_mixed_line(line)
expected_sequences.extend(sequences)
return expected_sequences
def _process_chopped_protein_line(self, line: str) -> List[Tuple[str, str]]:
"""Process a line with chopped proteins (comma-separated ranges)."""
def parse_protein_and_regions(part: str):
# Example: A0A075B6L2,1-10,2-5,3-12
tokens = [x.strip() for x in part.split(",")]
protein_name = tokens[0]
regions = []
for region_str in tokens[1:]:
if "-" in region_str:
s, e = region_str.split("-")
regions.append((int(s), int(e)))
return protein_name, regions
if ";" in line:
# Multiple chopped proteins
sequences = []
parts = line.split(";")
for part in parts:
part = part.strip()
if "," in part:
protein_name, regions = parse_protein_and_regions(part)
region_sequences = self._get_region_sequences(protein_name, regions)
if not region_sequences:
continue
chain_id = self._chain_id_from_index(len(sequences))
sequences.append((chain_id, "".join(region_sequences)))
else:
protein_name = part
sequence = self._get_sequence_for_protein(protein_name)
if not sequence:
continue
chain_id = self._chain_id_from_index(len(sequences))
sequences.append((chain_id, sequence))
return sequences
else:
# Single chopped protein
part = line.strip()
if "," in part:
protein_name, regions = parse_protein_and_regions(part)
region_sequences = self._get_region_sequences(protein_name, regions)
if region_sequences:
return [('A', "".join(region_sequences))]
else:
protein_name = part
sequence = self._get_sequence_for_protein(protein_name)
if sequence:
return [('A', sequence)]
return []
def _extract_cif_chains_and_sequences(self, cif_path: Path) -> List[Tuple[str, str]]:
"""
Extract chain IDs and sequences from a CIF file.
Args:
cif_path: Path to the CIF file
Returns:
List of tuples (chain_id, sequence) for chains in the CIF file
"""
chains_and_sequences = []
try:
from alphafold3.cpp import cif_dict
with open(cif_path, "rt") as handle:
cif = cif_dict.from_string(handle.read())
sequences_by_chain = {}
if "_pdbx_poly_seq_scheme.asym_id" in cif:
asym_ids = cif.get_array("_pdbx_poly_seq_scheme.asym_id", dtype=object)
mon_ids = cif.get_array("_pdbx_poly_seq_scheme.mon_id", dtype=object)
for chain_id, mon_id in zip(asym_ids, mon_ids, strict=True):
sequence = sequences_by_chain.setdefault(chain_id, "")
if mon_id in self._protein_letters_3to1:
sequence += self._protein_letters_3to1[mon_id]
elif mon_id in self._dna_letters_3to1:
sequence += self._dna_letters_3to1[mon_id]
elif mon_id in self._rna_letters_3to1:
sequence += self._rna_letters_3to1[mon_id]
elif mon_id + " " in self._rna_letters_3to1:
sequence += self._rna_letters_3to1[mon_id + " "]
elif mon_id + " " in self._dna_letters_3to1:
sequence += self._dna_letters_3to1[mon_id + " "]
elif mon_id == "HYS":
sequence += "H"
elif mon_id == "2MG":
sequence += "G"
else:
sequence += "X"
sequences_by_chain[chain_id] = sequence
for scheme_prefix in ("_pdbx_nonpoly_scheme", "_pdbx_branch_scheme"):
asym_key = f"{scheme_prefix}.asym_id"
mon_key = f"{scheme_prefix}.mon_id"
if asym_key not in cif or mon_key not in cif:
continue
asym_ids = cif.get_array(asym_key, dtype=object)
mon_ids = cif.get_array(mon_key, dtype=object)
for chain_id, mon_id in zip(asym_ids, mon_ids, strict=True):
if mon_id in {"HOH", "DOD"}:
continue
sequence = sequences_by_chain.setdefault(chain_id, "")
ligand_codes = [] if not sequence else sequence.split("+")
ligand_codes.append(mon_id if mon_id in self._ligand_ccd_codes else "UNKNOWN")
sequences_by_chain[chain_id] = "+".join(ligand_codes)
chain_order = (
list(cif.get_array("_struct_asym.id", dtype=object))
if "_struct_asym.id" in cif
else list(sequences_by_chain.keys())
)
for chain_id in chain_order:
sequence = sequences_by_chain.get(chain_id)
if sequence:
chains_and_sequences.append((chain_id, sequence))
if chains_and_sequences:
return chains_and_sequences
except ImportError:
pass
except Exception as e:
print(f"Error parsing CIF with AF3 cif_dict: {e}")
try:
from Bio.PDB import MMCIFParser
# Parse the CIF file
parser = MMCIFParser(QUIET=True)
structure = parser.get_structure("model", str(cif_path))
# Get the first model (should be the only one for AlphaFold3)
model = structure[0]
# Extract sequences for each chain
for chain in model:
chain_id = chain.id
# Keep the residue order from the file instead of sorting by
# residue number so discontinuous numbering remains testable.
residues = list(chain.get_residues())
# Separate standard residues from HETATM records
standard_residues = []
hetatm_residues = []
for residue in residues:
hetfield, resseq, icode = residue.id
res_name = residue.resname
if hetfield == " ":
# Standard residue (protein, DNA, RNA)
standard_residues.append((resseq, res_name))
elif hetfield != "W": # Skip water molecules
# HETATM record (ligand or PTM)
hetatm_residues.append((resseq, res_name))
# Check if this chain contains any HETATM records (ligands)
has_ligand_hetatm = any(res_name in self._ligand_ccd_codes for _, res_name in hetatm_residues)
if has_ligand_hetatm:
# This is a ligand chain - extract HETATM residues
ligand_codes = []
for _, res_name in hetatm_residues:
if res_name in self._ligand_ccd_codes:
ligand_codes.append(res_name)
else:
ligand_codes.append("UNKNOWN")
if ligand_codes:
# Join multiple ligand codes if present (e.g., ["ATP", "MG"] -> "ATP+MG")
sequence = '+'.join(ligand_codes)
else:
sequence = "UNKNOWN_LIGAND"
else:
# This is a polymer chain (protein, DNA, RNA) - extract base sequence
sequence = ""
unknown_residues = []
for _, res_name in standard_residues:
# Try protein first
if res_name in self._protein_letters_3to1:
sequence += self._protein_letters_3to1[res_name]
# Try DNA
elif res_name in self._dna_letters_3to1:
sequence += self._dna_letters_3to1[res_name]
# Try RNA
elif res_name in self._rna_letters_3to1:
sequence += self._rna_letters_3to1[res_name]
# Try RNA with spaces (PDBData format)
elif res_name + " " in self._rna_letters_3to1:
sequence += self._rna_letters_3to1[res_name + " "]
# Try DNA with spaces (PDBData format)
elif res_name + " " in self._dna_letters_3to1:
sequence += self._dna_letters_3to1[res_name + " "]
else:
sequence += "X" # Unknown residue
unknown_residues.append(res_name)
# Debug: print unknown residues
if unknown_residues:
print(f"Warning: Unknown residues in chain {chain_id}: {set(unknown_residues)}")
# Apply PTMs from HETATM records if present
if hetatm_residues and sequence:
sequence = self._apply_ptms_from_hetatm(sequence, hetatm_residues)
if sequence: # Only add if we have a sequence
chains_and_sequences.append((chain_id, sequence))
except ImportError:
# Fallback to regex parsing if Biopython is not available
print("Warning: Biopython not available, using regex parsing")
chains_and_sequences = self._extract_cif_chains_and_sequences_regex(cif_path)
except Exception as e:
print(f"Error parsing CIF with Biopython: {e}")
# Fallback to regex parsing
chains_and_sequences = self._extract_cif_chains_and_sequences_regex(cif_path)
return chains_and_sequences
def _extract_cif_chain_residue_numbers(self, cif_path: Path) -> List[Tuple[str, List[Union[int, str]]]]:
"""Extract author-facing residue numbers for each polymer chain from a CIF file."""
try:
from alphafold3.cpp import cif_dict
with open(cif_path, "rt") as handle:
cif = cif_dict.from_string(handle.read())
asym_ids = cif.get_array("_pdbx_poly_seq_scheme.asym_id", dtype=object)
auth_seq_nums = cif.get_array(
"_pdbx_poly_seq_scheme.auth_seq_num", dtype=object
)
ins_codes = cif.get_array(
"_pdbx_poly_seq_scheme.pdb_ins_code", dtype=object
)
chain_residue_numbers = []
chain_to_numbers = {}
for chain_id, auth_seq_num, ins_code in zip(
asym_ids,
auth_seq_nums,
ins_codes,
strict=True,
):
residue_numbers = chain_to_numbers.setdefault(chain_id, [])
ins_code = str(ins_code)
auth_seq_num = int(auth_seq_num)
if ins_code in {".", "?"}:
residue_numbers.append(auth_seq_num)
else:
residue_numbers.append(f"{auth_seq_num}{ins_code}")
for chain_id, residue_numbers in chain_to_numbers.items():
if residue_numbers:
chain_residue_numbers.append((chain_id, residue_numbers))
return chain_residue_numbers
except Exception as exc:
self.fail(f"Failed to extract CIF residue numbers from {cif_path}: {exc}")
def _apply_ptms_from_hetatm(self, sequence: str, hetatm_residues: List[Tuple[int, str]]) -> str:
"""
Apply PTMs from HETATM records to the protein sequence.
Args:
sequence: Base protein sequence
hetatm_residues: List of (residue_number, residue_name) tuples from HETATM records
Returns:
Modified sequence with PTMs applied
"""
# Convert to list for easier modification
seq_list = list(sequence)
for resseq, res_name in hetatm_residues:
ptm_position = resseq - 1 # Convert to 0-based indexing
if ptm_position < len(seq_list):
if res_name == "HYS":
# N-terminal histidine modification - replace N-terminal methionine with HYS
if ptm_position == 0 and seq_list[0] == 'M':
# Replace M with H (histidine) - HYS is the CCD code, but we use H for sequence
seq_list[0] = 'H'
elif res_name == "2MG":
# 2-methylguanosine modification - replace G with modified G
# For simplicity, we'll keep it as G since the exact representation may vary
pass
# Add more PTM types as needed
else:
print(f"Warning: Unknown PTM type '{res_name}' at position {ptm_position + 1}")
return ''.join(seq_list)
@property
def _dna_letters_3to1(self):
"""DNA three-letter to one-letter code mapping using Bio.Data.PDBData."""
try:
from Bio.Data.PDBData import nucleic_letters_3to1_extended
return nucleic_letters_3to1_extended
except ImportError:
# Fallback if PDBData is not available
return {
'DA': 'A', # deoxyadenosine
'DT': 'T', # deoxythymidine
'DG': 'G', # deoxyguanosine
'DC': 'C', # deoxycytidine
}
@property
def _rna_letters_3to1(self):
"""RNA three-letter to one-letter code mapping using Bio.Data.PDBData."""
try:
from Bio.Data.PDBData import nucleic_letters_3to1_extended
return nucleic_letters_3to1_extended
except ImportError:
# Fallback if PDBData is not available
return {
'A': 'A', # adenosine
'U': 'U', # uridine
'G': 'G', # guanosine
'C': 'C', # cytidine
}
@property
def _protein_letters_3to1(self):
"""Protein three-letter to one-letter code mapping using Bio.Data.PDBData."""
try:
from Bio.Data.PDBData import protein_letters_3to1_extended
return protein_letters_3to1_extended
except ImportError:
# Fallback if PDBData is not available
from Bio.Data.IUPACData import protein_letters_3to1
return {**protein_letters_3to1, 'UNK': 'X'}
@property
def _ligand_ccd_codes(self):
"""Common ligand CCD codes that might appear in CIF files."""
return {
'ATP', 'ADP', 'AMP', 'GTP', 'GDP', 'GMP', 'CTP', 'CDP', 'CMP',
'UTP', 'UDP', 'UMP', 'NAD', 'NADH', 'FAD', 'FADH2', 'COA',
'HEM', 'MG', 'CA', 'ZN', 'FE', 'CU', 'MN', 'K', 'NA', 'CL',
'SO4', 'PO4', 'NO3', 'CO3', 'HCO3', 'OH', 'H2O', 'DMS', 'EDO',
'GOL', 'PEG', 'PEO', 'MPD', 'BME', 'DTT', 'TCEP', 'GSH', 'GSSG'
}
def _extract_cif_chains_and_sequences_regex(self, cif_path: Path) -> List[Tuple[str, str]]:
"""
Fallback method to extract chain IDs and sequences from a CIF file using regex.
Args:
cif_path: Path to the CIF file
Returns:
List of tuples (chain_id, sequence) for chains in the CIF file
"""
chains_and_sequences = []
with open(cif_path, 'r') as f:
cif_content = f.read()
# Extract unique chain IDs from _struct_asym table
# Format: chain_id entity_id (e.g., "A 1")
struct_asym_pattern = r'([A-Z]+)\s+(\d+)'
struct_asym_matches = re.findall(struct_asym_pattern, cif_content)
# Create mapping of entity_id to chain_ids
entity_to_chains = {}
for chain_id, entity_id in struct_asym_matches:
entity_id = int(entity_id)
if entity_id not in entity_to_chains:
entity_to_chains[entity_id] = []
entity_to_chains[entity_id].append(chain_id)
# Extract sequences for each entity from _entity_poly_seq table
# Format: entity_id num mon_id (e.g., "1 n MET 1" or "2 n DA 1")
entity_poly_seq_pattern = r'(\d+)\s+n\s+([A-Z]{2,3})\s+(\d+)'
entity_poly_seq_matches = re.findall(entity_poly_seq_pattern, cif_content)
# Group residues by entity_id
entity_sequences = {}
for entity_id, mon_id, num in entity_poly_seq_matches:
entity_id = int(entity_id)
if entity_id not in entity_sequences:
entity_sequences[entity_id] = []
entity_sequences[entity_id].append((int(num), mon_id))
# Extract ligand information from _pdbx_nonpoly_scheme entries
# Look for single entries (not loops) with format:
# _pdbx_nonpoly_scheme.asym_id L
# _pdbx_nonpoly_scheme.mon_id ATP
nonpoly_asym_pattern = r'_pdbx_nonpoly_scheme\.asym_id\s+([A-Z]+)'
nonpoly_mon_pattern = r'_pdbx_nonpoly_scheme\.mon_id\s+([A-Z0-9]+)'
nonpoly_asym_matches = re.findall(nonpoly_asym_pattern, cif_content)
nonpoly_mon_matches = re.findall(nonpoly_mon_pattern, cif_content)
# Create ligand chains directly
for asym_id, mon_id in zip(nonpoly_asym_matches, nonpoly_mon_matches):
chains_and_sequences.append((asym_id, mon_id))
# Convert three-letter codes to one-letter sequences for polymer entities
try:
# Use comprehensive dictionaries from PDBData
three_to_one = {}
three_to_one.update(self._protein_letters_3to1)
three_to_one.update(self._dna_letters_3to1)
three_to_one.update(self._rna_letters_3to1)
except ImportError:
# Fallback if PDBData is not available
from Bio.Data.IUPACData import protein_letters_3to1
three_to_one = {**protein_letters_3to1, 'UNK': 'X'}
# Add DNA and RNA mappings
three_to_one.update(self._dna_letters_3to1)
three_to_one.update(self._rna_letters_3to1)
# Build sequences for each polymer entity
for entity_id, residues in entity_sequences.items():
# Sort by residue number
residues.sort(key=lambda x: x[0])
sequence = ''.join([three_to_one.get(res[1], 'X') for res in residues])
# Get chain IDs for this entity - only add one entry per chain
if entity_id in entity_to_chains:
for chain_id in entity_to_chains[entity_id]:
# Check if we already have this chain_id to avoid duplicates
if not any(existing_chain_id == chain_id for existing_chain_id, _ in chains_and_sequences):
chains_and_sequences.append((chain_id, sequence))
return chains_and_sequences
def _assert_exact_chain_mapping(
self,
expected_sequences: List[Tuple[str, str]],
actual_chains_and_sequences: List[Tuple[str, str]],
*,
context: str,
) -> None:
"""Assert an exact chain-id to sequence mapping, independent of file order."""
expected_dict = dict(expected_sequences)
actual_dict = dict(actual_chains_and_sequences)
self.assertLen(
expected_dict,
len(expected_sequences),
f"{context}: expected chain IDs must be unique",
)
self.assertLen(
actual_dict,
len(actual_chains_and_sequences),
f"{context}: actual chain IDs must be unique",
)
print(f"Expected exact chain mapping for {context}: {expected_dict}")
print(f"Actual exact chain mapping for {context}: {actual_dict}")
self.assertEqual(
actual_dict,
expected_dict,
f"{context}: exact chain mapping mismatch",
)
def _requires_exact_chain_mapping(self, protein_list: str) -> bool:
"""Cases where inference must preserve the explicit input chain IDs."""
return protein_list in {
"test_monomer_with_rna.txt",
"test_monomer_with_dna.txt",
"test_monomer_with_ligand.txt",
"test_protein_with_ptms.txt",
}
def _check_chain_counts_and_sequences(self, protein_list: str):
"""
Check that the predicted CIF files have the correct number of chains
and that the sequences match the expected input sequences.
Args:
protein_list: Name of the protein list file
"""
# Get expected sequences from input files
expected_sequences = self._extract_expected_sequences(protein_list)
print(f"\nExpected sequences: {expected_sequences}")
# Find the predicted CIF file (should be in the output directory)
result_dir = self._resolve_single_af3_result_dir()
cif_files = list(result_dir.glob("*_model.cif"))
if not cif_files:
self.fail("No predicted CIF files found")
# Use the first CIF file (should be the best ranked one)
cif_path = cif_files[0]
print(f"Checking CIF file: {cif_path}")
# Extract chains and sequences from the CIF file
actual_chains_and_sequences = self._extract_cif_chains_and_sequences(cif_path)
print(f"Actual chains and sequences: {actual_chains_and_sequences}")
# Check that the number of chains matches
self.assertEqual(
len(actual_chains_and_sequences),
len(expected_sequences),
f"Expected {len(expected_sequences)} chains, but found {len(actual_chains_and_sequences)}"
)
if self._requires_exact_chain_mapping(protein_list):
self._assert_exact_chain_mapping(
expected_sequences,
actual_chains_and_sequences,
context=protein_list,
)
return
actual_sequences = [seq for _, seq in actual_chains_and_sequences]
expected_sequences_only = [seq for _, seq in expected_sequences]
# Sort sequences for comparison (since chain order might vary)
actual_sequences.sort()
expected_sequences_only.sort()
self.assertEqual(
actual_sequences,
expected_sequences_only,
f"Sequences don't match. Expected: {expected_sequences_only}, Actual: {actual_sequences}"
)
def _make_af3_test_env(self) -> Dict[str, str]:
flash_impl = self._af3_flash_attention_impl()
env = os.environ.copy()
env["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter --xla_gpu_force_compilation_parallelism=0"
env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
env["XLA_CLIENT_MEM_FRACTION"] = "0.95"
env["JAX_FLASH_ATTENTION_IMPL"] = flash_impl
if "XLA_PYTHON_CLIENT_MEM_FRACTION" in env:
del env["XLA_PYTHON_CLIENT_MEM_FRACTION"]
return env
def _af3_flash_attention_impl(self) -> str:
return os.getenv("AF3_TEST_FLASH_ATTENTION_IMPL", "xla")
def _require_af3_functional_environment(self) -> None:
if not os.path.exists(DATA_DIR):
self.skipTest(
f"AF3 functional tests require ALPHAFOLD_DATA_DIR; missing path: {DATA_DIR}"
)
def _assert_af3_outputs_present(self, output_dir: Path) -> None:
files = list(output_dir.iterdir())
print(f"contents of {output_dir}: {[f.name for f in files]}")
self.assertIn("TERMS_OF_USE.md", {f.name for f in files})
self.assertIn("ranking_scores.csv", {f.name for f in files})
conf_files = [f for f in files if f.name.endswith("_confidences.json")]
summary_conf_files = [f for f in files if f.name.endswith("_summary_confidences.json")]
model_files = [f for f in files if f.name.endswith("_model.cif")]
self.assertTrue(len(conf_files) > 0, f"No confidences.json files found in {output_dir}")
self.assertTrue(len(summary_conf_files) > 0, f"No summary_confidences.json files found in {output_dir}")
self.assertTrue(len(model_files) > 0, f"No model.cif files found in {output_dir}")
sample_dirs = [
f for f in files if f.is_dir() and f.name.startswith("seed-") and "sample-" in f.name
]
for sample_dir in sample_dirs:
sample_files = list(sample_dir.iterdir())
self.assertIn("confidences.json", {f.name for f in sample_files})
self.assertIn("model.cif", {f.name for f in sample_files})
self.assertIn("summary_confidences.json", {f.name for f in sample_files})
with open(output_dir / "ranking_scores.csv") as f:
lines = f.readlines()
self.assertTrue(len(lines) > 1, "ranking_scores.csv should have header and data")
self.assertEqual(len(lines[0].strip().split(",")), 3, "ranking_scores.csv should have 3 columns")
seeds_in_csv = {ln.strip().split(",")[0] for ln in lines[1:] if ln.strip()}
def _seed_from_dirname(name: str) -> str:
try:
part = name.split("seed-")[1]
return part.split("_")[0]
except Exception:
return ""
sample_dirs_for_this_run = [d for d in sample_dirs if _seed_from_dirname(d.name) in seeds_in_csv]
expected_sample_dirs = len(lines) - 1
self.assertEqual(
len(sample_dirs_for_this_run), expected_sample_dirs,
f"Expected {expected_sample_dirs} sample directories, found {len(sample_dirs_for_this_run)}"
)
for i, line in enumerate(lines[1:], 1):
parts = line.strip().split(",")
self.assertEqual(len(parts), 3, f"Line {i+1} should have 3 columns: seed,sample,ranking_score")
try:
int(parts[0])
int(parts[1])
float(parts[2])
except ValueError:
self.fail(f"Line {i+1} has invalid format: {line.strip()}")
print(f"✓ Verified ranking_scores.csv has correct format with {len(lines)-1} entries")
def _resolve_single_af3_result_dir(self) -> Path:
"""Return the actual AF3 result directory for single-job tests."""
if (self.output_dir / "ranking_scores.csv").exists():
return self.output_dir
candidate_dirs = [
path
for path in self.output_dir.iterdir()
if path.is_dir() and (path / "ranking_scores.csv").exists()
]
if len(candidate_dirs) == 1:
print(f"Resolved nested AF3 result dir: {candidate_dirs[0]}")
return candidate_dirs[0]
return self.output_dir
# ---------------- assertions reused by all subclasses ----------------- #
def _runCommonTests(self, res: subprocess.CompletedProcess):
print(res.stdout)
print(res.stderr)
self.assertEqual(res.returncode, 0, "sub-process failed")
self._assert_af3_outputs_present(self._resolve_single_af3_result_dir())
# convenience builder
def _args(self, *, plist, script):
flash_impl = self._af3_flash_attention_impl()
# Determine mode from protein list name
if "homooligomer" in plist:
mode = "homo-oligomer"
else:
mode = "custom"
if script == "run_structure_prediction.py":
# Format from run_multimer_jobs.py input to run_structure_prediction.py input
specifications = generate_fold_specifications(
input_files=[str(self.test_protein_lists_dir / plist)],
delimiter="+",
exclude_permutations=True,
)
formatted_input_lines = [
spec.replace(",", ":").replace(";", "+")
for spec in specifications
if spec.strip()
]
formatted_input = formatted_input_lines[0] if formatted_input_lines else ""
args = [
sys.executable,
str(self.script_single),
f"--input={formatted_input}",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
]
# Add special arguments for multi_seeds_samples test
if "multi_seeds_samples" in plist:
args.extend([
"--num_seeds=3",
"--num_diffusion_samples=4",
])
return args
elif script == "run_multimer_jobs.py":
args = [
sys.executable,
str(self.script_multimer),
"--num_cycle=1",
"--num_predictions_per_model=1",
f"--data_dir={DATA_DIR}",
f"--monomer_objects_dir={self.test_features_dir}",
"--job_index=1",
f"--output_path={self.output_dir}",
f"--mode={mode}",
"--oligomer_state_file"
if mode == "homo-oligomer"
else "--protein_lists"
+ f"={self.test_protein_lists_dir / plist}",
# Ensure AF3 backend and keep runtime small
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
]
return args
# --------------------------------------------------------------------------- #
# backend-only AF3 preparation tests #
# --------------------------------------------------------------------------- #
class _BackendOnlyTestBase(_TestBase):
"""Backend-only AF3 preparation tests that do not run model inference."""
@classmethod
def setUpClass(cls):
parameterized.TestCase.setUpClass()
if cls.use_temp_dir:
cls.base_output_dir = Path(tempfile.mkdtemp(prefix="af3_backend_test_"))
else:
cls.base_output_dir = Path("test/test_data/predictions/af3_backend")
if cls.base_output_dir.exists():
try:
shutil.rmtree(cls.base_output_dir)
except (PermissionError, OSError) as e:
print(
"Warning: Could not remove existing output directory "
f"{cls.base_output_dir}: {e}"
)
cls.base_output_dir.mkdir(parents=True, exist_ok=True)
class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
"""AF3 input-construction regressions; these tests do not assert end-to-end ipTM quality."""
ISSUE_588_IDS = ("A0ABD7FQG0", "P18004")
def _require_issue_588_mmseqs_environment(self) -> None:
skip_reason = _mmseqs_functional_test_skip_reason()
if skip_reason:
self.skipTest(skip_reason)
for protein_id in self.ISSUE_588_IDS:
fasta_path = self.test_data_dir / "fastas" / f"{protein_id}.fasta"
self.assertTrue(
fasta_path.is_file(),
f"Missing FASTA fixture {fasta_path}",
)
def _generate_issue_588_precomputed_mmseq_features(self, env: Dict[str, str]) -> Path:
source_dir = self.output_dir / "issue_588_mmseq_source_features"
precomputed_dir = self.output_dir / "issue_588_mmseq_precomputed_features"
source_dir.mkdir(parents=True, exist_ok=True)
precomputed_dir.mkdir(parents=True, exist_ok=True)
fasta_paths = ",".join(
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
for protein_id in self.ISSUE_588_IDS
)
source_res = subprocess.run(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={source_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--data_pipeline=alphafold2",
"--save_msa_files=True",
"--compress_features=True",
"--skip_existing=False",
],
capture_output=True,
text=True,
env=env,
)
self.assertEqual(
source_res.returncode,
0,
"MMseqs source feature generation failed.\n"
f"STDOUT:\n{source_res.stdout}\nSTDERR:\n{source_res.stderr}",
)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(source_dir / f"{protein_id}.a3m").is_file(),
f"Expected MMseq A3M {source_dir / f'{protein_id}.a3m'} to be created.",
)
self.assertTrue(
(source_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected compressed feature pickle {source_dir / f'{protein_id}.pkl.xz'} to be created.",
)
shutil.copy2(
source_dir / f"{protein_id}.a3m",
precomputed_dir / f"{protein_id}.a3m",
)
precomputed_res = subprocess.run(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={precomputed_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--use_precomputed_msas=True",
"--data_pipeline=alphafold2",
"--compress_features=True",
"--skip_existing=False",
],
capture_output=True,
text=True,
env=env,
)
self.assertEqual(
precomputed_res.returncode,
0,
"Precomputed-MMseq feature generation failed.\n"
f"STDOUT:\n{precomputed_res.stdout}\nSTDERR:\n{precomputed_res.stderr}",
)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(precomputed_dir / f"{protein_id}.a3m").is_file(),
f"Expected copied MMseq A3M {precomputed_dir / f'{protein_id}.a3m'} to be present.",
)
self.assertTrue(
(precomputed_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected precomputed feature pickle {precomputed_dir / f'{protein_id}.pkl.xz'} to be created.",
)
return precomputed_dir
def _prepare_fold_input(
self,
*,
fold_spec: str,
feature_dir: Path,
debug_msas: bool = False,
):
from alphapulldown.folding_backend.alphafold3_backend import AlphaFold3Backend
parsed = parse_fold([fold_spec], [str(feature_dir)], "+")
data = create_custom_info(parsed)
all_interactors = create_interactors(data, [str(feature_dir)])
self.assertLen(all_interactors, 1)
self.assertGreaterEqual(len(all_interactors[0]), 1)
interactors = all_interactors[0]
if len(interactors) == 1:
object_to_model = interactors[0]
else:
object_to_model = MultimericObject(interactors=interactors, pair_msa=True)
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=[
{"object": object_to_model, "output_dir": str(self.output_dir)}
],
random_seed=42,
debug_msas=debug_msas,
)
self.assertLen(mappings, 1)
fold_input_obj, _ = next(iter(mappings[0].items()))
return fold_input_obj
def _copy_real_feature_fixture(
self,
*,
source_dir: Path,
protein_id: str,
target_dir: Path,
) -> Path:
copied_feature_path = None
for pattern in (
f"{protein_id}.pkl",
f"{protein_id}.pkl.xz",
f"{protein_id}.a3m",
f"{protein_id}_feature_metadata_*.json*",
):
for source_path in sorted(source_dir.glob(pattern)):
target_path = target_dir / source_path.name
shutil.copy2(source_path, target_path)
if source_path.name.startswith(f"{protein_id}.pkl"):
copied_feature_path = target_path
self.assertIsNotNone(
copied_feature_path,
f"Missing real feature fixture for {protein_id} in {source_dir}",
)
return copied_feature_path
@staticmethod
def _synthetic_accession_ids(species_ids: np.ndarray) -> np.ndarray:
identifiers = []
for index, value in enumerate(species_ids):
if isinstance(value, bytes):
value = value.decode("utf-8")
identifiers.append(
f"ACC{index:05d}".encode("utf-8") if str(value).strip() else b""
)
return np.asarray(identifiers, dtype=object)
def _prepare_mixed_identifier_fixture_dir(self) -> Path:
"""Materialize real AF2 fixtures with mixed identifier enrichment.
The underlying MSA rows come from repo fixtures in `test/test_data`.
We only adjust the identifier sidecars so one chain looks enriched while
the other reproduces the "no species enrichment / no accession IDs"
failure mode from issue #614's AF3 follow-up comment.
"""
feature_dir = self.output_dir / "mixed_identifier_features"
feature_dir.mkdir(parents=True, exist_ok=True)
source_dir = self.test_features_dir / "af2_features" / "protein"
enriched_feature_path = self._copy_real_feature_fixture(
source_dir=source_dir,
protein_id="A0A024R1R8",
target_dir=feature_dir,
)
unenriched_feature_path = self._copy_real_feature_fixture(
source_dir=source_dir,
protein_id="P61626",
target_dir=feature_dir,
)
enriched_payload = _load_feature_payload(enriched_feature_path)
enriched_feature_dict = (
enriched_payload.feature_dict
if hasattr(enriched_payload, "feature_dict")
else enriched_payload
)
enriched_feature_dict["msa_uniprot_accession_identifiers"] = (
self._synthetic_accession_ids(
np.asarray(enriched_feature_dict["msa_species_identifiers"])
)
)
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"] = (
self._synthetic_accession_ids(
np.asarray(enriched_feature_dict["msa_species_identifiers_all_seq"])
)
)
_write_feature_payload(enriched_feature_path, enriched_payload)
unenriched_payload = _load_feature_payload(unenriched_feature_path)
unenriched_feature_dict = (
unenriched_payload.feature_dict
if hasattr(unenriched_payload, "feature_dict")
else unenriched_payload
)
unenriched_feature_dict["msa_species_identifiers"] = np.asarray(
[b""] * int(np.asarray(unenriched_feature_dict["msa"]).shape[0]),
dtype=object,
)
unenriched_feature_dict["msa_species_identifiers_all_seq"] = np.asarray(
[b""] * int(np.asarray(unenriched_feature_dict["msa_all_seq"]).shape[0]),
dtype=object,
)
unenriched_feature_dict.pop("msa_uniprot_accession_identifiers", None)
unenriched_feature_dict.pop(
"msa_uniprot_accession_identifiers_all_seq", None
)
_write_feature_payload(unenriched_feature_path, unenriched_payload)
return feature_dir
def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self):
"""Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures."""
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
issue_588_dir = self.test_features_dir / "issue_588"
for protein_id in ("A0ABD7FQG0", "P18004"):
metadata_path, metadata = _load_feature_metadata(issue_588_dir, protein_id)
other = metadata["other"]
self.assertTrue(
_metadata_bool(other["use_mmseqs2"]),
f"{metadata_path} is not a mmseqs2-generated AF2 fixture.",
)
self.assertEqual(other["data_pipeline"], "alphafold2")
self.assertFalse(_metadata_bool(other["re_search_templates_mmseqs2"]))
fold_input_obj = self._prepare_fold_input(
fold_spec="A0ABD7FQG0+P18004",
feature_dir=issue_588_dir,
debug_msas=True,
)
protein_chains = [chain for chain in fold_input_obj.chains if hasattr(chain, "sequence")]
chain_sequences = {chain.id: chain.sequence for chain in protein_chains}
self.assertEqual(sorted(chain_sequences), ["A", "B"])
job_name = fold_input_obj.sanitised_name()
summary_path = self.output_dir / f"{job_name}_af2_to_af3_translation_summary.json"
self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}")
summary = json.loads(summary_path.read_text(encoding="utf-8"))
self.assertTrue(summary["paired_rows_valid"])
self.assertTrue(summary["unpaired_rows_valid"])
self.assertLen(summary["translation_modes"], 1)
translation_mode = summary["translation_modes"][0]
self.assertIn(
translation_mode,
{
"af3_species_pairing_from_af2_individual_msas",
"manual_unpaired_from_af2_multimer",
},
)
self.assertLen(summary["chains"], 2)
for chain_summary in summary["chains"]:
chain_id = chain_summary["chain_id"]
expected_sequence = chain_sequences[chain_id]
if translation_mode == "af3_species_pairing_from_af2_individual_msas":
self.assertGreater(
chain_summary["paired_msa_row_count"],
0,
f"Expected non-empty paired MSA rows for chain {chain_id}",
)
self.assertGreater(
chain_summary["unpaired_msa_row_count"],
0,
f"Expected non-empty unpaired MSA rows for chain {chain_id}",
)
else:
self.assertEqual(chain_summary["paired_msa_row_count"], 0)
self.assertGreater(
chain_summary["unpaired_msa_row_count"],
0,
f"Expected non-empty unpaired MSA rows for chain {chain_id}",
)
for msa_kind in ("paired_input", "unpaired_input"):
msa_path = self.output_dir / f"{job_name}_chain-{chain_id}_{msa_kind}.a3m"
self.assertTrue(msa_path.is_file(), f"Missing debug MSA {msa_path}")
msa_text = msa_path.read_text(encoding="utf-8")
if msa_text:
self.assertEqual(_a3m_query_sequence(msa_text), expected_sequence)
payload_sequences = _non_empty_a3m_payload_rows(msa_text)
if translation_mode == "af3_species_pairing_from_af2_individual_msas":
self.assertGreater(
len(payload_sequences),
0,
f"Expected payload rows in {msa_path}",
)
elif msa_kind == "unpaired_input":
self.assertGreater(
len(payload_sequences),
0,
f"Expected payload rows in {msa_path}",
)
else:
self.assertEqual(payload_sequences, [])
for payload_sequence in payload_sequences:
self.assertEqual(
_aligned_a3m_row_length(payload_sequence),
len(expected_sequence),
f"Aligned row length mismatch in {msa_path}",
)
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{job_name}_data.json"
written = json.loads(input_json.read_text(encoding="utf-8"))
protein_entries = {
protein_entry["id"]: protein_entry
for protein_entry in _protein_entries_from_af3_input(written)
}
self.assertEqual(set(protein_entries), set(chain_sequences))
for chain_id, protein_entry in protein_entries.items():
expected_sequence = chain_sequences[chain_id]
self.assertEqual(protein_entry["sequence"], expected_sequence)
if translation_mode == "af3_species_pairing_from_af2_individual_msas":
self.assertEqual(
_a3m_query_sequence(protein_entry["pairedMsa"]),
expected_sequence,
)
self.assertEqual(
_a3m_query_sequence(protein_entry["unpairedMsa"]),
expected_sequence,
)
else:
self.assertEqual(protein_entry["pairedMsa"], "")
self.assertEqual(
_a3m_query_sequence(protein_entry["unpairedMsa"]),
expected_sequence,
)
# These exact issue-588 fixtures are AF2/mmseqs2-derived and were
# generated without MMseqs template re-search. Empty templates
# document fixture provenance here, not an AF3 conversion failure.
self.assertEqual(protein_entry["templates"], [])
def test_issue_588_precomputed_mmseqs_msas_preserve_af3_species_pairing(self):
"""Precomputed MMseq A3Ms should preserve recovered identifiers and AF3 species pairing."""
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
self._require_issue_588_mmseqs_environment()
env = os.environ.copy()
feature_dir = self._generate_issue_588_precomputed_mmseq_features(env)
for protein_id in self.ISSUE_588_IDS:
metadata_path, metadata = _load_feature_metadata(feature_dir, protein_id)
self.assertTrue(
_metadata_bool(metadata["other"]["use_precomputed_msas"]),
f"{metadata_path} should record use_precomputed_msas=True",
)
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_species_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered species IDs from cached MMseq A3Ms",
)
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered accession IDs from cached MMseq A3Ms",
)
fold_input_obj = self._prepare_fold_input(
fold_spec="A0ABD7FQG0+P18004",
feature_dir=feature_dir,
debug_msas=True,
)
job_name = fold_input_obj.sanitised_name()
summary_path = self.output_dir / f"{job_name}_af2_to_af3_translation_summary.json"
self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}")
summary = json.loads(summary_path.read_text(encoding="utf-8"))
self.assertEqual(
summary["translation_modes"],
["af3_species_pairing_from_af2_individual_msas"],
)
self.assertTrue(summary["paired_rows_valid"])
self.assertTrue(summary["unpaired_rows_valid"])
self.assertLen(summary["chains"], 2)
for chain_summary in summary["chains"]:
self.assertGreater(chain_summary["paired_msa_row_count"], 0)
self.assertGreater(chain_summary["unpaired_msa_row_count"], 0)
self.assertGreater(chain_summary["paired_species_identifier_count"], 0)
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{job_name}_data.json"
written = json.loads(input_json.read_text(encoding="utf-8"))
protein_entries = _protein_entries_from_af3_input(written)
self.assertLen(protein_entries, 2)
for protein_entry in protein_entries:
self.assertEqual(
_a3m_query_sequence(protein_entry["pairedMsa"]),
protein_entry["sequence"],
)
self.assertEqual(
_a3m_query_sequence(protein_entry["unpairedMsa"]),
protein_entry["sequence"],
)
def test_af3_prepare_input_preserves_templates_for_templated_af2_pkl_features(self):
"""Positive control: templated AF2 pkl inputs should keep templates in AF3 JSON."""
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
feature_dir = self.test_features_dir / "af2_features" / "protein"
fold_input_obj = self._prepare_fold_input(
fold_spec="P61626",
feature_dir=feature_dir,
)
self.assertLen(fold_input_obj.chains, 1)
self.assertGreater(len(fold_input_obj.chains[0].templates), 0)
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
written = json.loads(input_json.read_text(encoding="utf-8"))
protein_entries = _protein_entries_from_af3_input(written)
self.assertLen(protein_entries, 1)
self.assertGreater(len(protein_entries[0]["templates"]), 0)
self.assertTrue(
all(template["mmcif"] for template in protein_entries[0]["templates"])
)
def test_af3_real_fixture_pipeline_tolerates_mixed_missing_accession_ids(self):
"""AF3 prep should tolerate a real mixed-enrichment multimer feature set."""
from alphapulldown.folding_backend.alphafold3_backend import (
AlphaFold3Backend,
process_fold_input,
)
from alphapulldown.scripts import run_structure_prediction
feature_dir = self._prepare_mixed_identifier_fixture_dir()
enriched_feature_dict = _load_feature_dict(feature_dir / "A0A024R1R8.pkl")
self.assertGreater(
_non_empty_identifier_count(
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
)
unenriched_feature_dict = _load_feature_dict(feature_dir / "P61626.pkl")
self.assertEqual(
_non_empty_identifier_count(
unenriched_feature_dict["msa_species_identifiers_all_seq"]
),
0,
)
self.assertNotIn(
"msa_uniprot_accession_identifiers_all_seq",
unenriched_feature_dict,
)
script_flags = SimpleNamespace(
pair_msa=True,
multimeric_template=False,
description_file=None,
path_to_mmt=None,
threshold_clashes=1000,
hb_allowance=0.4,
plddt_threshold=0,
save_features_for_multimeric_object=False,
features_directory=[str(feature_dir)],
use_ap_style=False,
)
with mock.patch.object(run_structure_prediction, "FLAGS", script_flags):
parsed = run_structure_prediction.parse_fold(
["A0A024R1R8+P61626"],
[str(feature_dir)],
"+",
)
data = run_structure_prediction.create_custom_info(parsed)
all_interactors = run_structure_prediction.create_interactors(
data,
[str(feature_dir)],
)
self.assertLen(all_interactors, 1)
self.assertLen(all_interactors[0], 2)
object_to_model, prepared_output_dir = (
run_structure_prediction.pre_modelling_setup(
all_interactors[0],
output_dir=str(self.output_dir / "mixed_identifier_prediction"),
)
)
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=[
{"object": object_to_model, "output_dir": prepared_output_dir}
],
random_seed=42,
debug_msas=True,
)
self.assertLen(mappings, 1)
fold_input_obj, (
prepared_output_dir,
resolve_msa_overlaps,
) = next(iter(mappings[0].items()))
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=prepared_output_dir,
buckets=(512,),
resolve_msa_overlaps=resolve_msa_overlaps,
)
job_name = fold_input_obj.sanitised_name()
summary_path = (
Path(prepared_output_dir)
/ f"{job_name}_af2_to_af3_translation_summary.json"
)
self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}")
summary = json.loads(summary_path.read_text(encoding="utf-8"))
self.assertLen(summary["chains"], 2)
self.assertTrue(summary["unpaired_rows_valid"])
input_json = Path(prepared_output_dir) / f"{job_name}_data.json"
self.assertTrue(input_json.is_file(), f"Missing AF3 input JSON {input_json}")
written = json.loads(input_json.read_text(encoding="utf-8"))
protein_entries = {
protein_entry["id"]: protein_entry
for protein_entry in _protein_entries_from_af3_input(written)
}
self.assertEqual(set(protein_entries), {"A", "B"})
for chain in fold_input_obj.chains:
if not hasattr(chain, "sequence"):
continue
protein_entry = protein_entries[chain.id]
self.assertEqual(protein_entry["sequence"], chain.sequence)
self.assertEqual(
_a3m_query_sequence(protein_entry["unpairedMsa"]),
chain.sequence,
)
class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
"""Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features."""
ISSUE_588_IDS = ("A0ABD7FQG0", "P18004")
def _require_mmseqs_functional_environment(self) -> None:
self._require_af3_functional_environment()
skip_reason = _mmseqs_functional_test_skip_reason()
if skip_reason:
self.skipTest(skip_reason)
for protein_id in self.ISSUE_588_IDS:
fasta_path = self.test_data_dir / "fastas" / f"{protein_id}.fasta"
self.assertTrue(
fasta_path.is_file(),
f"Missing FASTA fixture {fasta_path}",
)
def _generate_issue_588_mmseq_features(self, env: Dict[str, str]) -> Path:
feature_dir = self.output_dir / "issue_588_mmseq_features"
feature_dir.mkdir(parents=True, exist_ok=True)
fasta_paths = ",".join(
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
for protein_id in self.ISSUE_588_IDS
)
res = subprocess.run(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={feature_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--data_pipeline=alphafold2",
"--save_msa_files=True",
"--compress_features=True",
"--skip_existing=False",
],
capture_output=True,
text=True,
env=env,
)
self.assertEqual(
res.returncode,
0,
f"MMseqs feature generation failed.\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
)
return feature_dir
def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_inference(self):
self._require_mmseqs_functional_environment()
env = self._make_af3_test_env()
feature_dir = self._generate_issue_588_mmseq_features(env)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(feature_dir / f"{protein_id}.a3m").is_file(),
f"Expected MMseq A3M {feature_dir / f'{protein_id}.a3m'} to be created.",
)
self.assertTrue(
(feature_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected compressed feature pickle {feature_dir / f'{protein_id}.pkl.xz'} to be created.",
)
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_species_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered species IDs in msa_species_identifiers_all_seq",
)
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered accession IDs in msa_uniprot_accession_identifiers_all_seq",
)
flash_impl = self._af3_flash_attention_impl()
res = subprocess.run(
[
sys.executable,
str(self.script_single),
"--input=A0ABD7FQG0+P18004",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={feature_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--random_seed=42",
"--debug_msas",
],
capture_output=True,
text=True,
env=env,
)
self._runCommonTests(res)
result_dir = self._resolve_single_af3_result_dir()
summary_paths = sorted(
result_dir.glob("*_af2_to_af3_translation_summary.json")
)
self.assertLen(summary_paths, 1)
summary = json.loads(summary_paths[0].read_text(encoding="utf-8"))
self.assertEqual(
summary["translation_modes"],
["af3_species_pairing_from_af2_individual_msas"],
)
self.assertTrue(summary["paired_rows_valid"])
self.assertTrue(summary["unpaired_rows_valid"])
for chain_summary in summary["chains"]:
self.assertGreater(chain_summary["paired_msa_row_count"], 0)
self.assertGreater(chain_summary["unpaired_msa_row_count"], 0)
self.assertGreater(chain_summary["paired_species_identifier_count"], 0)
confidence_files = sorted(result_dir.glob("*_summary_confidences.json"))
self.assertLen(confidence_files, 1)
confidence_payload = json.loads(
confidence_files[0].read_text(encoding="utf-8")
)
self.assertIn("iptm", confidence_payload)
self.assertGreater(
confidence_payload["iptm"],
0.6,
f"Expected AF3 ipTM > 0.6, got {confidence_payload['iptm']}",
)
def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_trimer_inference(self):
"""AF3 should accept trimer jobs built from AF2/mmseqs2 pkl features and report effective pairing."""
self._require_mmseqs_functional_environment()
env = self._make_af3_test_env()
feature_dir = self._generate_issue_588_mmseq_features(env)
flash_impl = self._af3_flash_attention_impl()
res = subprocess.run(
[
sys.executable,
str(self.script_single),
"--input=A0ABD7FQG0+P18004+A0ABD7FQG0",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={feature_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--random_seed=42",
"--debug_msas",
],
capture_output=True,
text=True,
env=env,
)
self._runCommonTests(res)
result_dir = self._resolve_single_af3_result_dir()
summary_paths = sorted(
result_dir.glob("*_af2_to_af3_translation_summary.json")
)
self.assertLen(summary_paths, 1)
summary = json.loads(summary_paths[0].read_text(encoding="utf-8"))
self.assertEqual(
summary["translation_modes"],
["af3_species_pairing_from_af2_individual_msas"],
)
self.assertTrue(summary["paired_rows_valid"])
self.assertTrue(summary["unpaired_rows_valid"])
self.assertGreater(summary["translated_paired_input_row_count"], 0)
self.assertGreater(summary["paired_row_count"], 0)
self.assertGreaterEqual(
summary["translated_paired_input_row_count"],
summary["paired_row_count"],
)
histogram = summary["effective_paired_row_histogram_by_num_chains"]
self.assertTrue(histogram)
self.assertGreaterEqual(max(int(key) for key in histogram), 2)
self.assertLen(summary["chains"], 3)
for chain_summary in summary["chains"]:
self.assertGreater(chain_summary["paired_msa_row_count"], 0)
self.assertGreater(chain_summary["unpaired_msa_row_count"], 0)
self.assertGreater(chain_summary["effective_paired_msa_row_count"], 0)
input_json_paths = sorted(result_dir.glob("*_data.json"))
self.assertLen(input_json_paths, 1)
written = json.loads(input_json_paths[0].read_text(encoding="utf-8"))
protein_entries = _protein_entries_from_af3_input(written)
self.assertLen(protein_entries, 2)
all_chain_ids = []
for protein_entry in protein_entries:
entry_ids = protein_entry["id"]
if isinstance(entry_ids, str):
entry_ids = [entry_ids]
all_chain_ids.extend(entry_ids)
self.assertEqual(
_a3m_query_sequence(protein_entry["pairedMsa"]),
protein_entry["sequence"],
)
self.assertEqual(
_a3m_query_sequence(protein_entry["unpairedMsa"]),
protein_entry["sequence"],
)
self.assertCountEqual(all_chain_ids, ["A", "B", "C"])
# --------------------------------------------------------------------------- #
# parameterised "run mode" tests #
# --------------------------------------------------------------------------- #
class TestAlphaFold3RunModes(_TestBase):
def test_af3_on_the_fly_pairing_from_json_features(self):
"""
Build a dimer from two AF3 JSON monomer feature files that only contain
unpairedMsa. Ensure backend writes combined *_data.json where protein
chains have pairedMsa populated (promoted from unpairedMsa) so AF3 can
perform cross-chain pairing downstream. Skip model inference.
"""
# Input JSONs (use repo-relative paths via test_features_dir)
json_a = self.test_features_dir / "af3_features/protein/A0A024R1R8_af3_input.json"
json_b = self.test_features_dir / "af3_features/protein/P61626_af3_input.json"
# Prepare objects_to_model input to backend: two JSON inputs merged into one complex
from alphapulldown.folding_backend.alphafold3_backend import AlphaFold3Backend, process_fold_input
objects_to_model = [
{"object": {"json_input": str(json_a)}, "output_dir": str(self.output_dir)},
{"object": {"json_input": str(json_b)}, "output_dir": str(self.output_dir)},
]
# Use backend to prepare the combined input
mappings = AlphaFold3Backend.prepare_input(objects_to_model=objects_to_model, random_seed=42)
self.assertEqual(len(mappings), 1)
fold_input_obj, out_dir = next(iter(mappings[0].items()))
# Ask the backend helper to write *_data.json without inference
res = process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
self.assertIsNotNone(res)
out_path = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
# Load JSON and verify that each protein chain now has pairedMsa populated
# (promoted from unpairedMsa) and unpairedMsa cleared.
with open(out_path, "rt") as f:
data = json.load(f)
# JSON structure depends on AF3 version; check sequences[*].protein fields
sequences = data.get("sequences", [])
self.assertGreaterEqual(len(sequences), 2, "Expected at least two chains in combined input")
# For protein entries, ensure at least one of pairedMsa/unpairedMsa is present
# and that our pipeline can promote unpaired -> paired (non-empty strings present in at least one field)
num_proteins = 0
num_with_promoted_paired = 0
for seq_entry in sequences:
if "protein" in seq_entry:
num_proteins += 1
protein = seq_entry["protein"]
paired = protein.get("pairedMsa", "") or ""
unpaired = protein.get("unpairedMsa", None)
# After promotion we expect pairedMsa to be non-empty and unpairedMsa to be ""
if isinstance(paired, str) and len(paired) > 0 and (unpaired == "" or unpaired is None):
num_with_promoted_paired += 1
self.assertGreaterEqual(num_proteins, 2, "Expected two protein chains in the dimer test")
self.assertEqual(num_with_promoted_paired, num_proteins, "All protein chains must have pairedMsa populated and unpairedMsa cleared")
# Finally, assert that original monomer JSON has empty pairedMsa, to validate that
# we started from unpaired-only features.
with open(json_a, "rt") as f:
a_data = json.load(f)
with open(json_b, "rt") as f:
b_data = json.load(f)
def _paired_empty(d):
for seq_entry in d.get("sequences", []):
if "protein" in seq_entry:
if seq_entry["protein"].get("pairedMsa", None):
return False
return True
self.assertTrue(_paired_empty(a_data))
self.assertTrue(_paired_empty(b_data))
print("✓ Combined AF3 input JSON created; per-chain MSAs present for backend pairing")
def test_af3_custom_residue_ids_round_trip_through_json_and_structure(self):
"""Custom AF3 residue IDs must survive JSON and structure conversion."""
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
expected_residue_ids = [2, 3, 4, 5, 8, 9, 10]
chain = folding_input.ProteinChain(
id="A",
sequence="SSHEKKK",
ptms=[],
residue_ids=expected_residue_ids,
unpaired_msa="",
paired_msa="",
templates=[],
)
fold_input = folding_input.Input(
name="gap_test",
chains=[chain],
rng_seeds=[1],
)
round_tripped = folding_input.Input.from_json(fold_input.to_json())
self.assertEqual(
list(round_tripped.protein_chains[0].residue_ids),
expected_residue_ids,
)
struc = round_tripped.to_structure(ccd=chemical_components.Ccd())
self.assertEqual(struc.present_residues.id.tolist(), expected_residue_ids)
def test_af3_custom_residue_ids_propagate_to_token_features(self):
"""AF3 token features must retain custom gapped residue numbering."""
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model import features as af3_features
from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.network import featurization as af3_featurization
expected_residue_ids = [1, 2, 3, 4, 8, 9, 10]
chain = folding_input.ProteinChain(
id="A",
sequence="ACDEFGH",
ptms=[],
residue_ids=expected_residue_ids,
unpaired_msa="",
paired_msa="",
templates=[],
)
fold_input = folding_input.Input(
name="gap_token_test",
chains=[chain],
rng_seeds=[1],
)
ccd = chemical_components.Ccd()
struc = fold_input.to_structure(ccd=ccd)
flat_layout = atom_layout.atom_layout_from_structure(struc)
all_tokens, _, _ = af3_features.tokenizer(
flat_layout,
ccd=ccd,
max_atoms_per_token=24,
flatten_non_standard_residues=False,
logging_name="gap_token_test",
)
padding_shapes = af3_features.PaddingShapes(
num_tokens=len(all_tokens.atom_name),
msa_size=1,
num_chains=1,
num_templates=0,
num_atoms=24 * len(all_tokens.atom_name),
)
token_features = af3_features.TokenFeatures.compute_features(
all_tokens=all_tokens,
padding_shapes=padding_shapes,
)
self.assertEqual(
token_features.residue_index[:len(expected_residue_ids)].tolist(),
expected_residue_ids,
)
self.assertEqual(
sorted(set(token_features.asym_id[:len(expected_residue_ids)].tolist())),
[1],
)
relative_encoding = np.asarray(
af3_featurization.create_relative_encoding(
token_features,
max_relative_idx=4,
max_relative_chain=2,
)
)
inter_chain_bin = 2 * 4 + 1
self.assertEqual(relative_encoding[3, 4, inter_chain_bin], 0)
self.assertEqual(np.argmax(relative_encoding[3, 4, : 2 * 4 + 2]), 0)
def test_af3_duplicate_residue_ids_survive_empty_structure_round_trip(self):
"""AF3 must preserve duplicate residue IDs when rebuilding empty structures."""
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model.atom_layout import atom_layout
expected_residue_ids = list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16))
chain = folding_input.ProteinChain(
id="A",
sequence="ACDEFGHIKLCDEFMNPQ",
ptms=[],
residue_ids=expected_residue_ids,
unpaired_msa="",
paired_msa="",
templates=[],
)
fold_input = folding_input.Input(
name="duplicate_residue_ids_test",
chains=[chain],
rng_seeds=[1],
)
ccd = chemical_components.Ccd()
struc = fold_input.to_structure(ccd=ccd)
flat_layout = atom_layout.atom_layout_from_structure(struc)
all_physical_residues = atom_layout.residues_from_structure(struc)
rebuilt = atom_layout.make_structure(
flat_layout,
atom_coords=np.zeros((flat_layout.atom_name.shape[0], 3), dtype=np.float32),
name="duplicate_residue_ids_test",
all_physical_residues=all_physical_residues,
)
self.assertEqual(rebuilt.present_residues.id.tolist(), expected_residue_ids)
def test_af3_output_job_name_compacts_long_homomer_names(self):
"""AF3 job names should stay readable and below common filename limits."""
from alphapulldown.folding_backend.alphafold3_backend import AlphaFold3Backend
parsed = parse_fold(
["A0A075B6L2:10:1-3:4-5:6-7:7-8"],
[str(self.test_features_dir)],
"+",
)
data = create_custom_info(parsed)
all_interactors = create_interactors(data, [str(self.test_features_dir)])
self.assertLen(all_interactors, 1)
self.assertLen(all_interactors[0], 10)
object_to_model = MultimericObject(interactors=all_interactors[0], pair_msa=True)
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=[{"object": object_to_model, "output_dir": str(self.output_dir)}],
random_seed=42,
)
self.assertLen(mappings, 1)
fold_input_obj, _ = next(iter(mappings[0].items()))
self.assertEqual(
fold_input_obj.sanitised_name(),
"A0A075B6L2_1-3_4-5_6-7_7-8__x10",
)
self.assertLessEqual(len(fold_input_obj.sanitised_name()), 200)
expected_sequence = "".join(
self._get_region_sequences(
"A0A075B6L2",
[(1, 3), (4, 5), (6, 7), (8, 8)],
)
)
self.assertTrue(
all(chain.sequence == expected_sequence for chain in fold_input_obj.chains)
)
self.assertTrue(
all(list(chain.residue_ids) == [1, 2, 3, 4, 5, 6, 7, 8] for chain in fold_input_obj.chains)
)
def test_af3_output_job_name_hashes_overlong_unique_compound_names(self):
"""AF3 job names should fall back to a deterministic hash suffix when needed."""
from alphapulldown.folding_backend.alphafold3_backend import (
_build_output_job_name,
)
fragments = [
f"protein_{index:02d}_{'verylongsegment' * 4}"
for index in range(12)
]
objects_to_model = [
{
"object": {
"json_input": str(
Path("/tmp") / f"{fragment}_af3_input.json"
)
},
"output_dir": str(self.output_dir),
}
for fragment in fragments
]
readable_name = "_and_".join(fragments)
self.assertGreater(len(readable_name), 200)
job_name = _build_output_job_name(objects_to_model)
expected_digest = hashlib.sha1(
readable_name.encode("utf-8")
).hexdigest()[:12]
self.assertLessEqual(len(job_name), 200)
self.assertTrue(job_name.endswith(f"__{expected_digest}"))
self.assertRegex(job_name, r"__[0-9a-f]{12}$")
self.assertEqual(job_name, _build_output_job_name(objects_to_model))
def test_af3_prepare_input_accepts_monomer_plus_ligand_json(self):
"""AF3 mixed protein+ligand JSON inputs must survive prepare_input cloning."""
from alphafold3.common import folding_input
from alphapulldown.folding_backend.alphafold3_backend import (
AlphaFold3Backend,
process_fold_input,
)
parsed = parse_fold(
["A0A024R1R8+ligand.json"],
[str(self.test_features_dir)],
"+",
)
data = create_custom_info(parsed)
all_interactors = create_interactors(data, [str(self.test_features_dir)])
self.assertLen(all_interactors, 1)
self.assertLen(all_interactors[0], 2)
objects_to_model = [
{"object": obj, "output_dir": str(self.output_dir)}
for obj in all_interactors[0]
]
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=objects_to_model,
random_seed=42,
)
self.assertLen(mappings, 1)
fold_input_obj, _ = next(iter(mappings[0].items()))
self.assertEqual([chain.id for chain in fold_input_obj.chains], ["A", "L"])
self.assertIsInstance(fold_input_obj.chains[0], folding_input.ProteinChain)
self.assertIsInstance(fold_input_obj.chains[1], folding_input.Ligand)
self.assertEqual(list(fold_input_obj.chains[1].ccd_ids), ["ATP"])
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
with open(input_json, "rt") as handle:
written = json.load(handle)
protein_entries = [
sequence_entry["protein"]
for sequence_entry in written.get("sequences", [])
if "protein" in sequence_entry
]
ligand_entries = [
sequence_entry["ligand"]
for sequence_entry in written.get("sequences", [])
if "ligand" in sequence_entry
]
self.assertLen(protein_entries, 1)
self.assertLen(ligand_entries, 1)
self.assertEqual(ligand_entries[0]["id"], "L")
self.assertEqual(ligand_entries[0]["ccdCodes"], ["ATP"])
def test_af3_prepare_input_skips_invalid_json_templates_for_ptm_input(self):
"""Malformed inline JSON templates should be dropped instead of crashing AF3."""
from alphafold3.common import folding_input
from alphapulldown.folding_backend.alphafold3_backend import (
AlphaFold3Backend,
process_fold_input,
)
json_input = self.test_features_dir / "protein_with_ptms.json"
raw_payload = json.loads(json_input.read_text())
expected_protein = raw_payload["sequences"][0]["protein"]
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=[
{
"object": {"json_input": str(json_input)},
"output_dir": str(self.output_dir),
}
],
random_seed=42,
)
self.assertLen(mappings, 1)
fold_input_obj, _ = next(iter(mappings[0].items()))
self.assertEqual([chain.id for chain in fold_input_obj.chains], ["P"])
self.assertLen(fold_input_obj.chains, 1)
self.assertIsInstance(fold_input_obj.chains[0], folding_input.ProteinChain)
self.assertEqual(list(fold_input_obj.chains[0].ptms), [("HYS", 1), ("2MG", 15)])
self.assertEqual(list(fold_input_obj.chains[0].templates), [])
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
with open(input_json, "rt") as handle:
written = json.load(handle)
protein_entries = [
sequence_entry["protein"]
for sequence_entry in written.get("sequences", [])
if "protein" in sequence_entry
]
self.assertLen(protein_entries, 1)
self.assertEqual(protein_entries[0]["id"], "P")
self.assertEqual(protein_entries[0]["sequence"], expected_protein["sequence"])
self.assertEqual(
protein_entries[0]["modifications"],
expected_protein["modifications"],
)
self.assertEqual(protein_entries[0]["templates"], [])
def test_af3_prepare_input_keeps_valid_json_templates(self):
"""Valid inline JSON templates should survive prepare_input and JSON write-out."""
from alphafold3.common import folding_input
from alphapulldown.folding_backend.alphafold3_backend import (
AlphaFold3Backend,
process_fold_input,
)
json_input = (
self.test_features_dir
/ "af3_features"
/ "protein"
/ "P61626_af3_input.json"
)
raw_payload = json.loads(json_input.read_text())
expected_protein = raw_payload["sequences"][0]["protein"]
expected_template_count = len(expected_protein["templates"])
self.assertGreater(expected_template_count, 0)
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=[
{
"object": {"json_input": str(json_input)},
"output_dir": str(self.output_dir),
}
],
random_seed=42,
)
self.assertLen(mappings, 1)
fold_input_obj, _ = next(iter(mappings[0].items()))
self.assertEqual([chain.id for chain in fold_input_obj.chains], ["A"])
self.assertLen(fold_input_obj.chains, 1)
self.assertIsInstance(fold_input_obj.chains[0], folding_input.ProteinChain)
self.assertLen(fold_input_obj.chains[0].templates, expected_template_count)
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
with open(input_json, "rt") as handle:
written = json.load(handle)
protein_entries = [
sequence_entry["protein"]
for sequence_entry in written.get("sequences", [])
if "protein" in sequence_entry
]
self.assertLen(protein_entries, 1)
self.assertEqual(protein_entries[0]["id"], "A")
self.assertEqual(
len(protein_entries[0]["templates"]),
expected_template_count,
)
self.assertTrue(
all(template["mmcif"] for template in protein_entries[0]["templates"])
)
self.assertTrue(
all(template["queryIndices"] for template in protein_entries[0]["templates"])
)
self.assertTrue(
all(template["templateIndices"] for template in protein_entries[0]["templates"])
)
def test_af3_viewer_output_renumbers_gapped_residue_ids_for_viewers(self):
"""Viewer-safe AF3 output must use sequential label IDs for gapped chains."""
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model import model as af3_model
from alphapulldown.folding_backend.alphafold3_backend import (
_make_viewer_compatible_inference_result,
)
original_residue_ids = [2, 3, 4, 5, 8, 9, 10]
chain = folding_input.ProteinChain(
id="A",
sequence="ACDEFGH",
ptms=[],
residue_ids=original_residue_ids,
unpaired_msa="",
paired_msa="",
templates=[],
)
fold_input = folding_input.Input(
name="gapped_residue_ids_for_viewers",
chains=[chain],
rng_seeds=[1],
)
struc = fold_input.to_structure(ccd=chemical_components.Ccd())
inference_result = af3_model.InferenceResult(
predicted_structure=struc,
metadata={
"token_chain_ids": ["A"] * len(original_residue_ids),
"token_res_ids": original_residue_ids,
},
)
viewer_result = _make_viewer_compatible_inference_result(inference_result)
self.assertEqual(
viewer_result.predicted_structure.present_residues.id.tolist(),
list(range(1, len(original_residue_ids) + 1)),
)
self.assertEqual(
viewer_result.metadata["token_res_ids"],
list(range(1, len(original_residue_ids) + 1)),
)
self.assertEqual(
viewer_result.predicted_structure.residues_table.auth_seq_id.tolist(),
[str(residue_id) for residue_id in original_residue_ids],
)
self.assertEqual(
viewer_result.predicted_structure.residues_table.insertion_code.tolist(),
["."] * len(original_residue_ids),
)
self.assertEqual(
viewer_result.metadata["token_auth_res_ids"],
[str(residue_id) for residue_id in original_residue_ids],
)
self.assertEqual(
viewer_result.metadata["token_auth_res_labels"],
[str(residue_id) for residue_id in original_residue_ids],
)
def test_af3_viewer_output_uses_insertion_codes_for_duplicate_residue_ids(self):
"""Viewer-safe AF3 output must preserve IDs and disambiguate with insertions."""
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model import model as af3_model
from alphapulldown.folding_backend.alphafold3_backend import (
_make_viewer_compatible_inference_result,
)
original_residue_ids = (
list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16))
)
chain = folding_input.ProteinChain(
id="A",
sequence="ACDEFGHIKLCDEFMNPQ",
ptms=[],
residue_ids=original_residue_ids,
unpaired_msa="",
paired_msa="",
templates=[],
)
fold_input = folding_input.Input(
name="duplicate_residue_ids_for_chimerax",
chains=[chain],
rng_seeds=[1],
)
struc = fold_input.to_structure(ccd=chemical_components.Ccd())
inference_result = af3_model.InferenceResult(
predicted_structure=struc,
metadata={
"token_chain_ids": ["A"] * len(original_residue_ids),
"token_res_ids": original_residue_ids,
},
)
viewer_result = _make_viewer_compatible_inference_result(
inference_result
)
self.assertEqual(
viewer_result.predicted_structure.present_residues.id.tolist(),
list(range(1, len(original_residue_ids) + 1)),
)
self.assertEqual(
viewer_result.metadata["token_res_ids"],
list(range(1, len(original_residue_ids) + 1)),
)
self.assertEqual(
viewer_result.predicted_structure.residues_table.auth_seq_id.tolist(),
[str(residue_id) for residue_id in original_residue_ids],
)
self.assertEqual(
viewer_result.predicted_structure.residues_table.insertion_code.tolist(),
['.'] * 10 + ['A'] * 4 + ['.'] * 4,
)
self.assertEqual(
viewer_result.metadata["token_auth_res_ids"],
[str(residue_id) for residue_id in original_residue_ids],
)
self.assertEqual(
viewer_result.metadata["token_pdb_ins_codes"],
['.'] * 10 + ['A'] * 4 + ['.'] * 4,
)
self.assertEqual(
viewer_result.metadata["token_auth_res_labels"],
[str(i) for i in range(1, 11)]
+ [f"{i}A" for i in range(2, 6)]
+ [str(i) for i in range(12, 16)],
)
def test_af3_viewer_output_handles_many_tokens_for_one_residue(self):
"""Viewer metadata must not crash when many tokens map to one residue."""
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model import model as af3_model
from alphapulldown.folding_backend.alphafold3_backend import (
_make_viewer_compatible_inference_result,
)
chain = folding_input.ProteinChain(
id="L",
sequence="A",
ptms=[],
residue_ids=[1],
unpaired_msa="",
paired_msa="",
templates=[],
)
fold_input = folding_input.Input(
name="many_tokens_one_residue",
chains=[chain],
rng_seeds=[1],
)
struc = fold_input.to_structure(ccd=chemical_components.Ccd())
token_count = 40
inference_result = af3_model.InferenceResult(
predicted_structure=struc,
metadata={
"token_chain_ids": ["L"] * token_count,
"token_res_ids": [1] * token_count,
},
)
viewer_result = _make_viewer_compatible_inference_result(inference_result)
self.assertEqual(
viewer_result.metadata["token_res_ids"],
list(range(1, token_count + 1)),
)
self.assertEqual(
viewer_result.metadata["token_auth_res_ids"],
["1"] * token_count,
)
self.assertEqual(
viewer_result.metadata["token_pdb_ins_codes"][:27],
["."] + [chr(ord("A") + index) for index in range(26)],
)
self.assertEqual(
viewer_result.metadata["token_pdb_ins_codes"][27:],
["."] * (token_count - 27),
)
self.assertEqual(
viewer_result.metadata["token_auth_res_labels"][:27],
["1"] + [f"1{chr(ord('A') + index)}" for index in range(26)],
)
self.assertEqual(
viewer_result.metadata["token_auth_res_labels"][27],
"1[28]",
)
self.assertEqual(
viewer_result.metadata["token_auth_res_labels"][-1],
"1[40]",
)
def test_af3_keeps_discontinuous_chopped_regions_in_one_gapped_chain(self):
"""AF3 must keep multi-region chopped inputs as one gapped protein chain."""
from alphapulldown.folding_backend.alphafold3_backend import (
AlphaFold3Backend,
process_fold_input,
)
parsed = parse_fold(
["TEST+A0A075B6L2:1-10:2-5:12-15"],
[str(self.test_features_dir)],
"+",
)
data = create_custom_info(parsed)
all_interactors = create_interactors(data, [str(self.test_features_dir)])
self.assertLen(all_interactors, 1)
self.assertLen(all_interactors[0], 2)
object_to_model = MultimericObject(interactors=all_interactors[0], pair_msa=True)
objects_to_model = [{"object": object_to_model, "output_dir": str(self.output_dir)}]
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=objects_to_model,
random_seed=42,
)
self.assertLen(mappings, 1)
fold_input_obj, _ = next(iter(mappings[0].items()))
chopped_region_sequences = self._get_region_sequences(
"A0A075B6L2",
[(1, 10), (2, 5), (12, 15)],
)
concatenated_chopped_sequence = "".join(chopped_region_sequences)
expected_sequences = [
self._get_sequence_for_protein("TEST"),
concatenated_chopped_sequence,
]
expected_chopped_residue_ids = (
list(range(1, 11))
+ [2, 3, 4, 5]
+ list(range(12, 16))
)
actual_sequences = [chain.sequence for chain in fold_input_obj.chains]
self.assertCountEqual(actual_sequences, expected_sequences)
self.assertLen(actual_sequences, 2)
chopped_chains = [
chain for chain in fold_input_obj.chains
if chain.sequence == concatenated_chopped_sequence
]
self.assertLen(chopped_chains, 1)
self.assertEqual(
list(chopped_chains[0].residue_ids),
expected_chopped_residue_ids,
)
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
with open(input_json, "rt") as handle:
data = json.load(handle)
protein_entries = [
sequence_entry["protein"]
for sequence_entry in data.get("sequences", [])
if "protein" in sequence_entry
]
self.assertLen(protein_entries, 2)
self.assertCountEqual(
[entry["sequence"] for entry in protein_entries],
expected_sequences,
)
chopped_entries = [
entry for entry in protein_entries
if entry["sequence"] == concatenated_chopped_sequence
]
self.assertLen(chopped_entries, 1)
self.assertEqual(
chopped_entries[0]["residueIds"],
expected_chopped_residue_ids,
)
print("✓ AF3 input keeps discontinuous chopped regions as one gapped chain")
def test_af3_keeps_two_out_of_order_gapped_copies_as_two_chains(self):
"""AF3 must keep two copied out-of-order gapped regions as two chains."""
from alphapulldown.folding_backend.alphafold3_backend import (
AlphaFold3Backend,
process_fold_input,
)
parsed = parse_fold(
["A0A075B6L2:2:8-10:2-5"],
[str(self.test_features_dir)],
"+",
)
data = create_custom_info(parsed)
all_interactors = create_interactors(data, [str(self.test_features_dir)])
self.assertLen(all_interactors, 1)
self.assertLen(all_interactors[0], 2)
objects_to_model = [{"object": all_interactors[0], "output_dir": str(self.output_dir)}]
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=objects_to_model,
random_seed=42,
)
self.assertLen(mappings, 1)
fold_input_obj, _ = next(iter(mappings[0].items()))
expected_regions = [(8, 10), (2, 5)]
expected_sequence = "".join(
self._get_region_sequences("A0A075B6L2", expected_regions)
)
expected_residue_ids = [8, 9, 10, 2, 3, 4, 5]
self.assertEqual(
[chain.id for chain in fold_input_obj.chains],
["A", "B"],
)
self.assertEqual(
[chain.sequence for chain in fold_input_obj.chains],
[expected_sequence, expected_sequence],
)
self.assertEqual(
[list(chain.residue_ids) for chain in fold_input_obj.chains],
[expected_residue_ids, expected_residue_ids],
)
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
with open(input_json, "rt") as handle:
written = json.load(handle)
protein_entries = [
sequence_entry["protein"]
for sequence_entry in written.get("sequences", [])
if "protein" in sequence_entry
]
self.assertLen(protein_entries, 1)
self.assertEqual(protein_entries[0]["id"], ["A", "B"])
self.assertEqual(protein_entries[0]["sequence"], expected_sequence)
self.assertEqual(protein_entries[0]["residueIds"], expected_residue_ids)
print("✓ AF3 input keeps two copied out-of-order gapped regions as two chains")
def test_af3_json_feature_ranges_collapse_into_one_gapped_chain(self):
"""AF3 JSON feature files with ranges must collapse into one gapped chain."""
from alphapulldown.folding_backend.alphafold3_backend import (
AlphaFold3Backend,
process_fold_input,
)
feature_dir = self.test_features_dir / "af3_features" / "protein"
json_filename = "A0A024R1R8_af3_input.json"
parsed = parse_fold(
[f"{json_filename}:2-5:8-10"],
[str(feature_dir)],
"+",
)
self.assertEqual(
parsed,
[[
{
"json_input": str(feature_dir / json_filename),
"regions": [(2, 5), (8, 10)],
}
]],
)
data = create_custom_info(parsed)
all_interactors = create_interactors(data, [str(feature_dir)])
self.assertLen(all_interactors, 1)
self.assertLen(all_interactors[0], 1)
self.assertIsInstance(all_interactors[0][0], dict)
objects_to_model = [{"object": all_interactors[0][0], "output_dir": str(self.output_dir)}]
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=objects_to_model,
random_seed=42,
)
self.assertLen(mappings, 1)
fold_input_obj, _ = next(iter(mappings[0].items()))
json_sequences = self._get_sequence_from_json(
"af3_features/protein/A0A024R1R8_af3_input.json"
)
self.assertLen(json_sequences, 1)
full_sequence = json_sequences[0][1]
expected_sequence = full_sequence[1:5] + full_sequence[7:10]
expected_residue_ids = [2, 3, 4, 5, 8, 9, 10]
self.assertEqual(
[chain.sequence for chain in fold_input_obj.chains],
[expected_sequence],
)
self.assertEqual(
fold_input_obj.sanitised_name(),
"A0A024R1R8__2-5_8-10",
)
self.assertEqual(
[list(chain.residue_ids) for chain in fold_input_obj.chains],
[expected_residue_ids],
)
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
with open(input_json, "rt") as handle:
written = json.load(handle)
protein_entries = [
sequence_entry["protein"]
for sequence_entry in written.get("sequences", [])
if "protein" in sequence_entry
]
self.assertLen(protein_entries, 1)
self.assertEqual(protein_entries[0]["sequence"], expected_sequence)
self.assertEqual(protein_entries[0]["residueIds"], expected_residue_ids)
print("✓ AF3 JSON feature ranges collapse into one gapped chain")
def test_af3_predicts_json_feature_ranges_as_one_gapped_chain(self):
"""Run AF3 on a Snakefile-style AF3 JSON feature input with explicit ranges."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
feature_dir = self.test_features_dir / "af3_features" / "protein"
res = subprocess.run(
[
sys.executable,
str(self.script_single),
"--input=A0A024R1R8_af3_input.json:2-5:8-10",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={feature_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
],
capture_output=True,
text=True,
env=env,
)
self._runCommonTests(res)
json_sequences = self._get_sequence_from_json(
"af3_features/protein/A0A024R1R8_af3_input.json"
)
self.assertLen(json_sequences, 1)
full_sequence = json_sequences[0][1]
expected_sequence = full_sequence[1:5] + full_sequence[7:10]
expected_residue_ids = [2, 3, 4, 5, 8, 9, 10]
result_dir = self._resolve_single_af3_result_dir()
cif_files = list(result_dir.glob("*_model.cif"))
self.assertTrue(cif_files, f"No predicted CIF files found in {result_dir}")
actual_chains_and_sequences = self._extract_cif_chains_and_sequences(cif_files[0])
actual_sequences = [sequence for _, sequence in actual_chains_and_sequences]
actual_residue_numbers = self._extract_cif_chain_residue_numbers(cif_files[0])
self.assertEqual(actual_sequences, [expected_sequence])
self.assertEqual(actual_residue_numbers, [("A", expected_residue_ids)])
print("✓ AF3 prediction keeps AF3 JSON feature ranges as one gapped chain")
def test_af3_predicts_discontinuous_chopped_regions_as_one_gapped_chain(self):
"""Run AF3 inference and ensure discontinuous chopped regions remain one chain."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
res = subprocess.run(
[
sys.executable,
str(self.script_single),
"--input=TEST+A0A075B6L2:1-10:2-5:12-15",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
],
capture_output=True,
text=True,
env=env,
)
self._runCommonTests(res)
chopped_region_sequences = self._get_region_sequences(
"A0A075B6L2",
[(1, 10), (2, 5), (12, 15)],
)
concatenated_chopped_sequence = "".join(chopped_region_sequences)
expected_sequences = [
self._get_sequence_for_protein("TEST"),
concatenated_chopped_sequence,
]
expected_chopped_residue_ids = (
list(range(1, 11))
+ ["2A", "3A", "4A", "5A"]
+ list(range(12, 16))
)
result_dir = self._resolve_single_af3_result_dir()
cif_files = list(result_dir.glob("*_model.cif"))
self.assertTrue(cif_files, f"No predicted CIF files found in {result_dir}")
actual_chains_and_sequences = self._extract_cif_chains_and_sequences(cif_files[0])
actual_sequences = [sequence for _, sequence in actual_chains_and_sequences]
residue_numbers_by_chain = dict(self._extract_cif_chain_residue_numbers(cif_files[0]))
sequences_by_chain = dict(actual_chains_and_sequences)
self.assertLen(actual_sequences, 2)
self.assertCountEqual(actual_sequences, expected_sequences)
chopped_chain_ids = [
chain_id
for chain_id, sequence in sequences_by_chain.items()
if sequence == concatenated_chopped_sequence
]
self.assertLen(chopped_chain_ids, 1)
self.assertEqual(
residue_numbers_by_chain[chopped_chain_ids[0]],
expected_chopped_residue_ids,
)
print("✓ AF3 prediction keeps discontinuous chopped regions as one gapped chain")
def test_af3_predicts_two_out_of_order_gapped_copies_as_two_chains(self):
"""Run AF3 inference and ensure copied out-of-order gapped regions remain two chains."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
res = subprocess.run(
[
sys.executable,
str(self.script_single),
"--input=A0A075B6L2:2:8-10:2-5",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
],
capture_output=True,
text=True,
env=env,
)
self._runCommonTests(res)
expected_regions = [(8, 10), (2, 5)]
expected_sequence = "".join(
self._get_region_sequences("A0A075B6L2", expected_regions)
)
expected_residue_ids = [8, 9, 10, 2, 3, 4, 5]
result_dir = self._resolve_single_af3_result_dir()
cif_files = list(result_dir.glob("*_model.cif"))
self.assertTrue(cif_files, f"No predicted CIF files found in {result_dir}")
actual_chains_and_sequences = self._extract_cif_chains_and_sequences(cif_files[0])
residue_numbers_by_chain = dict(self._extract_cif_chain_residue_numbers(cif_files[0]))
self.assertEqual(
[sequence for _, sequence in actual_chains_and_sequences],
[expected_sequence, expected_sequence],
)
self.assertEqual(
[chain_id for chain_id, _ in actual_chains_and_sequences],
["A", "B"],
)
self.assertEqual(residue_numbers_by_chain["A"], expected_residue_ids)
self.assertEqual(residue_numbers_by_chain["B"], expected_residue_ids)
print("✓ AF3 prediction keeps two copied out-of-order gapped regions as two chains")
def test_dimer_chopped_expected_sequences_are_concatenated_per_chain(self):
"""Sequence expectations for AF3 chopped inputs must reflect one gapped chain."""
expected_sequences = self._extract_expected_sequences("test_dimer_chopped.txt")
chopped_sequence = "".join(
self._get_region_sequences(
"A0A075B6L2",
[(1, 10), (2, 5), (12, 15)],
)
)
self.assertCountEqual(
[sequence for _, sequence in expected_sequences],
[
self._get_sequence_for_protein("TEST"),
chopped_sequence,
],
)
self.assertLen(expected_sequences, 2)
def test_multi_seeds_samples_sequence_extraction(self):
"""Test that sequence extraction works correctly for multi_seeds_samples."""
# Test the sequence extraction logic directly
expected_sequences = self._extract_expected_sequences("test_multi_seeds_samples.txt")
# The expected result should be [('A', 'PLVV')] for A0A075B6L2,2-5
self.assertEqual(expected_sequences, [('A', 'PLVV')],
f"Expected [('A', 'PLVV')], got {expected_sequences}")
def test_multi_seeds_samples_output_validation(self):
"""Test that the multi_seeds_samples output files are correct."""
if not (self.output_dir / "ranking_scores.csv").exists():
# Keep this validation test independently runnable under isolated temp dirs.
env = self._make_af3_test_env()
res = subprocess.run(
self._args(
plist="test_multi_seeds_samples.txt",
script="run_structure_prediction.py",
),
capture_output=True,
text=True,
env=env,
)
self._runCommonTests(res)
result_dir = self._resolve_single_af3_result_dir()
files = list(result_dir.iterdir())
self.assertIn("TERMS_OF_USE.md", {f.name for f in files})
self.assertIn("ranking_scores.csv", {f.name for f in files})
conf_files = [f for f in files if f.name.endswith("_confidences.json")]
summary_conf_files = [f for f in files if f.name.endswith("_summary_confidences.json")]
model_files = [f for f in files if f.name.endswith("_model.cif")]
self.assertTrue(len(conf_files) > 0, "No confidences.json files found")
self.assertTrue(len(summary_conf_files) > 0, "No summary_confidences.json files found")
self.assertTrue(len(model_files) > 0, "No model.cif files found")
sample_dirs = [f for f in files if f.is_dir() and f.name.startswith("seed-")]
self.assertEqual(
len(sample_dirs),
12,
f"Expected 12 sample directories, found {len(sample_dirs)}",
)
for sample_dir in sample_dirs:
sample_files = list(sample_dir.iterdir())
self.assertIn("confidences.json", {f.name for f in sample_files})
self.assertIn("model.cif", {f.name for f in sample_files})
self.assertIn("summary_confidences.json", {f.name for f in sample_files})
with open(result_dir / "ranking_scores.csv") as f:
lines = f.readlines()
self.assertTrue(len(lines) > 1, "ranking_scores.csv should have header and data")
self.assertEqual(len(lines[0].strip().split(",")), 3, "ranking_scores.csv should have 3 columns")
expected_lines = 13
self.assertEqual(
len(lines),
expected_lines,
f"Expected {expected_lines} lines in ranking_scores.csv, found {len(lines)}",
)
for i, line in enumerate(lines[1:], 1):
parts = line.strip().split(",")
self.assertEqual(
len(parts),
3,
f"Line {i+1} should have 3 columns: seed,sample,ranking_score",
)
try:
int(parts[0])
int(parts[1])
float(parts[2])
except ValueError:
self.fail(f"Line {i+1} has invalid format: {line.strip()}")
self._check_chain_counts_and_sequences("test_multi_seeds_samples.txt")
print(
f"✓ Verified multi_seeds_samples output with {len(sample_dirs)} sample "
f"directories and {len(lines)-1} ranking score entries"
)
def test_af3_run_structure_prediction_keeps_single_explicit_output_dir_flat_for_json(self):
"""A single explicit output dir must remain flat even with --use_ap_style."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
json_input = self.test_features_dir / "protein_with_ptms.json"
res = subprocess.run(
[
sys.executable,
str(self.script_single),
f"--input={json_input}",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--use_ap_style",
],
capture_output=True,
text=True,
env=env,
)
self._runCommonTests(res)
self.assertFalse(
(self.output_dir / "protein_ptms").exists(),
"Single-job AF3 runs should keep outputs directly in the explicitly provided output directory.",
)
def test_af3_run_multimer_jobs_multiple_jobs_create_per_job_subdirs(self):
"""Shared AF3 wrapper output roots must isolate multiple jobs by subdirectory."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
protein_list = self.test_protein_lists_dir / "test_multiple_monomers.txt"
res = subprocess.run(
[
sys.executable,
str(self.script_multimer),
"--num_cycle=1",
"--num_predictions_per_model=1",
f"--data_dir={DATA_DIR}",
f"--monomer_objects_dir={self.test_features_dir}",
f"--output_path={self.output_dir}",
"--mode=custom",
f"--protein_lists={protein_list}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
],
capture_output=True,
text=True,
env=env,
)
print(res.stdout)
print(res.stderr)
self.assertEqual(res.returncode, 0, "sub-process failed")
self.assertFalse(
(self.output_dir / "ranking_scores.csv").exists(),
"Shared wrapper output root should not contain flattened AF3 outputs.",
)
# AF3 currently merges all objects passed to one run_structure_prediction
# invocation into a single combined fold input, so shared-root
# multi-job isolation is validated through the wrapper path instead.
for job_dir in ("A0A024R1R8_1-5", "A0A075B6L2_2-5"):
current_output_dir = self.output_dir / job_dir
self.assertTrue(
current_output_dir.is_dir(),
f"Expected per-job output directory {current_output_dir} to be created.",
)
self._assert_af3_outputs_present(current_output_dir)
def test_af3_run_multimer_jobs_multiple_json_jobs_create_per_job_subdirs(self):
"""Shared AF3 wrapper roots must isolate combined JSON folds by subdirectory."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
json_folds = [
[
self.test_features_dir / "protein_with_ptms.json",
self.test_features_dir / "P61626_af3_input.json",
],
[
self.test_features_dir / "P01308_af3_input.json",
self.test_features_dir / "P61626_af3_input.json",
],
]
protein_list = self.output_dir / "test_multiple_json_jobs.txt"
protein_list.write_text(
"\n".join(
";".join(json_input.name for json_input in json_fold)
for json_fold in json_folds
)
+ "\n",
encoding="utf-8",
)
res = subprocess.run(
[
sys.executable,
str(self.script_multimer),
"--num_cycle=1",
"--num_predictions_per_model=1",
f"--data_dir={DATA_DIR}",
f"--monomer_objects_dir={self.test_features_dir}",
f"--output_path={self.output_dir}",
"--mode=custom",
f"--protein_lists={protein_list}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--use_ap_style",
],
capture_output=True,
text=True,
env=env,
)
print(res.stdout)
print(res.stderr)
self.assertEqual(res.returncode, 0, "sub-process failed")
self.assertFalse(
(self.output_dir / "ranking_scores.csv").exists(),
"Shared wrapper output root should not contain flattened AF3 JSON outputs.",
)
self.assertFalse(
any(self.output_dir.glob("*_data.json")),
"Combined JSON folds should not write flat AF3 input JSONs into the shared root.",
)
for output_dir_name in (
"protein_with_ptms_and_p61626",
"p01308_and_p61626",
):
current_output_dir = self.output_dir / output_dir_name
self.assertTrue(
current_output_dir.is_dir(),
f"Expected per-job output directory {current_output_dir} to be created.",
)
self._assert_af3_outputs_present(current_output_dir)
@parameterized.named_parameters(
dict(testcase_name="monomer", protein_list="test_monomer.txt", script="run_structure_prediction.py"),
dict(testcase_name="dimer", protein_list="test_dimer.txt", script="run_structure_prediction.py"),
dict(testcase_name="trimer", protein_list="test_trimer.txt", script="run_structure_prediction.py"),
dict(testcase_name="homo_oligomer", protein_list="test_homooligomer.txt", script="run_structure_prediction.py"),
dict(testcase_name="chopped_dimer", protein_list="test_dimer_chopped.txt", script="run_structure_prediction.py"),
dict(testcase_name="long_name", protein_list="test_long_name.txt", script="run_structure_prediction.py"),
# Ensure AF3 also works when launched via the multimer wrapper script
dict(testcase_name="monomer_via_multimer_wrapper", protein_list="test_monomer.txt", script="run_multimer_jobs.py"),
dict(testcase_name="chopped_dimer_via_multimer_wrapper", protein_list="test_dimer_chopped.txt", script="run_multimer_jobs.py"),
# Test cases for combining AlphaPulldown monomer with different JSON inputs
dict(
testcase_name="monomer_with_rna",
protein_list="test_monomer_with_rna.txt",
script="run_structure_prediction.py"
),
dict(
testcase_name="monomer_with_dna",
protein_list="test_monomer_with_dna.txt",
script="run_structure_prediction.py"
),
dict(
testcase_name="monomer_with_ligand",
protein_list="test_monomer_with_ligand.txt",
script="run_structure_prediction.py"
),
# Test case for protein with PTMs from JSON
dict(
testcase_name="protein_with_ptms",
protein_list="test_protein_with_ptms.txt",
script="run_structure_prediction.py"
),
# Test case for multiple seeds and diffusion samples
dict(
testcase_name="multi_seeds_samples",
protein_list="test_multi_seeds_samples.txt",
script="run_structure_prediction.py"
),
# Test homodimer from af3 features
dict(
testcase_name="homodimer_from_json_features",
protein_list="test_homodimer_from_json_features.txt",
script="run_structure_prediction.py",
),
)
def test_(self, protein_list, script):
# Create environment with GPU settings
env = self._make_af3_test_env()
# Debug output
print("\nEnvironment variables:")
print(f"XLA_FLAGS: {env.get('XLA_FLAGS')}")
print(f"XLA_PYTHON_CLIENT_PREALLOCATE: {env.get('XLA_PYTHON_CLIENT_PREALLOCATE')}")
print(f"XLA_CLIENT_MEM_FRACTION: {env.get('XLA_CLIENT_MEM_FRACTION')}")
print(f"JAX_FLASH_ATTENTION_IMPL: {env.get('JAX_FLASH_ATTENTION_IMPL')}")
# Check GPU availability
try:
import jax
print("\nJAX GPU devices:")
print(jax.devices())
print("JAX GPU local devices:")
print(jax.local_devices(backend='gpu'))
except Exception as e:
print(f"\nError checking JAX GPU: {e}")
res = subprocess.run(
self._args(plist=protein_list, script=script),
capture_output=True,
text=True,
env=env
)
self._runCommonTests(res)
# Check chain counts and sequences
self._check_chain_counts_and_sequences(protein_list)
def test_af3_writes_embeddings_and_distogram(self):
"""Run AF3 with embeddings and distogram enabled and check files exist."""
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
args = [
sys.executable,
str(self.script_single),
f"--input=A0A075B6L2:1:2-5", # small chopped example
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--save_embeddings",
"--save_distogram",
"--num_diffusion_samples=1",
]
res = subprocess.run(args, capture_output=True, text=True, env=env)
self._runCommonTests(res)
# Check per-seed embeddings and distogram artifacts in output dir
seed_emb_dirs = list(self.output_dir.glob("seed-*_*embeddings"))
seed_dist_dirs = list(self.output_dir.glob("seed-*_*distogram"))
self.assertTrue(len(seed_emb_dirs) >= 1, "No embeddings directories written")
self.assertTrue(len(seed_dist_dirs) >= 1, "No distogram directories written")
# Number of embeddings/distogram directories should equal number of unique seeds
with open(self.output_dir / "ranking_scores.csv") as f:
lines = [ln.strip() for ln in f.readlines()[1:] if ln.strip()]
seeds_in_csv = {ln.split(",")[0] for ln in lines}
self.assertEqual(len(seed_emb_dirs), len(seeds_in_csv),
f"Embeddings dirs ({len(seed_emb_dirs)}) != seeds ({len(seeds_in_csv)})")
self.assertEqual(len(seed_dist_dirs), len(seeds_in_csv),
f"Distogram dirs ({len(seed_dist_dirs)}) != seeds ({len(seeds_in_csv)})")
# Check expected files inside
for emb_dir in seed_emb_dirs:
npz_files = list(emb_dir.glob("*.npz"))
self.assertTrue(len(npz_files) >= 1, f"No embeddings npz in {emb_dir}")
# Validate embeddings content
for npz in npz_files:
with np.load(npz) as data:
self.assertIn('single_embeddings', data.files, f"single_embeddings missing in {npz}")
self.assertIn('pair_embeddings', data.files, f"pair_embeddings missing in {npz}")
self.assertGreater(data['single_embeddings'].size, 0, f"single_embeddings empty in {npz}")
self.assertGreater(data['pair_embeddings'].size, 0, f"pair_embeddings empty in {npz}")
for d_dir in seed_dist_dirs:
npz_files = list(d_dir.glob("*_distogram.npz"))
self.assertTrue(len(npz_files) >= 1, f"No distogram npz in {d_dir}")
# Validate distogram content
for npz in npz_files:
with np.load(npz) as data:
self.assertIn('distogram', data.files, f"distogram key missing in {npz}")
self.assertGreater(data['distogram'].size, 0, f"distogram array empty in {npz}")
def test_af3_num_recycles_affects_runtime(self):
"""num_recycles=1 should be faster than default (keeping other knobs same)."""
if os.getenv("AF3_RUN_PERF_TESTS", "").lower() not in ("1", "true", "yes"):
self.skipTest(
"Set AF3_RUN_PERF_TESTS=1 to run AF3 runtime benchmarks."
)
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
common = [
sys.executable,
str(self.script_single),
f"--input=A0A075B6L2:1",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--num_seeds=2", # ensures second seed reuses compiled XLA and timing reflects compute
]
# Default num_recycles (10) measure per-seed inference time from logs (last seed)
res_default = subprocess.run(common, capture_output=True, text=True, env=env)
self._runCommonTests(res_default)
combined_default = res_default.stdout + "\n" + res_default.stderr
m_default = re.findall(r"Model inference for seed .* took ([0-9.]+) seconds\.", combined_default)
self.assertTrue(len(m_default) >= 1, "Couldn't parse default inference time from logs")
default_time = float(m_default[-1])
# num_recycles=1
faster_dir = self.output_dir / "fewer_recycles"
faster_dir.mkdir(parents=True, exist_ok=True)
args_fast = common.copy()
args_fast[args_fast.index(f"--output_directory={self.output_dir}")] = f"--output_directory={faster_dir}"
args_fast.append("--num_recycles=1")
res_fast = subprocess.run(args_fast, capture_output=True, text=True, env=env)
self._runCommonTests(res_fast)
combined_fast = res_fast.stdout + "\n" + res_fast.stderr
m_fast = re.findall(r"Model inference for seed .* took ([0-9.]+) seconds\.", combined_fast)
self.assertTrue(len(m_fast) >= 1, "Couldn't parse fast inference time from logs")
fast_time = float(m_fast[-1])
# Allow some jitter; require at least 15% faster with fewer recycles
self.assertLess(
fast_time,
0.95 * default_time,
f"num_recycles=1 not faster enough (default {default_time:.2f}s vs {fast_time:.2f}s)",
)
def test_af3_rejects_alphafold2_flag(self):
"""Passing AF2-only flags to AF3 backend should fail via validator."""
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
args = [
sys.executable,
str(self.script_single),
f"--input=A0A075B6L2:1:2-5",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
# Intentionally invalid for AF3:
"--num_predictions_per_model=1",
]
res = subprocess.run(args, capture_output=True, text=True, env=env)
# Expect non-zero exit and clear error message
self.assertNotEqual(res.returncode, 0, "AF3 run unexpectedly succeeded with AF2 flag")
self.assertRegex(
res.stderr + res.stdout,
r"not supported by backend 'alphafold3'",
)
# --------------------------------------------------------------------------- #
def _parse_test_args():
"""Parse test-specific arguments that work with both absltest and pytest."""
# Check for --use-temp-dir in sys.argv or environment variable
use_temp_dir = '--use-temp-dir' in sys.argv or os.getenv('USE_TEMP_DIR', '').lower() in ('1', 'true', 'yes')
# Remove the argument from sys.argv if present to avoid conflicts
while '--use-temp-dir' in sys.argv:
sys.argv.remove('--use-temp-dir')
return use_temp_dir
# Parse arguments at module level to work with both absltest and pytest
_TestBase.use_temp_dir = _parse_test_args()
if __name__ == "__main__":
absltest.main()