* Harden MMseqs species ID resolution fallback

* Reorganize tests for CPU coverage CI

* New

* Fix function coverage checker def-line false positives

* Expand unit coverage for helper and backend manager utilities

* New.

* New.

* Expand unit coverage for template and post-processing helpers

* Expand unit coverage for objects.py edge cases

* Publish HTML coverage reports via GitHub Pages

* Add CPU unit coverage for AlphaFold3 backend helpers

* Reorganize tests and expand backend coverage

* Reset shared test flags between cases

* Expand AF3 prepare_input unit coverage

* Cover AF3 and truemultimer feature creation

* Test AF3 multimer MSA translation paths

* Cover AF3 duplicate-residue multimer fallback

* Cover AF2 resume and postprocess edge paths

* Cover AF3 template mmCIF preparation

* Test small script entry points

* Expand workflow and ModelCIF test coverage

* Add backend extras and install guide

* Clarify AF3 backend installation path

* Stabilize cluster GPU test runners

* Document AF3 CMake SQLite hints

* Simplify backend installation guide

* Align AF3 install with working cluster env

* Backfill typing dataclass_transform for AF2

* Pin TensorFlow for cluster installs

* Fallback AF2 relax when CUDA OpenMM is unavailable

* Raise AF3 default minimum bucket size

* Simplify backend cluster installation guide

* Fix AF3 wrapper JSON output isolation

* Fix AF3 JSON wrapper outputs and MMseqs ID parsing

* Fix CI entrypoint stub and Python 3.8 typing

* Document release readiness test gates
This commit is contained in:
Dima
2026-04-01 14:13:35 +02:00
committed by GitHub
parent 9bd18ce9b2
commit fff63051b4
38 changed files with 5991 additions and 135 deletions

View File

@@ -4,7 +4,8 @@ source =
alphapulldown
omit =
alphapulldown/__init__.py
alphapulldown/analysis_pipeline/af2plots/*
alphapulldown/analysis_pipeline/*
alphapulldown/folding_backend/alphalink_backend.py
[report]
skip_empty = True

View File

@@ -11,6 +11,7 @@ import json
import pickle
import subprocess
import enum
import typing
from typing import Dict, Union, List, Any
import os
from absl import logging
@@ -45,6 +46,55 @@ class ModelsToRelax(enum.Enum):
NONE = 2
def _ensure_typing_dataclass_transform() -> None:
"""Provide typing.dataclass_transform on Python versions that lack it."""
if hasattr(typing, "dataclass_transform"):
return
from typing_extensions import dataclass_transform
typing.dataclass_transform = dataclass_transform
def _get_openmm_platform_names() -> List[str]:
"""Return the available OpenMM platform names."""
try:
import openmm
except ImportError:
return []
try:
return [
openmm.Platform.getPlatform(i).getName()
for i in range(openmm.Platform.getNumPlatforms())
]
except Exception as exc: # pragma: no cover - defensive
logging.warning("Failed to inspect OpenMM platforms: %s", exc)
return []
def _resolve_gpu_relax(use_gpu_relax: bool) -> bool:
"""Only request GPU relax when OpenMM exposes a CUDA platform."""
if not use_gpu_relax:
return False
platform_names = _get_openmm_platform_names()
if "CUDA" in platform_names:
return True
if platform_names:
logging.warning(
"OpenMM CUDA platform is unavailable; falling back to CPU relax. "
"Available platforms: %s",
", ".join(platform_names),
)
else:
logging.warning(
"OpenMM CUDA platform is unavailable; falling back to CPU relax."
)
return False
def _jnp_to_np(output):
"""Recursively changes jax arrays to numpy arrays."""
for k, v in output.items():
@@ -381,6 +431,7 @@ class AlphaFold2Backend(FoldingBackend):
If provided custom model names are not part of the available models.
"""
_ensure_typing_dataclass_transform()
from alphafold.model import config
from alphafold.model import data, model
@@ -798,6 +849,7 @@ class AlphaFold2Backend(FoldingBackend):
for model_name, prediction_result in prediction_results.items():
prediction_result.update(AlphaFold2Backend.recalculate_confidence(prediction_result,multimer_mode,
total_num_res))
unrelaxed_protein = prediction_result.get("unrelaxed_protein")
if 'unrelaxed_protein' in prediction_result.keys():
unrelaxed_protein = prediction_result.pop("unrelaxed_protein")
# Remove jax dependency from results
@@ -806,7 +858,8 @@ class AlphaFold2Backend(FoldingBackend):
result_output_path = os.path.join(output_dir, f"result_{model_name}.pkl")
with open(result_output_path, "wb") as f:
pickle.dump(np_prediction_result, f, protocol=4)
prediction_results[model_name]['unrelaxed_protein'] = unrelaxed_protein
if unrelaxed_protein is not None:
prediction_results[model_name]['unrelaxed_protein'] = unrelaxed_protein
if 'iptm' in prediction_result:
label = 'iptm+ptm'
iptm_scores[model_name] = float(prediction_result['iptm'])
@@ -862,7 +915,7 @@ class AlphaFold2Backend(FoldingBackend):
stiffness=RELAX_STIFFNESS,
exclude_residues=RELAX_EXCLUDE_RESIDUES,
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
use_gpu=use_gpu_relax)
use_gpu=_resolve_gpu_relax(use_gpu_relax))
if models_to_relax == ModelsToRelax.BEST:
to_relax = [ranked_order[0]]

View File

@@ -31,13 +31,15 @@ from alphapulldown.utils.modelling_setup import create_uniprot_runner
from alphapulldown.utils import save_meta_data
# Try to import AlphaFold3, but it's optional
AF3_IMPORT_ERROR = None
try:
from alphafold3.data.pipeline import DataPipeline as AF3DataPipeline, DataPipelineConfig as AF3DataPipelineConfig
from alphafold3.common import folding_input
except ImportError:
except ImportError as exc:
AF3DataPipeline = None
AF3DataPipelineConfig = None
folding_input = None
AF3_IMPORT_ERROR = exc
# =================== Database Maps ===================
AF2_DATABASES = {
@@ -321,7 +323,13 @@ def create_custom_db(temp_dir, protein, template_paths, chains):
def create_pipeline_af3():
"""Create the AlphaFold3 pipeline. Raises if AF3 not available."""
if AF3DataPipeline is None or AF3DataPipelineConfig is None:
raise ImportError("alphafold3.data.pipeline not available")
raise ImportError(
"AlphaFold3 is not installed correctly. "
"Install AlphaPulldown with 'pip install -e \".[alphafold3,test]\"', "
"make sure the build environment provides SQLite, then build the "
"vendored package with 'pip install -r alphafold3/dev-requirements.txt', "
"'pip install --no-deps -e ./alphafold3', and 'build_data'."
) from AF3_IMPORT_ERROR
# Convert max_template_date string to datetime.date object
import datetime

View File

@@ -38,8 +38,7 @@ out_lines = []
with open(sys.argv[1]) as f:
for headerStr, seq in fasta_iter(f):
items = re.split('[ \|]', headerStr)
items = re.split(r"[ |]", headerStr)
out_lines.append(f'>{items[1]}')
out_lines.append(seq)
f.close()
print("\n".join(out_lines))
print("\n".join(out_lines))

View File

@@ -13,6 +13,8 @@ import sys
import jax
gpus = jax.local_devices(backend='gpu')
from alphapulldown.scripts.run_structure_prediction import FLAGS
from alphapulldown.utils.modelling_setup import parse_fold
from alphapulldown.utils.output_paths import derive_af3_job_name_from_json
from alphapulldown_input_parser import generate_fold_specifications
logging.set_verbosity(logging.INFO)
@@ -38,6 +40,36 @@ flags.DEFINE_boolean("dry_run", False, "Report number of jobs that would be run
flags.DEFINE_list("monomer_objects_dir", None, "a list of directories where monomer objects are stored")
flags.DEFINE_string("output_path", None, "output directory where the region data is going to be stored")
flags.DEFINE_string("data_dir", None, "Path to params directory")
def _resolve_af3_wrapper_output_dir(
job_spec: str,
output_root: str,
*,
features_directory,
protein_delimiter: str,
use_ap_style: bool,
) -> str:
"""Scope single AF3 JSON wrapper jobs to stable per-job subdirectories."""
if not use_ap_style:
return output_root
try:
parsed_jobs = parse_fold([job_spec], features_directory, protein_delimiter)
except Exception:
return output_root
if len(parsed_jobs) != 1 or len(parsed_jobs[0]) != 1:
return output_root
entry = parsed_jobs[0][0]
if not (isinstance(entry, dict) and "json_input" in entry):
return output_root
return os.path.join(
output_root,
derive_af3_job_name_from_json(entry["json_input"]),
)
del(FLAGS.models_to_relax)
flags.DEFINE_enum("models_to_relax",'None',['None','All','Best'],
"Which models to relax. Default is None, meaning no model will be relaxed")
@@ -177,6 +209,14 @@ def main(argv):
else:
for job_index in job_indices:
command_args["--input"] = all_folds[job_index]
if fold_backend == "alphafold3":
command_args["--output_directory"] = _resolve_af3_wrapper_output_dir(
all_folds[job_index],
FLAGS.output_path,
features_directory=FLAGS.monomer_objects_dir,
protein_delimiter=FLAGS.protein_delimiter,
use_ap_style=af3_use_ap_style,
)
command = base_command.copy()
for arg, value in command_args.items():
command.extend([str(arg), str(value)])

View File

@@ -26,7 +26,10 @@ from alphapulldown.folding_backend import backend
from alphapulldown.folding_backend.alphafold2_backend import ModelsToRelax
from alphapulldown.objects import MultimericObject, MonomericObject, ChoppedObject
from alphapulldown.utils.modelling_setup import create_interactors, create_custom_info, parse_fold
from alphapulldown.utils.output_paths import resolve_af3_json_output_dir
from alphapulldown.utils.output_paths import (
resolve_af3_combined_json_output_dir,
resolve_af3_json_output_dir,
)
import sys as _sys
logging.set_verbosity(logging.INFO)
@@ -120,7 +123,7 @@ flags.DEFINE_string(
flags.DEFINE_list(
'buckets',
# pyformat: disable
['64', '128', '256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072',
['128', '256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072',
'3584', '4096', '4608', '5120'],
# pyformat: enable
'Strictly increasing order of token sizes for which to cache compilations.'
@@ -486,6 +489,7 @@ def main(argv):
prot_objs, output_dir=out_dir
)
objects_to_model.append({'object': obj, 'output_dir': real_out})
json_output_dir = real_out
# Update final flags based on object type
final_model_flags = default_model_flags.copy()
@@ -498,15 +502,27 @@ def main(argv):
"model_names_custom": FLAGS.model_names,
"msa_depth": FLAGS.msa_depth
})
# Then handle any number of JSON inputs
for json_dict in json_dicts:
json_output_dir = resolve_af3_json_output_dir(
json_dict["json_input"],
elif len(json_dicts) > 1:
json_output_dir = resolve_af3_combined_json_output_dir(
json_dicts,
out_dir,
use_ap_style=FLAGS.use_ap_style,
shared_output_root=shared_output_root,
)
objects_to_model.append({'object': json_dict, 'output_dir': json_output_dir})
else:
json_output_dir = None
# Then handle any number of JSON inputs
for json_dict in json_dicts:
current_json_output_dir = json_output_dir
if current_json_output_dir is None:
current_json_output_dir = resolve_af3_json_output_dir(
json_dict["json_input"],
out_dir,
use_ap_style=FLAGS.use_ap_style,
shared_output_root=shared_output_root,
)
objects_to_model.append(
{'object': json_dict, 'output_dir': current_json_output_dir}
)
if objects_to_model:
predict_structure(

View File

@@ -33,8 +33,8 @@ _UNIREF_HEADER_PATTERN = re.compile(
r'^UniRef\d+_(?P<accession>[A-Za-z0-9]+)$'
)
_UNIPARC_HEADER_PATTERN = re.compile(r'^(?P<accession>UPI[A-Z0-9]+)$')
_GENERIC_ACCESSION_PATTERN = re.compile(
r'^(?P<accession>[A-Za-z0-9]{6,16})$'
_UNIPROT_ACCESSION_PATTERN = re.compile(
r'^(?:[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9](?:[A-Z][A-Z0-9]{2}[0-9]){1,2})$'
)
_UNIPROT_BATCH_SIZE = 32
@@ -59,14 +59,19 @@ def _extract_accession_and_species(description: str) -> tuple[str, str]:
if matches:
return matches.group("accession"), matches.group("species")
for pattern in (
_UNIREF_HEADER_PATTERN,
_UNIPARC_HEADER_PATTERN,
_GENERIC_ACCESSION_PATTERN,
):
matches = pattern.search(sequence_identifier.strip())
if matches:
return matches.group("accession"), ""
matches = _UNIREF_HEADER_PATTERN.search(sequence_identifier.strip())
if matches:
accession = matches.group("accession")
if _is_resolvable_accession(accession):
return accession, ""
return "", ""
matches = _UNIPARC_HEADER_PATTERN.search(sequence_identifier.strip())
if matches:
return matches.group("accession"), ""
if _UNIPROT_ACCESSION_PATTERN.search(sequence_identifier.strip()):
return sequence_identifier.strip(), ""
return "", ""
@@ -82,6 +87,12 @@ def _is_transport_error(exc: Exception) -> bool:
)
def _is_resolvable_accession(accession: str) -> bool:
return accession.startswith("UPI") or bool(
_UNIPROT_ACCESSION_PATTERN.fullmatch(accession)
)
def _query_uniprot_batch(
accessions: Sequence[str],
*,
@@ -203,8 +214,13 @@ def resolve_species_ids_by_accession(
unresolved = [
accession
for accession in sorted(set(accessions))
if accession and accession not in _SPECIES_ID_CACHE
if accession
and accession not in _SPECIES_ID_CACHE
and _is_resolvable_accession(accession)
]
for accession in accessions:
if accession and accession not in _SPECIES_ID_CACHE and not _is_resolvable_accession(accession):
_SPECIES_ID_CACHE.setdefault(accession, "")
if unresolved:
uniprot_accessions = [
accession for accession in unresolved if not accession.startswith("UPI")

View File

@@ -2,6 +2,7 @@ import json
import os
import string
from pathlib import Path
from typing import Sequence
_AF3_ALLOWED_NAME_CHARS = set(string.ascii_lowercase + string.digits + "_-.")
@@ -40,6 +41,119 @@ def derive_af3_job_name_from_json(json_input_path: str) -> str:
return fallback_name
def _json_input_basename(json_input_path: str) -> str:
stem = Path(json_input_path).stem
for suffix in ("_af3_input", "_input"):
if stem.endswith(suffix):
stem = stem[: -len(suffix)]
break
return stem or Path(json_input_path).stem
def _collapse_repeated_name_fragments(fragments: Sequence[str]) -> list[str]:
if not fragments:
return []
collapsed: list[str] = []
current_fragment = fragments[0]
current_count = 1
for fragment in fragments[1:]:
if fragment == current_fragment:
current_count += 1
continue
collapsed.append(
current_fragment
if current_count == 1
else f"{current_fragment}__x{current_count}"
)
current_fragment = fragment
current_count = 1
collapsed.append(
current_fragment
if current_count == 1
else f"{current_fragment}__x{current_count}"
)
return collapsed
def _compact_output_job_name(job_name: str, *, max_chars: int = 200) -> str:
if len(job_name) <= max_chars:
return job_name
import hashlib
digest = hashlib.sha1(job_name.encode("utf-8")).hexdigest()[:12]
suffix = f"__{digest}"
prefix = job_name[: max_chars - len(suffix)].rstrip("_.-")
if not prefix:
return f"job{suffix}"
return f"{prefix}{suffix}"
def _normalise_json_regions(regions: object) -> str | None:
if not isinstance(regions, list) or not regions:
return None
parts: list[str] = []
for region in regions:
if not isinstance(region, (tuple, list)) or len(region) != 2:
return None
start, end = region
parts.append(f"{start}-{end}")
return "_".join(parts)
def build_af3_combined_json_job_name(
json_inputs: Sequence[dict[str, object]],
) -> str:
fragments: list[str] = []
for json_input in json_inputs:
json_input_path = json_input.get("json_input")
if not isinstance(json_input_path, str) or not json_input_path:
continue
fragment = _json_input_basename(json_input_path)
region_fragment = _normalise_json_regions(json_input.get("regions"))
if region_fragment:
fragment = f"{fragment}__{region_fragment}"
fragments.append(sanitise_af3_job_name(fragment))
fragments = [fragment for fragment in fragments if fragment]
if not fragments:
return "ranked_0"
return _compact_output_job_name(
"_and_".join(_collapse_repeated_name_fragments(fragments))
)
def _ensure_path_is_within_root(candidate: Path, output_root: Path) -> None:
try:
candidate.resolve(strict=False).relative_to(output_root.resolve(strict=False))
except ValueError as exc:
raise ValueError(
f"Resolved AF3 output directory {candidate} escapes configured root {output_root}"
) from exc
def resolve_af3_combined_json_output_dir(
json_inputs: Sequence[dict[str, object]],
output_dir: str,
*,
use_ap_style: bool,
) -> str:
if not use_ap_style:
return output_dir
output_root = Path(output_dir)
candidate = output_root / build_af3_combined_json_job_name(json_inputs)
_ensure_path_is_within_root(candidate, output_root)
return os.fspath(candidate)
def resolve_af3_json_output_dir(
json_input_path: str,
output_dir: str,
@@ -53,12 +167,5 @@ def resolve_af3_json_output_dir(
output_root = Path(output_dir)
candidate = output_root / derive_af3_job_name_from_json(json_input_path)
try:
candidate.resolve(strict=False).relative_to(output_root.resolve(strict=False))
except ValueError as exc:
raise ValueError(
f"Resolved AF3 output directory {candidate} escapes configured root {output_root}"
) from exc
_ensure_path_is_within_root(candidate, output_root)
return os.fspath(candidate)

View File

@@ -215,9 +215,10 @@ def tmp_flags(monkeypatch, tmp_path):
# not an absl flag on this FLAGS; attach as attribute for completeness
setattr(F, name, value)
continue
fl.unparse()
arg = _to_arg_str(name, value)
if arg is None:
# leave at declared default (often None); don't parse
# reset to the declared default and leave it unset
continue
fl.parse(arg) # sets .value and marks present

View File

@@ -40,7 +40,9 @@ RUN set -eux; \
modelcif \
hmmer \
hhsuite \
pdbfixer=1.9 \
"numpy<2" \
"openmm>=8.2" \
"pdbfixer>=1.10" \
pip \
git \
&& micromamba clean -a -y
@@ -78,5 +80,3 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
rm -rf /var/lib/apt/lists/* /root/.cache
#ENTRYPOINT ["bash"]
RUN pip install --no-cache-dir "numpy<2"

View File

@@ -0,0 +1,246 @@
# Backend Installation Guide
This guide is for direct AlphaPulldown use without Snakemake.
Two points matter in practice:
1. Run the install commands from the AlphaPulldown repo root.
2. For AlphaFold3, building the vendored `alphafold3` package is a separate step. A root-level install alone is not enough.
The Docker files remain the long-term reference environments, but the commands below are the simpler cluster-facing paths.
We revalidated them on EMBL on March 30, 2026 by creating fresh environments and running the cluster test suites there.
## Known-good cluster stacks
These are the two EMBL environments that were already working and that we rechecked while validating the wrappers:
- `AlphaPulldown` for AF2:
- Python `3.10`
- `jax 0.5.3`
- `jaxlib 0.5.3`
- `numpy 1.26.4`
- `tensorflow-cpu 2.20.0`
- `openmm 8.1.1`
- `pdbfixer 1.12`
- `modelcif 1.6`
- `AlphaPulldown_alphafold3` for AF3:
- Python `3.11`
- `jax 0.5.3`
- `jaxlib 0.5.3`
- `numpy 1.26.4`
- `tensorflow-cpu 2.18.0`
- `openmm 8.3.1`
- `pdbfixer 1.12`
- `modelcif 1.6`
- `jax-triton 0.2.0`
- `triton 3.1.0`
- `rdkit 2024.3.5`
- `typeguard 2.13.3`
- compiled `alphafold3.cpp`
The installation steps below are written to stay close to those working environments and to keep user-facing environment variables to a minimum.
## AlphaFold2 backend
Create the environment, activate it, move into the repo, and install:
```bash
mamba create -y -n apd-af2 -c conda-forge -c bioconda \
python=3.10 \
kalign2 \
hmmer \
hhsuite
mamba activate apd-af2
cd /path/to/AlphaPulldown
python -m pip install ".[alphafold2,test]"
```
Check that GPU JAX is visible:
```bash
python - <<'PY'
import jax
print(jax.__version__)
print(jax.local_devices(backend="gpu"))
PY
```
## AlphaFold3 backend
AlphaFold3 needs both the AlphaPulldown install and the vendored `alphafold3` build.
```bash
mamba create -y -n apd-af3 -c conda-forge -c bioconda \
python=3.11 \
kalign2 \
hmmer \
hhsuite \
libcifpp \
sqlite
mamba activate apd-af3
cd /path/to/AlphaPulldown
python -m pip install ".[alphafold3,test]"
python -m pip install --no-deps -e ./alphafold3
build_data
```
Check that the compiled extension and GPU JAX are available:
```bash
python - <<'PY'
import alphafold3.cpp
import jax
print(jax.__version__)
print(jax.local_devices(backend="gpu"))
print(alphafold3.cpp.__file__)
PY
```
Notes:
- The working EMBL AF3 environment is closer to `.[alphafold3,test]` than to the upstream `alphafold3/dev-requirements.txt` stack.
- For cluster installs, prefer `python -m pip install ".[alphafold3,test]"` and then build the vendored `alphafold3` package.
- The compiled `alphafold3.cpp` extension comes from `python -m pip install --no-deps -e ./alphafold3`, not from the root install.
- The vendored AF3 package provides the `build_data` entry point. Use that directly.
- If you are actively developing inside the checkout and want the root package editable as well, add:
```bash
python -m pip install -e . --no-deps
```
## Cluster validation
On EMBL, the AF2 and AF3 functional tests already have default database roots baked into the test files, so the only environment variable you need for the standard cluster runs is:
```bash
export RUN_GPU_FUNCTIONAL_TESTS=1
```
Set `ALPHAFOLD_DATA_DIR` only if your databases are not in the EMBL default locations.
Two ways to run the cluster tests:
1. Direct `srun` + `pytest`
- Best for validating a fresh install end-to-end.
- Keeps everything on one allocated GPU node.
- More reliable than the wrappers when Slurm priority is poor.
- Important: when calling `pytest` directly, pass `-o addopts="-ra --strict-markers"`. The repo-level `pytest.ini` excludes cluster tests by default.
2. Wrapper scripts
- `test/cluster/run_alphafold2_predictions.py`
- `test/cluster/run_alphafold3_predictions.py`
- Faster when the queue is healthy, because they fan out one job per pytest node.
### AlphaFold2
Direct full-suite validation:
```bash
srun -p gpu-training --gres=gpu:1 \
--cpus-per-task=4 --mem=16G --time=12:00:00 \
bash -lc '
cd /path/to/AlphaPulldown
export RUN_GPU_FUNCTIONAL_TESTS=1
python -m pytest -o addopts="-ra --strict-markers" -vv -s \
test/cluster/check_alphafold2_predictions.py --use-temp-dir
'
```
Expected result on the standard suite:
```text
11 passed, 1 skipped
```
The skip is the opt-in MMseqs functional inference check, which only runs when `RUN_MMSEQS_FUNCTIONAL_TESTS=1` is set.
Preview the collected nodes for wrapper mode:
```bash
python test/cluster/run_alphafold2_predictions.py --list
```
Wrapper-based parallel submission:
```bash
python test/cluster/run_alphafold2_predictions.py \
--max-tests 6 \
--use-temp-dir \
--partition gpu-training
```
### AlphaFold3
Direct full-suite validation:
```bash
srun -p gpu-training --gres=gpu:1 \
--cpus-per-task=4 --mem=64G --time=12:00:00 \
bash -lc '
cd /path/to/AlphaPulldown
export RUN_GPU_FUNCTIONAL_TESTS=1
python -m pytest -o addopts="-ra --strict-markers" -vv -s \
test/cluster/check_alphafold3_predictions.py --use-temp-dir
'
```
On our fresh EMBL validation run, `32G` was not enough for the full AF3 suite on `hgx5`; `64G` was.
Preview the collected nodes for wrapper mode:
```bash
python test/cluster/run_alphafold3_predictions.py --list
```
Wrapper-based parallel submission:
```bash
python test/cluster/run_alphafold3_predictions.py \
--max-tests 6 \
--use-temp-dir \
--partition gpu-training
```
## Troubleshooting
### AlphaFold2: `Unknown backend: 'gpu' requested, ... Platforms present are: cpu`
That means the environment has CPU-only JAX:
```bash
python -m pip install --upgrade --no-cache-dir "jax==0.5.3" "jax[cuda12]==0.5.3"
python - <<'PY'
import jax
print(jax.__version__)
print(jax.local_devices(backend="gpu"))
PY
```
### AlphaFold2: `There is no registered Platform called "CUDA"`
That comes from OpenMM relaxation, not from JAX. Older working EMBL envs exposed a CUDA-enabled OpenMM platform, while a fresh pip-installed OpenMM may expose only `Reference`, `CPU`, and `OpenCL`.
Current AlphaPulldown falls back to CPU relax automatically if CUDA is unavailable, so the AF2 cluster tests still pass. If you want GPU-backed OpenMM relax as well, install the OpenMM stack from conda before the pip install.
### AlphaFold3: `ModuleNotFoundError: No module named 'alphafold3.cpp'`
The root AlphaPulldown install succeeded, but the vendored AF3 package was not built yet:
```bash
cd /path/to/AlphaPulldown
python -m pip install ".[alphafold3,test]"
python -m pip install --no-deps -e ./alphafold3
build_data
```
### AlphaFold3 build error: `Could NOT find SQLite3`
If SQLite is installed in the conda environment but CMake still cannot find it, rerun the AF3 build with the conda prefix exposed:
```bash
mamba activate apd-af3
export CMAKE_PREFIX_PATH="$CONDA_PREFIX"
export SQLite3_ROOT="$CONDA_PREFIX"
cd /path/to/AlphaPulldown
python -m pip install --no-deps -e ./alphafold3
build_data
```

View File

@@ -4,8 +4,9 @@ channels:
- bioconda
- omnia
dependencies:
- openmm=8.0
- pdbfixer=1.9
- numpy<2
- openmm>=8.2
- pdbfixer>=1.10
- kalign2
- hmmer
- hhsuite

View File

@@ -26,7 +26,7 @@ dependencies = [
"matplotlib>=3.3.3",
"ml-collections>=0.1.0",
"pandas>=1.5.3",
"tensorflow-cpu>=2.16.1",
"tensorflow-cpu==2.20.0",
"importlib-resources>=6.1.0",
"importlib-metadata>=4.8.2,<5.0.0",
"biopython>=1.81,<1.82",
@@ -44,6 +44,28 @@ dependencies = [
]
[project.optional-dependencies]
alphafold2 = [
"jax==0.5.3",
"jax[cuda12]==0.5.3; platform_system == 'Linux' and platform_machine == 'x86_64'",
"modelcif>=1.6",
"numpy<2",
"openmm>=8.2",
"pdbfixer>=1.10",
]
alphafold3 = [
"jaxtyping==0.2.34",
"jax==0.5.3",
"jax[cuda12]==0.5.3; platform_system == 'Linux' and platform_machine == 'x86_64'",
"jax-triton==0.2.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
"modelcif>=1.6",
"numpy<2",
"openmm>=8.2",
"pdbfixer>=1.10",
"rdkit==2024.3.5",
"triton==3.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
"typeguard==2.13.3",
"zstandard",
]
test = [
"parameterized",
"pytest>=8.0",

14
test/README.md Normal file
View File

@@ -0,0 +1,14 @@
The active pytest layout is:
- `test/unit` for fast helper and mocked component tests
- `test/integration` for CPU-only filesystem, CLI wiring, and module interaction tests
- `test/functional` for heavier workflow tests that still belong to the main package test tree
- `test/cluster` for Slurm, GPU, or workstation smoke wrappers that are run explicitly
- `test/outdated` for legacy tests kept for reference until they are rewritten or deleted
Notes:
- `pytest.ini` only collects `unit`, `integration`, and `functional`.
- `conftest.py` auto-applies markers from the directory layout.
- `test/alphalink` now holds AlphaLink fixture files used by optional tests and cluster smoke runs.
- `test/RELEASE_READINESS.md` tracks which workflows are continuously protected by CI and which still require explicit cluster or manual validation before a release.

61
test/RELEASE_READINESS.md Normal file
View File

@@ -0,0 +1,61 @@
# Release Readiness
This branch has broad maintained test coverage, but the protection is layered.
## Always-on CI
GitHub Actions runs:
- `test/unit`
- `test/integration`
- coverage collection and reporting
- Python `3.10` and `3.11`
These lanes continuously protect:
- feature creation wiring and CLI dispatch
- fold parsing and object construction
- AF2 backend helper behavior
- AF3 backend helper behavior
- script entrypoint and wrapper argument handling
- post-prediction and ModelCIF integration paths
- CPU-safe AlphaLink helper logic
## Explicitly-run validation
The following are intentionally outside default CI and must be run explicitly when a change touches them:
- `test/cluster/check_alphafold2_predictions.py`
- `test/cluster/check_alphafold3_predictions.py`
- `test/cluster/check_alphalink_predictions.py`
- manual or cluster-backed GPU/Slurm smoke runs
Release-critical examples include:
- AF3 wrapper output isolation for combined JSON folds
- MMseqs-generated `.a3m` and `.pkl` / `.pkl.xz` feature artifacts
- AF3 species pairing regressions for issue `#588`
- AF2 and AF3 dimer inference quality checks such as `ipTM > 0.6`
## Not Continuously Protected
The following areas are only partially protected, optional, or report-only:
- `test/cluster` workflows
- `test/alphalink` workflows beyond CPU-safe helper tests
- legacy scenarios still parked under `test/outdated`
- analysis-pipeline utilities and some deeper ModelCIF internals
- Python `3.8`, which is still advertised in packaging but is not exercised by GitHub Actions
The coverage artifact is useful as an audit input, but it does not prove workflow correctness by itself. In particular, `python test/tools/check_function_coverage.py --report-only` highlights functions that were never executed in CI and should be treated as follow-up audit items, not automatic release blockers.
## Practical Release Gate
Before a final release, the expected evidence is:
1. PR smoke-tests and coverage are green.
2. AF3 JSON wrapper regressions are green.
3. Issue `#588` MMseqs/AF3 pairing regressions are green.
4. Manual or cluster-backed AF2 and AF3 dimer runs from MMseqs features complete successfully with acceptable confidence.
If those gates are green, the branch is in good release shape even though some heavyweight workflows are still validated outside default CI.

7
test/alphalink/README.md Normal file
View File

@@ -0,0 +1,7 @@
This directory contains AlphaLink-specific fixture files.
Active tests now live in the standard buckets:
- `test/unit/test_crosslink_input.py` for the small crosslink helper checks
- `test/cluster/check_alphalink_predictions.py` for GPU and weights-backed smoke tests
- `test/outdated/` for older AlphaLink checks that still hard-code cluster paths or stale fixtures

View File

@@ -1,50 +0,0 @@
import unittest
from unifold.dataset import calculate_offsets,create_xl_features,bin_xl
from alphafold.data.pipeline_multimer import _FastaChain
import numpy as np
import gzip,pickle
import torch
class TestCreateObjects(unittest.TestCase):
def setUp(self) -> None:
self.crosslink_info ="./test/test_data/test_xl_input.pkl.gz"
self.asym_id = torch.tensor([1]*10 + [2]*25 + [3]*40)
self.chain_id_map = {
"A":_FastaChain(sequence='',description='chain1'),
"B":_FastaChain(sequence='',description='chain2'),
"C":_FastaChain(sequence='',description='chain3')
}
self.bins = torch.arange(0,1.05,0.05)
return super().setUp()
def test1_calculate_offsets(self):
offsets = calculate_offsets(self.asym_id)
offsets = offsets.tolist()
expected_offsets = [0,10,35,75]
self.assertEqual(offsets,expected_offsets)
def test2_create_xl_inputs(self):
offsets = calculate_offsets(self.asym_id)
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
expected_xl = torch.tensor([[10,35,0.01],
[3,27,0.01],
[5,56,0.01],
[20,65,0.01]])
self.assertTrue(torch.equal(xl,expected_xl))
def test3_bin_xl(self):
offsets = calculate_offsets(self.asym_id)
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
num_res = len(self.asym_id)
xl = bin_xl(xl,num_res)
expected_xl = np.zeros((num_res,num_res,1))
expected_xl[3,27,0] = expected_xl[27,3,0] = torch.bucketize(0.99,self.bins)
expected_xl[10,35,0] = expected_xl[35,10,0] = torch.bucketize(0.99,self.bins)
expected_xl[5,56,0] = expected_xl[56,5,0] = torch.bucketize(0.99,self.bins)
expected_xl[20,65,0] = expected_xl[65,20,0] = torch.bucketize(0.99,self.bins)
self.assertTrue(np.array_equal(xl,expected_xl))
if __name__ == "__main__":
unittest.main()

View File

@@ -101,6 +101,27 @@ def _non_empty_identifier_count(values) -> int:
count += 1
return count
def _af2_subprocess_env() -> dict[str, str]:
"""Return stable GPU/JAX defaults for AF2 functional subprocesses."""
env = os.environ.copy()
env.setdefault("OMP_NUM_THREADS", "4")
env.setdefault("MKL_NUM_THREADS", "4")
env.setdefault("NUMEXPR_NUM_THREADS", "4")
env.setdefault("TF_NUM_INTEROP_THREADS", "4")
env.setdefault("TF_NUM_INTRAOP_THREADS", "4")
env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
env.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.8")
env.setdefault("JAX_PLATFORM_NAME", "gpu")
env.setdefault(
"XLA_FLAGS",
"--xla_gpu_force_compilation_parallelism=0 "
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1",
)
return env
# --------------------------------------------------------------------------- #
# common helper mix-in / assertions #
# --------------------------------------------------------------------------- #
@@ -159,6 +180,14 @@ class _TestBase(parameterized.TestCase):
apd_path / "scripts" / "create_individual_features.py"
)
def _run_prediction_subprocess(self, args):
return subprocess.run(
args,
capture_output=True,
text=True,
env=_af2_subprocess_env(),
)
# ---------------- assertions reused by all subclasses ----------------- #
def _runCommonTests(self, res: subprocess.CompletedProcess, multimer: bool, dirname: str | None = None):
if res.returncode != 0:
@@ -265,9 +294,8 @@ class TestRunModes(_TestBase):
)
def test_(self, protein_list, mode, script):
multimer = "monomer" not in protein_list
res = subprocess.run(
self._args(plist=protein_list, mode=mode, script=script),
capture_output=True, text=True
res = self._run_prediction_subprocess(
self._args(plist=protein_list, mode=mode, script=script)
)
self._runCommonTests(res, multimer)
@@ -351,7 +379,7 @@ class TestResume(_TestBase):
(self.output_dir / "TEST_homo_2er" / fname).unlink()
except FileNotFoundError:
pass
res = subprocess.run(args, capture_output=True, text=True)
res = self._run_prediction_subprocess(args)
self._runCommonTests(res, multimer=True, dirname="TEST_homo_2er")
self._runAfterRelaxTests(relax_mode)
@@ -422,12 +450,12 @@ class TestDropoutDiversity(_TestBase):
# Execute both predictions
logger.info("Running prediction without dropout...")
#logger.info("".join(args_no_dropout))
res_no_dropout = subprocess.run(args_no_dropout, capture_output=True, text=True)
res_no_dropout = self._run_prediction_subprocess(args_no_dropout)
self.assertEqual(res_no_dropout.returncode, 0,
f"No dropout prediction failed: {res_no_dropout.stderr}")
logger.info("Running prediction with dropout...")
res_with_dropout = subprocess.run(args_with_dropout, capture_output=True, text=True)
res_with_dropout = self._run_prediction_subprocess(args_with_dropout)
self.assertEqual(res_with_dropout.returncode, 0,
f"Dropout prediction failed: {res_with_dropout.stderr}")
@@ -508,7 +536,7 @@ class TestMmseqsIssue588Inference(_TestBase):
"--compress_features=True",
"--skip_existing=False",
]
res = subprocess.run(args, capture_output=True, text=True)
res = self._run_prediction_subprocess(args)
self.assertEqual(
res.returncode,
0,
@@ -588,6 +616,7 @@ class TestMmseqsIssue588Inference(_TestBase):
],
capture_output=True,
text=True,
env=_af2_subprocess_env(),
)
self.assertEqual(
res.returncode,

View File

@@ -1327,6 +1327,14 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(source_dir / f"{protein_id}.a3m").is_file(),
f"Expected MMseq A3M {source_dir / f'{protein_id}.a3m'} to be created.",
)
self.assertTrue(
(source_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected compressed feature pickle {source_dir / f'{protein_id}.pkl.xz'} to be created.",
)
shutil.copy2(
source_dir / f"{protein_id}.a3m",
precomputed_dir / f"{protein_id}.a3m",
@@ -1356,6 +1364,15 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
"Precomputed-MMseq feature generation failed.\n"
f"STDOUT:\n{precomputed_res.stdout}\nSTDERR:\n{precomputed_res.stderr}",
)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(precomputed_dir / f"{protein_id}.a3m").is_file(),
f"Expected copied MMseq A3M {precomputed_dir / f'{protein_id}.a3m'} to be present.",
)
self.assertTrue(
(precomputed_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected precomputed feature pickle {precomputed_dir / f'{protein_id}.pkl.xz'} to be created.",
)
return precomputed_dir
def _prepare_fold_input(
@@ -1656,6 +1673,7 @@ class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--data_pipeline=alphafold2",
"--save_msa_files=True",
"--compress_features=True",
"--skip_existing=False",
],
@@ -1676,6 +1694,14 @@ class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
feature_dir = self._generate_issue_588_mmseq_features(env)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(feature_dir / f"{protein_id}.a3m").is_file(),
f"Expected MMseq A3M {feature_dir / f'{protein_id}.a3m'} to be created.",
)
self.assertTrue(
(feature_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected compressed feature pickle {feature_dir / f'{protein_id}.pkl.xz'} to be created.",
)
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
self.assertGreater(
_non_empty_identifier_count(
@@ -2974,19 +3000,28 @@ class TestAlphaFold3RunModes(_TestBase):
self._assert_af3_outputs_present(current_output_dir)
def test_af3_run_multimer_jobs_multiple_json_jobs_create_per_job_subdirs(self):
"""Shared AF3 wrapper output roots must isolate multiple JSON jobs by subdirectory."""
from alphapulldown.utils.output_paths import derive_af3_job_name_from_json
"""Shared AF3 wrapper roots must isolate combined JSON folds by subdirectory."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
json_inputs = [
self.test_features_dir / "protein_with_ptms.json",
self.test_features_dir / "P01308_af3_input.json",
json_folds = [
[
self.test_features_dir / "protein_with_ptms.json",
self.test_features_dir / "P61626_af3_input.json",
],
[
self.test_features_dir / "P01308_af3_input.json",
self.test_features_dir / "P61626_af3_input.json",
],
]
protein_list = self.output_dir / "test_multiple_json_jobs.txt"
protein_list.write_text(
"\n".join(json_input.name for json_input in json_inputs) + "\n",
"\n".join(
";".join(json_input.name for json_input in json_fold)
for json_fold in json_folds
)
+ "\n",
encoding="utf-8",
)
@@ -3017,11 +3052,16 @@ class TestAlphaFold3RunModes(_TestBase):
(self.output_dir / "ranking_scores.csv").exists(),
"Shared wrapper output root should not contain flattened AF3 JSON outputs.",
)
self.assertFalse(
any(self.output_dir.glob("*_data.json")),
"Combined JSON folds should not write flat AF3 input JSONs into the shared root.",
)
for json_input in json_inputs:
current_output_dir = self.output_dir / derive_af3_job_name_from_json(
str(json_input)
)
for output_dir_name in (
"protein_with_ptms_and_p61626",
"p01308_and_p61626",
):
current_output_dir = self.output_dir / output_dir_name
self.assertTrue(
current_output_dir.is_dir(),
f"Expected per-job output directory {current_output_dir} to be created.",

View File

@@ -120,6 +120,24 @@ def _timestamp() -> str:
return dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
def _default_gpu_env_lines(*, cpus_per_task: int) -> list[str]:
thread_count = max(1, min(cpus_per_task, 4))
return [
"export PYTHONUNBUFFERED=1",
f'export OMP_NUM_THREADS="${{OMP_NUM_THREADS:-{thread_count}}}"',
f'export MKL_NUM_THREADS="${{MKL_NUM_THREADS:-{thread_count}}}"',
f'export NUMEXPR_NUM_THREADS="${{NUMEXPR_NUM_THREADS:-{thread_count}}}"',
f'export TF_NUM_INTEROP_THREADS="${{TF_NUM_INTEROP_THREADS:-{thread_count}}}"',
f'export TF_NUM_INTRAOP_THREADS="${{TF_NUM_INTRAOP_THREADS:-{thread_count}}}"',
'export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"',
'export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-2}"',
'export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}"',
'export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.8}"',
'export JAX_PLATFORM_NAME="${JAX_PLATFORM_NAME:-gpu}"',
'if [ -z "${XLA_FLAGS:-}" ]; then export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=0 --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"; fi',
]
def _relative_nodeid_prefix(test_file: Path) -> str:
return str(test_file.resolve().relative_to(REPO_ROOT))
@@ -209,6 +227,7 @@ def write_job_script(
job: JobSpec,
python_executable: str,
use_temp_dir: bool,
cpus_per_task: int,
) -> None:
pytest_cmd = [
python_executable,
@@ -228,7 +247,7 @@ def write_job_script(
"#!/bin/bash",
"set -euo pipefail",
f"cd {_quote(str(REPO_ROOT))}",
"export PYTHONUNBUFFERED=1",
*_default_gpu_env_lines(cpus_per_task=cpus_per_task),
"echo \"[$(date)] Running test node:\"",
f"echo {_quote(job.nodeid)}",
"echo \"[$(date)] Host: $(hostname)\"",
@@ -634,7 +653,7 @@ def main() -> int:
stderr_path = log_dir / f"{index:03d}_{slug}.err"
script_path = log_dir / f"{index:03d}_{slug}.sbatch.sh"
rerun_command = (
f"{_quote(args.python)} -m pytest -vv -s {_quote(nodeid)}"
f"{_quote(args.python)} -m pytest -o {_quote('addopts=-ra --strict-markers')} -vv -s {_quote(nodeid)}"
+ (" --use-temp-dir" if args.use_temp_dir else "")
)
job = JobSpec(
@@ -650,6 +669,7 @@ def main() -> int:
job=job,
python_executable=args.python,
use_temp_dir=args.use_temp_dir,
cpus_per_task=args.cpus_per_task,
)
jobs.append(job)

View File

@@ -650,10 +650,21 @@ def main() -> int:
stdout_path = log_dir / f"{index:03d}_{slug}.out"
stderr_path = log_dir / f"{index:03d}_{slug}.err"
script_path = log_dir / f"{index:03d}_{slug}.sbatch.sh"
rerun_command = (
f"{_quote(args.python)} -m pytest -vv -s {_quote(nodeid)}"
+ (" --use-temp-dir" if args.use_temp_dir else "")
)
rerun_parts = [
_quote(args.python),
"-m",
"pytest",
"-o",
_quote("addopts=-ra --strict-markers"),
"-vv",
"-s",
_quote(nodeid),
]
if args.use_temp_dir:
rerun_parts.append("--use-temp-dir")
rerun_command = " ".join(rerun_parts)
if args.include_perf:
rerun_command = f"AF3_RUN_PERF_TESTS=1 {rerun_command}"
job = JobSpec(
index=index,
nodeid=nodeid,

View File

@@ -1,2 +1,9 @@
Functional tests live here when they are deterministic, CPU-safe, and heavier than the
unit/integration layers. GPU, Slurm, or external-tool smoke tests belong under `test/cluster/`.
unit/integration layers.
Some functional suites may still carry the `external_tools` marker when they shell
into the real feature-generation stack or depend on heavyweight local runtimes.
Those stay in this directory because they are package-level workflow tests, but
they are still excluded from the default CPU-only pytest invocation.
GPU or Slurm smoke wrappers belong under `test/cluster/`.

View File

@@ -1,3 +1,11 @@
"""Template-heavy feature-generation smoke tests.
These exercise the real CLI path for custom-template feature creation. They are
kept under `test/functional/`, but still marked `external_tools` because the
subprocess touches the ColabFold/JAX import chain and local bioinformatics
tooling that is not available in the lightweight default test environment.
"""
import os
import sys
import shutil
@@ -9,11 +17,15 @@ import lzma
import json
import numpy as np
import pytest
from absl.testing import absltest, parameterized
from alphapulldown.utils.remove_clashes_low_plddt import extract_seqs
pytestmark = pytest.mark.external_tools
class TestCreateIndividualFeaturesWithTemplates(parameterized.TestCase):
def setUp(self):

View File

@@ -5,11 +5,14 @@ Tests both AlphaFold2 and AlphaFold3 pipelines with various configurations.
"""
import os
import sys
import tempfile
import json
import lzma
import pickle
import pytest
import logging
import types
from pathlib import Path
from unittest.mock import patch, MagicMock, mock_open
@@ -22,8 +25,9 @@ logger = logging.getLogger(__name__)
# Minimal real MonomericObject for pickling
class DummyMonomer:
def __init__(self, description):
def __init__(self, description, sequence=None):
self.description = description
self.sequence = sequence
self.feature_dict = {}
self.uniprot_runner = None
def make_features(self, *a, **k):
@@ -33,6 +37,20 @@ class DummyMonomer:
def all_seq_msa_features(self, *a, **k):
return {}
class RecordingDummyMonomer(DummyMonomer):
def __init__(self, description, sequence=None):
super().__init__(description, sequence)
self.feature_calls = []
self.mmseq_calls = []
def make_features(self, *args, **kwargs):
self.feature_calls.append(kwargs)
def make_mmseq_features(self, *args, **kwargs):
self.mmseq_calls.append(kwargs)
class DummyJsonObj:
def to_json(self):
return '{"test": "features"}'
@@ -44,6 +62,60 @@ def real_write_text(self, content, *args, **kwargs):
f.write(content)
return len(content)
def build_af3_stub_modules():
alphafold3_pkg = types.ModuleType("alphafold3")
alphafold3_pkg.__path__ = []
common_pkg = types.ModuleType("alphafold3.common")
common_pkg.__path__ = []
structure_pkg = types.ModuleType("alphafold3.structure")
structure_pkg.__path__ = []
folding_input_mod = types.ModuleType("alphafold3.common.folding_input")
mmcif_mod = types.ModuleType("alphafold3.structure.mmcif")
class ProteinChain:
def __init__(self, sequence, id, ptms=None):
self.sequence = sequence
self.id = id
self.ptms = [] if ptms is None else list(ptms)
class RnaChain:
def __init__(self, sequence, id, modifications=None):
self.sequence = sequence
self.id = id
self.modifications = [] if modifications is None else list(modifications)
class DnaChain:
def __init__(self, sequence, id, modifications=None):
self.sequence = sequence
self.id = id
self.modifications = [] if modifications is None else list(modifications)
class Input:
def __init__(self, name, chains, rng_seeds):
self.name = name
self.chains = list(chains)
self.rng_seeds = list(rng_seeds)
folding_input_mod.ProteinChain = ProteinChain
folding_input_mod.RnaChain = RnaChain
folding_input_mod.DnaChain = DnaChain
folding_input_mod.Input = Input
mmcif_mod.int_id_to_str_id = lambda idx: chr(ord("A") + idx - 1)
alphafold3_pkg.common = common_pkg
alphafold3_pkg.structure = structure_pkg
common_pkg.folding_input = folding_input_mod
structure_pkg.mmcif = mmcif_mod
return {
"alphafold3": alphafold3_pkg,
"alphafold3.common": common_pkg,
"alphafold3.common.folding_input": folding_input_mod,
"alphafold3.structure": structure_pkg,
"alphafold3.structure.mmcif": mmcif_mod,
}, folding_input_mod
class TestCreateIndividualFeaturesComprehensive:
"""Comprehensive test cases for create_individual_features.py."""
@@ -98,10 +170,10 @@ class TestCreateIndividualFeaturesComprehensive:
("alphafold2", "multi_protein.fasta", False, False),
("alphafold2", "single_protein.fasta", True, False), # mmseqs2
("alphafold2", "single_protein.fasta", False, True), # compressed
#("alphafold3", "single_protein.fasta", False, False),
#("alphafold3", "multi_protein.fasta", False, False),
#("alphafold3", "rna.fasta", False, False),
#("alphafold3", "dna.fasta", False, False),
("alphafold3", "single_protein.fasta", False, False),
("alphafold3", "multi_protein.fasta", False, False),
("alphafold3", "rna.fasta", False, False),
("alphafold3", "dna.fasta", False, False),
])
def test_feature_creation(self, pipeline, fasta_file, use_mmseqs2, compress_features):
"""Test feature creation for different configurations."""
@@ -157,39 +229,199 @@ class TestCreateIndividualFeaturesComprehensive:
logger.info(f"Verified file exists: {file_path}")
else:
logger.info("Testing AlphaFold3 pipeline")
with patch.object(create_features, 'create_pipeline_af3') as mock_af3_pipeline, \
patch('alphapulldown.scripts.create_individual_features.folding_input') as mock_folding_input, \
af3_modules, folding_input_stub = build_af3_stub_modules()
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, 'create_pipeline_af3') as mock_af3_pipeline, \
patch.object(create_features, 'folding_input', folding_input_stub), \
patch('pathlib.Path.write_text', new=real_write_text), \
patch('alphapulldown.utils.save_meta_data.get_meta_dict', return_value={}):
mock_af3_pipeline.return_value = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
# Patch chain classes in folding_input
mock_folding_input.ProteinChain = lambda sequence, id, ptms: MagicMock()
mock_folding_input.RnaChain = lambda sequence, id, modifications=None: MagicMock()
mock_folding_input.DnaChain = lambda sequence, id: MagicMock()
mock_folding_input.Input = lambda name, chains, rng_seeds: MagicMock()
create_features.create_af3_individual_features()
process_calls = mock_af3_pipeline.return_value.process.call_args_list
observed_chain_types = [
type(call.args[0].chains[0]).__name__ for call in process_calls
]
expected_files = []
expected_chain_types = []
if fasta_file == "single_protein.fasta":
expected_files.append("A0A024R1R8_af3_input.json")
expected_chain_types.append("ProteinChain")
elif fasta_file == "multi_protein.fasta":
expected_files.extend(["A0A024R1R8_af3_input.json", "P61626_af3_input.json"])
expected_chain_types.extend(["ProteinChain", "ProteinChain"])
elif fasta_file == "rna.fasta":
expected_files.append("RNA_TEST_af3_input.json")
expected_chain_types.append("RnaChain")
elif fasta_file == "dna.fasta":
expected_files.append("DNA_TEST_af3_input.json")
expected_chain_types.append("DnaChain")
logger.info(f"Checking for expected files: {expected_files}")
assert observed_chain_types == expected_chain_types
for expected_file in expected_files:
file_path = os.path.join(output_dir, expected_file)
# Simulate file creation
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
Path(file_path).write_text('{"test": "features"}')
assert os.path.exists(file_path), f"Expected file {file_path} not found"
logger.info(f"Verified file exists: {file_path}")
logger.info("Feature creation test completed successfully")
def test_af3_invalid_sequence_is_logged_and_skipped(self):
"""Invalid AF3 sequences should not create output files."""
invalid_fasta = os.path.join(self.fasta_dir, "invalid_af3.fasta")
with open(invalid_fasta, "w") as handle:
handle.write(">INVALID\nACDZ*\n")
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.data_pipeline = "alphafold3"
FLAGS.fasta_paths = [invalid_fasta]
FLAGS.data_dir = self.af3_db
FLAGS.output_dir = os.path.join(self.test_dir, "output_invalid_af3")
FLAGS.max_template_date = "2021-09-30"
error_messages = []
af3_modules, folding_input_stub = build_af3_stub_modules()
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, "create_pipeline_af3") as mock_af3_pipeline, \
patch.object(create_features, "folding_input", folding_input_stub), \
patch.object(create_features.logging, "error", side_effect=error_messages.append):
mock_af3_pipeline.return_value = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
create_features.create_af3_individual_features()
mock_af3_pipeline.return_value.process.assert_not_called()
assert not os.path.exists(
os.path.join(FLAGS.output_dir, "INVALID_af3_input.json")
)
assert any("Failed to create AlphaFold3 input object" in message for message in error_messages)
def test_create_individual_features_truemultimer_respects_seq_index(self):
"""TrueMultimer mode should only process the selected CSV row."""
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.description_file = os.path.join(self.test_dir, "description.csv")
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "multi_protein.fasta")]
FLAGS.path_to_mmt = os.path.join(self.test_dir, "templates")
FLAGS.multiple_mmts = True
FLAGS.seq_index = 2
feats = [
{"protein": "prot1"},
{"protein": "prot2"},
{"protein": "prot3"},
]
with patch.object(create_features, "parse_csv_file", return_value=feats) as mock_parse, \
patch.object(create_features, "process_multimeric_features") as mock_process:
create_features.create_individual_features_truemultimer()
mock_parse.assert_called_once_with(
FLAGS.description_file,
FLAGS.fasta_paths,
FLAGS.path_to_mmt,
FLAGS.multiple_mmts,
)
mock_process.assert_called_once_with(feats[1], 2)
def test_process_multimeric_features_rejects_missing_templates(self):
"""TrueMultimer mode should fail early if a template path is missing."""
feat = {
"protein": "complexA",
"chains": ["A"],
"templates": [os.path.join(self.test_dir, "missing_template.cif")],
"sequence": "ACDE",
}
with pytest.raises(FileNotFoundError, match="does not exist"):
create_features.process_multimeric_features(feat, 1)
def test_process_multimeric_features_creates_custom_db_and_saves_monomer(self):
"""TrueMultimer processing should build a custom DB and hand a monomer to the saver."""
template_path = os.path.join(self.test_dir, "template1.cif")
Path(template_path).write_text("data_template\n", encoding="utf-8")
class RecordingMonomer:
def __init__(self, description, sequence):
self.description = description
self.sequence = sequence
self.feature_dict = {}
self.uniprot_runner = None
feat = {
"protein": "complexB",
"chains": ["A", "B"],
"templates": [template_path],
"sequence": "ACDEFG",
}
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.output_dir = os.path.join(self.test_dir, "truemultimer_output")
FLAGS.data_dir = self.af2_db
FLAGS.max_template_date = "2021-09-30"
FLAGS.use_mmseqs2 = False
FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer"
FLAGS.uniprot_database_path = "/db/uniprot.fasta"
with patch.object(create_features, "MonomericObject", RecordingMonomer), \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_create_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)
mock_custom_db.assert_called_once()
custom_db_args = mock_custom_db.call_args.args
assert custom_db_args[1:] == (
"complexB",
[template_path],
["A", "B"],
)
mock_create_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_called_once_with()
mock_runner.assert_called_once_with(
FLAGS.jackhmmer_binary_path,
FLAGS.uniprot_database_path,
)
saved_monomer, saved_pipeline = mock_save.call_args.args
assert saved_pipeline == "pipeline"
assert saved_monomer.description == "complexB"
assert saved_monomer.sequence == "ACDEFG"
assert saved_monomer.uniprot_runner == "runner"
def test_main_dispatches_to_truemultimer_for_af2_template_runs(self):
"""The main entrypoint should route AF2 template jobs to the TrueMultimer path."""
from absl import flags
FLAGS = flags.FLAGS
FLAGS(["test"])
FLAGS.data_pipeline = "alphafold2"
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "single_protein.fasta")]
FLAGS.data_dir = self.af2_db
FLAGS.output_dir = os.path.join(self.test_dir, "main_truemultimer")
FLAGS.max_template_date = "2021-09-30"
FLAGS.path_to_mmt = os.path.join(self.test_dir, "templates")
with patch.object(create_features, "check_template_date") as mock_check, \
patch.object(create_features, "create_individual_features_truemultimer") as mock_tm, \
patch.object(create_features, "create_individual_features") as mock_single:
create_features.main([])
mock_check.assert_called_once_with()
mock_tm.assert_called_once_with()
mock_single.assert_not_called()
def test_database_path_mapping(self):
"""Test that database paths are correctly mapped for both pipelines."""
logger.info("Testing database path mapping")
@@ -229,7 +461,7 @@ class TestCreateIndividualFeaturesComprehensive:
FLAGS.data_pipeline = "alphafold3"
FLAGS.data_dir = "/test/db"
with pytest.raises(ImportError):
with pytest.raises(ImportError, match="pip install -e .*alphafold3,test.*build_data"):
create_features.create_pipeline_af3()
logger.info("AF3 pipeline creation correctly failed with ImportError")
@@ -669,4 +901,287 @@ class TestCreateIndividualFeaturesComprehensive:
with pytest.raises(SystemExit):
create_features.main([])
logger.info("Flag validation for MMseqs2 scenarios successful")
logger.info("Flag validation for MMseqs2 scenarios successful")
def test_create_pipeline_af2_uses_hhsearch_template_stack(tmp_flags):
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.use_hhsearch = True
create_features.FLAGS.hhsearch_binary_path = "/bin/hhsearch"
create_features.FLAGS.pdb70_database_path = "/db/pdb70"
create_features.FLAGS.template_mmcif_dir = "/db/mmcif"
create_features.FLAGS.max_template_date = "2021-09-30"
create_features.FLAGS.kalign_binary_path = "/bin/kalign"
create_features.FLAGS.obsolete_pdbs_path = "/db/obsolete.dat"
with patch.object(create_features.hhsearch, "HHSearch", return_value="searcher") as mock_searcher, \
patch.object(create_features.templates, "HhsearchHitFeaturizer", return_value="featurizer") as mock_featurizer, \
patch.object(create_features, "AF2DataPipeline", return_value="pipeline") as mock_pipeline:
pipeline = create_features.create_pipeline_af2()
assert pipeline == "pipeline"
mock_searcher.assert_called_once_with(
binary_path="/bin/hhsearch",
databases=["/db/pdb70"],
)
mock_featurizer.assert_called_once_with(
mmcif_dir="/db/mmcif",
max_template_date="2021-09-30",
max_hits=20,
kalign_binary_path="/bin/kalign",
release_dates_path=None,
obsolete_pdbs_path="/db/obsolete.dat",
)
assert mock_pipeline.call_args.kwargs["template_searcher"] == "searcher"
assert mock_pipeline.call_args.kwargs["template_featurizer"] == "featurizer"
def test_create_pipeline_af2_uses_hmmsearch_template_stack(tmp_flags):
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.use_hhsearch = False
create_features.FLAGS.hmmsearch_binary_path = "/bin/hmmsearch"
create_features.FLAGS.hmmbuild_binary_path = "/bin/hmmbuild"
create_features.FLAGS.pdb_seqres_database_path = "/db/pdb_seqres.txt"
create_features.FLAGS.template_mmcif_dir = "/db/mmcif"
create_features.FLAGS.max_template_date = "2021-09-30"
create_features.FLAGS.kalign_binary_path = "/bin/kalign"
create_features.FLAGS.obsolete_pdbs_path = "/db/obsolete.dat"
with patch.object(create_features.hmmsearch, "Hmmsearch", return_value="searcher") as mock_searcher, \
patch.object(create_features.templates, "HmmsearchHitFeaturizer", return_value="featurizer") as mock_featurizer, \
patch.object(create_features, "AF2DataPipeline", return_value="pipeline") as mock_pipeline:
pipeline = create_features.create_pipeline_af2()
assert pipeline == "pipeline"
mock_searcher.assert_called_once_with(
binary_path="/bin/hmmsearch",
hmmbuild_binary_path="/bin/hmmbuild",
database_path="/db/pdb_seqres.txt",
)
mock_featurizer.assert_called_once_with(
mmcif_dir="/db/mmcif",
max_template_date="2021-09-30",
max_hits=20,
kalign_binary_path="/bin/kalign",
obsolete_pdbs_path="/db/obsolete.dat",
release_dates_path=None,
)
assert mock_pipeline.call_args.kwargs["template_searcher"] == "searcher"
assert mock_pipeline.call_args.kwargs["template_featurizer"] == "featurizer"
def test_create_individual_features_only_saves_selected_sequence(tmp_flags):
create_features.FLAGS.seq_index = 2
with patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
patch.object(create_features, "MonomericObject", DummyMonomer), \
patch.object(create_features, "iter_seqs", return_value=[("AAAA", "first"), ("BBBB", "second")]), \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.create_individual_features()
mock_arguments.assert_called_once_with()
mock_pipeline.assert_called_once_with()
mock_runner.assert_called_once()
saved_monomer, saved_pipeline = mock_save.call_args.args
assert saved_pipeline == "pipeline"
assert saved_monomer.description == "second"
assert saved_monomer.uniprot_runner == "runner"
def test_create_and_save_monomer_objects_writes_compressed_af2_outputs(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = True
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = False
create_features.FLAGS.use_precomputed_msas = True
create_features.FLAGS.save_msa_files = True
monomer = RecordingDummyMonomer("protA")
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(monomer, pipeline="pipeline")
metadata_files = list(tmp_path.glob("protA_feature_metadata_*.json.xz"))
assert len(metadata_files) == 1
with lzma.open(metadata_files[0], "rt", encoding="utf-8") as handle:
assert json.load(handle) == {"source": "test"}
assert (tmp_path / "protA.pkl.xz").exists()
assert monomer.feature_calls == [
{
"pipeline": "pipeline",
"output_dir": str(tmp_path),
"use_precomputed_msa": True,
"save_msa": True,
}
]
assert monomer.mmseq_calls == []
def test_create_and_save_monomer_objects_skips_existing_outputs(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = True
create_features.FLAGS.use_mmseqs2 = True
existing_pickle = tmp_path / "protA.pkl"
existing_pickle.write_bytes(b"already-there")
monomer = RecordingDummyMonomer("protA")
create_features.create_and_save_monomer_objects(monomer, pipeline=None)
assert monomer.feature_calls == []
assert monomer.mmseq_calls == []
assert list(tmp_path.glob("protA_feature_metadata_*.json")) == []
def test_create_and_save_monomer_objects_uses_mmseqs_when_requested(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.compress_features = False
create_features.FLAGS.skip_existing = False
create_features.FLAGS.use_mmseqs2 = True
create_features.FLAGS.use_precomputed_msas = True
create_features.FLAGS.re_search_templates_mmseqs2 = True
monomer = RecordingDummyMonomer("protA")
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
create_features.create_and_save_monomer_objects(monomer, pipeline=None)
assert monomer.feature_calls == []
assert monomer.mmseq_calls == [
{
"DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER,
"output_dir": str(tmp_path),
"use_precomputed_msa": True,
"use_templates": True,
}
]
assert (tmp_path / "protA.pkl").exists()
def test_process_multimeric_features_uses_mmseqs_without_local_pipeline(tmp_flags, tmp_path):
template_path = tmp_path / "template.cif"
template_path.write_text("data_template\n", encoding="utf-8")
create_features.FLAGS.output_dir = str(tmp_path / "out")
create_features.FLAGS.use_mmseqs2 = True
feat = {
"protein": "complex_mmseqs",
"chains": ["A"],
"templates": [str(template_path)],
"sequence": "ACDE",
}
with patch.object(create_features, "MonomericObject", RecordingDummyMonomer), \
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
patch.object(create_features, "create_arguments") as mock_arguments, \
patch.object(create_features, "create_pipeline_af2") as mock_pipeline, \
patch.object(create_features, "create_uniprot_runner") as mock_runner, \
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
create_features.process_multimeric_features(feat, 1)
mock_custom_db.assert_called_once()
mock_arguments.assert_called_once_with("/tmp/custom_db")
mock_pipeline.assert_not_called()
mock_runner.assert_not_called()
saved_monomer, saved_pipeline = mock_save.call_args.args
assert saved_pipeline is None
assert saved_monomer.description == "complex_mmseqs"
assert saved_monomer.uniprot_runner is None
def test_create_custom_db_passes_thresholds_to_builder(tmp_flags):
create_features.FLAGS.threshold_clashes = 12.5
create_features.FLAGS.hb_allowance = 0.7
create_features.FLAGS.plddt_threshold = 42.0
with patch.object(create_features, "create_db") as mock_create_db:
db_path = create_features.create_custom_db("/tmp/base", "proteinX", ["a.cif"], ["A"])
assert str(db_path) == "/tmp/base/custom_template_db/proteinX"
mock_create_db.assert_called_once_with(
db_path,
["a.cif"],
["A"],
12.5,
0.7,
42.0,
)
def test_create_pipeline_af3_prefers_explicit_database_overrides(tmp_flags):
class DummyConfig:
def __init__(self, **kwargs):
self.kwargs = kwargs
create_features.FLAGS.max_template_date = "2021-09-30"
create_features.FLAGS.data_pipeline = "alphafold3"
create_features.FLAGS.data_dir = "/db"
create_features.FLAGS.small_bfd_database_path = "/override/small_bfd"
create_features.FLAGS.uniref90_database_path = "/override/uniref90"
create_features.FLAGS.template_mmcif_dir = "/override/mmcif"
with patch.object(create_features, "AF3DataPipelineConfig", side_effect=DummyConfig) as mock_config, \
patch.object(create_features, "AF3DataPipeline", side_effect=lambda config: config) as mock_pipeline:
config = create_features.create_pipeline_af3()
mock_config.assert_called_once()
mock_pipeline.assert_called_once()
assert config.kwargs["small_bfd_database_path"] == "/override/small_bfd"
assert config.kwargs["uniref90_database_path"] == "/override/uniref90"
assert config.kwargs["pdb_database_path"] == "/override/mmcif"
assert config.kwargs["mgnify_database_path"] == "/db/mgy_clusters_2022_05.fa"
assert config.kwargs["seqres_database_path"] == "/db/pdb_seqres_2022_09_28.fasta"
def test_create_af3_individual_features_falls_back_to_double_letter_chain_ids(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.seq_index = 27
af3_modules, folding_input_stub = build_af3_stub_modules()
del af3_modules["alphafold3.structure"].mmcif
af3_modules.pop("alphafold3.structure.mmcif")
sequences = [("ACDE", f"chain_{idx}") for idx in range(1, 28)]
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, "create_pipeline_af3", return_value=MagicMock(process=MagicMock(return_value={"plain": "json"}))), \
patch.object(create_features, "folding_input", folding_input_stub), \
patch.object(create_features, "iter_seqs", return_value=sequences), \
patch("pathlib.Path.write_text", new=real_write_text):
create_features.create_af3_individual_features()
outpath = tmp_path / "chain_27_af3_input.json"
assert outpath.exists()
assert json.loads(outpath.read_text(encoding="utf-8")) == {"plain": "json"}
def test_create_af3_individual_features_skips_existing_outputs(tmp_flags, tmp_path):
create_features.FLAGS.output_dir = str(tmp_path)
create_features.FLAGS.skip_existing = True
af3_modules, folding_input_stub = build_af3_stub_modules()
existing_output = tmp_path / "protA_af3_input.json"
existing_output.write_text("{}", encoding="utf-8")
pipeline = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
with patch.dict(sys.modules, af3_modules), \
patch.object(create_features, "create_pipeline_af3", return_value=pipeline), \
patch.object(create_features, "folding_input", folding_input_stub), \
patch.object(create_features, "iter_seqs", return_value=[("ACDE", "protA")]), \
patch("pathlib.Path.write_text", new=real_write_text):
create_features.create_af3_individual_features()
pipeline.process.assert_not_called()
assert existing_output.read_text(encoding="utf-8") == "{}"
def test_main_dispatches_to_af3_feature_creation(tmp_flags, tmp_path):
create_features.FLAGS.data_pipeline = "alphafold3"
create_features.FLAGS.output_dir = str(tmp_path / "af3_out")
with patch.object(create_features, "create_af3_individual_features") as mock_af3, \
patch.object(create_features, "check_template_date") as mock_check:
create_features.main([])
mock_af3.assert_called_once_with()
mock_check.assert_not_called()

View File

@@ -16,6 +16,8 @@ import pytest
Test conversion of PDB to CIF for monomers and multimers
"""
pytest.importorskip("ihm")
pytest.importorskip("modelcif")
pytestmark = pytest.mark.external_tools

View File

@@ -0,0 +1,973 @@
import importlib.util
import json
import pickle
import sys
import types
from pathlib import Path
from types import SimpleNamespace
import numpy as np
import pytest
MODULE_PATH = (
Path(__file__).resolve().parents[2]
/ "alphapulldown"
/ "folding_backend"
/ "alphafold2_backend.py"
)
def _package(name: str) -> types.ModuleType:
module = types.ModuleType(name)
module.__path__ = [] # type: ignore[attr-defined]
return module
class _ConfigNode(dict):
def __getattr__(self, name):
try:
return self[name]
except KeyError as exc: # pragma: no cover - defensive
raise AttributeError(name) from exc
def __setattr__(self, name, value):
self[name] = value
def _restore_modules(saved_modules: dict[str, types.ModuleType | None]) -> None:
for name, module in saved_modules.items():
if module is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = module
def _install_alphafold2_backend_stubs() -> dict[str, types.ModuleType | None]:
names_to_replace = [
"jax",
"jax.numpy",
"alphafold",
"alphafold.relax",
"alphafold.relax.relax",
"alphafold.common",
"alphafold.common.protein",
"alphafold.common.residue_constants",
"alphafold.common.confidence",
"alphafold.model",
"alphafold.model.config",
"alphafold.model.data",
"alphafold.model.model",
"alphapulldown.objects",
"alphapulldown.utils.plotting",
"alphapulldown.utils.post_modelling",
"alphapulldown.utils.modelling_setup",
"alphapulldown.utils.af2_to_af3_msa",
]
saved_modules = {name: sys.modules.get(name) for name in names_to_replace}
jax_pkg = _package("jax")
jax_numpy_mod = types.ModuleType("jax.numpy")
jax_numpy_mod.ndarray = np.ndarray
jax_numpy_mod.array = np.array
alphafold_pkg = _package("alphafold")
relax_pkg = _package("alphafold.relax")
relax_mod = types.ModuleType("alphafold.relax.relax")
class FakeAmberRelaxation:
def __init__(self, **kwargs):
self.kwargs = kwargs
def process(self, prot):
name = getattr(prot, "name", "protein")
return (f"RELAXED:{name}", None, [0, 1])
relax_mod.AmberRelaxation = FakeAmberRelaxation
common_pkg = _package("alphafold.common")
protein_mod = types.ModuleType("alphafold.common.protein")
class FakeProtein:
def __init__(
self,
*,
atom_positions,
atom_mask,
aatype,
residue_index,
chain_index,
b_factors,
):
self.atom_positions = atom_positions
self.atom_mask = atom_mask
self.aatype = aatype
self.residue_index = residue_index
self.chain_index = chain_index
self.b_factors = b_factors
self.name = "debug_protein"
def _to_pdb(prot):
return f"PDB:{getattr(prot, 'name', 'protein')}"
def _from_prediction(features, result, b_factors, remove_leading_feature_dimension):
residue_index = np.asarray(features.get("residue_index", [0]), dtype=np.int32)
chain_index = np.asarray(features.get("asym_id", np.zeros_like(residue_index)))
return SimpleNamespace(
name="predicted",
features=features,
result=result,
b_factors=b_factors,
residue_index=residue_index,
chain_index=chain_index,
aatype=np.zeros_like(residue_index),
remove_leading_feature_dimension=remove_leading_feature_dimension,
)
protein_mod.Protein = FakeProtein
protein_mod.to_pdb = _to_pdb
protein_mod.from_prediction = _from_prediction
protein_mod.from_pdb_string = lambda text: SimpleNamespace(name=f"from_pdb:{text}")
residue_constants_mod = types.ModuleType("alphafold.common.residue_constants")
residue_constants_mod.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = list(range(64))
residue_constants_mod.restype_num = 20
residue_constants_mod.atom_type_num = 37
confidence_mod = types.ModuleType("alphafold.common.confidence")
confidence_mod.pae_json = lambda pae, max_pae: json.dumps(
{"pae": np.asarray(pae).tolist(), "max_pae": float(max_pae)}
)
confidence_mod.confidence_json = lambda plddt: json.dumps(
{"plddt": np.asarray(plddt).tolist()}
)
confidence_mod.predicted_tm_score = (
lambda logits, breaks, asym_id=None, interface=False: 0.8 if interface else 0.5
)
confidence_mod.compute_predicted_aligned_error = lambda logits, breaks: {
"predicted_aligned_error": np.asarray(logits) + 1,
"max_predicted_aligned_error": 31.0,
}
model_pkg = _package("alphafold.model")
config_mod = types.ModuleType("alphafold.model.config")
config_mod.MODEL_PRESETS = {
"multimer": ("model_1_multimer_v3", "model_2_multimer_v3")
}
def _model_config(_name):
return _ConfigNode(
{
"model": _ConfigNode(
{
"num_ensemble_eval": None,
"global_config": _ConfigNode({"eval_dropout": False}),
"embeddings_and_evoformer": _ConfigNode(
{"num_msa": 64, "num_extra_msa": 256}
),
}
)
}
)
config_mod.model_config = _model_config
data_mod = types.ModuleType("alphafold.model.data")
data_mod.get_model_haiku_params = (
lambda model_name, data_dir: {"model_name": model_name, "data_dir": data_dir}
)
model_mod = types.ModuleType("alphafold.model.model")
class FakeRunModel:
def __init__(self, config, params):
self.config = config
self.params = params
self.multimer_mode = True
def process_features(self, feature_dict, random_seed):
return dict(feature_dict)
def predict(self, processed_feature_dict, random_seed):
seq_len = len(np.asarray(processed_feature_dict.get("residue_index", [0, 1])))
return {
"plddt": np.full(seq_len, 75.0, dtype=np.float32),
"predicted_aligned_error": np.zeros((seq_len, seq_len), dtype=np.float32),
"max_predicted_aligned_error": 31.0,
}
model_mod.RunModel = FakeRunModel
objects_mod = types.ModuleType("alphapulldown.objects")
class MonomericObject:
def __init__(self, description="monomer", sequence=""):
self.description = description
self.sequence = sequence
self.feature_dict = {}
self.multimeric_mode = False
class MultimericObject:
def __init__(
self,
description="multimer",
input_seqs=None,
feature_dict=None,
multimeric_mode=True,
):
self.description = description
self.input_seqs = input_seqs or []
self.feature_dict = feature_dict or {}
self.multimeric_mode = multimeric_mode
class ChoppedObject(MonomericObject):
pass
objects_mod.MonomericObject = MonomericObject
objects_mod.MultimericObject = MultimericObject
objects_mod.ChoppedObject = ChoppedObject
plotting_mod = types.ModuleType("alphapulldown.utils.plotting")
plotting_mod.plot_pae_from_matrix = lambda **_kwargs: None
post_modelling_mod = types.ModuleType("alphapulldown.utils.post_modelling")
post_modelling_mod.post_prediction_process = lambda *args, **kwargs: None
modelling_setup_mod = types.ModuleType("alphapulldown.utils.modelling_setup")
modelling_setup_mod.pad_input_features = (
lambda feature_dict, desired_num_msa, desired_num_res: None
)
af2_to_af3_msa_mod = types.ModuleType("alphapulldown.utils.af2_to_af3_msa")
af2_to_af3_msa_mod.msa_rows_and_deletions_to_a3m = (
lambda msa_rows, deletion_rows, query_sequence: (
f">query\n{query_sequence}\n>rows\n{len(np.asarray(msa_rows))}\n"
)
)
modules = {
"jax": jax_pkg,
"jax.numpy": jax_numpy_mod,
"alphafold": alphafold_pkg,
"alphafold.relax": relax_pkg,
"alphafold.relax.relax": relax_mod,
"alphafold.common": common_pkg,
"alphafold.common.protein": protein_mod,
"alphafold.common.residue_constants": residue_constants_mod,
"alphafold.common.confidence": confidence_mod,
"alphafold.model": model_pkg,
"alphafold.model.config": config_mod,
"alphafold.model.data": data_mod,
"alphafold.model.model": model_mod,
"alphapulldown.objects": objects_mod,
"alphapulldown.utils.plotting": plotting_mod,
"alphapulldown.utils.post_modelling": post_modelling_mod,
"alphapulldown.utils.modelling_setup": modelling_setup_mod,
"alphapulldown.utils.af2_to_af3_msa": af2_to_af3_msa_mod,
}
for name, module in modules.items():
sys.modules[name] = module
jax_pkg.numpy = jax_numpy_mod
alphafold_pkg.relax = relax_pkg
alphafold_pkg.common = common_pkg
alphafold_pkg.model = model_pkg
relax_pkg.relax = relax_mod
common_pkg.protein = protein_mod
common_pkg.residue_constants = residue_constants_mod
common_pkg.confidence = confidence_mod
model_pkg.config = config_mod
model_pkg.data = data_mod
model_pkg.model = model_mod
return saved_modules
@pytest.fixture(scope="module")
def af2_backend_module():
saved_modules = _install_alphafold2_backend_stubs()
sys.modules.pop("alphapulldown.folding_backend.alphafold2_backend", None)
spec = importlib.util.spec_from_file_location(
"alphapulldown.folding_backend.alphafold2_backend",
MODULE_PATH,
)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
assert spec.loader is not None
try:
spec.loader.exec_module(module)
yield module
finally:
sys.modules.pop(spec.name, None)
_restore_modules(saved_modules)
def test_json_template_and_alignment_helpers(af2_backend_module, tmp_path):
missing = af2_backend_module._read_from_json_if_exists(tmp_path / "missing.json")
assert missing == {}
existing_path = tmp_path / "timings.json"
existing_path.write_text(json.dumps({"stage": 1}), encoding="utf-8")
assert af2_backend_module._read_from_json_if_exists(existing_path) == {"stage": 1}
feature_dict = {
"seq_length": 3,
"template_aatype": np.ones((2, 3), dtype=np.int32),
"template_all_atom_positions": np.ones((2, 3, 37, 3), dtype=np.float32),
"template_all_atom_mask": np.zeros((2, 3, 37), dtype=np.float32),
"num_templates": np.array([5]),
}
af2_backend_module._reset_template_features(feature_dict)
assert feature_dict["template_aatype"].shape == (1, 3)
assert feature_dict["template_all_atom_positions"].shape == (1, 3, 37, 3)
assert np.all(feature_dict["template_all_atom_mask"] == 1.0)
np.testing.assert_array_equal(feature_dict["num_templates"], np.array([1]))
assert af2_backend_module._normalise_num_alignments_for_debug({}) == 0
assert (
af2_backend_module._normalise_num_alignments_for_debug({"msa": np.zeros((2, 3))})
== 2
)
assert (
af2_backend_module._normalise_num_alignments_for_debug(
{"msa": np.zeros((2, 3)), "num_alignments": np.array([9])}
)
== 2
)
def test_ensure_typing_dataclass_transform_backfills_missing_attribute(
af2_backend_module, monkeypatch
):
monkeypatch.delattr(af2_backend_module.typing, "dataclass_transform", raising=False)
af2_backend_module._ensure_typing_dataclass_transform()
assert af2_backend_module.typing.dataclass_transform is not None
def test_resolve_gpu_relax_keeps_cuda_when_available(af2_backend_module, monkeypatch):
monkeypatch.setattr(
af2_backend_module,
"_get_openmm_platform_names",
lambda: ["Reference", "CPU", "CUDA"],
)
assert af2_backend_module._resolve_gpu_relax(True) is True
def test_resolve_gpu_relax_falls_back_when_cuda_is_missing(
af2_backend_module, monkeypatch, caplog
):
monkeypatch.setattr(
af2_backend_module,
"_get_openmm_platform_names",
lambda: ["Reference", "CPU", "OpenCL"],
)
assert af2_backend_module._resolve_gpu_relax(True) is False
assert "falling back to CPU relax" in caplog.text
def test_asym_query_and_msa_debug_helpers(af2_backend_module, tmp_path):
normalized = af2_backend_module._normalize_asym_id(
{"asym_id": np.array([5, 5, 2, 9], dtype=np.int32)}
)
np.testing.assert_array_equal(normalized["asym_id"], np.array([0, 0, 1, 2]))
fallback_normalized = af2_backend_module._normalize_asym_id(
{},
fallback_feature_dict={"asym_id": np.array([1, 1, 3], dtype=np.int32)},
)
np.testing.assert_array_equal(fallback_normalized["asym_id"], np.array([0, 0, 1]))
multimer = af2_backend_module.MultimericObject(
description="job",
input_seqs=["AA", "BB"],
feature_dict={},
)
monomer = af2_backend_module.MonomericObject("single", "AC")
assert af2_backend_module._query_sequence_for_debug(multimer) == "AABB"
assert af2_backend_module._query_sequence_for_debug(monomer) == "AC"
af2_backend_module._write_processed_msa_debug_artifact(
processed_feature_dict={
"msa": np.array([[1, 2], [3, 4]], dtype=np.int32),
"deletion_matrix_int": np.zeros((2, 2), dtype=np.int32),
"num_alignments": np.array([1]),
},
multimeric_object=monomer,
output_dir=tmp_path,
model_name="modelA",
)
debug_path = tmp_path / "modelA_processed_msa.a3m"
assert debug_path.read_text(encoding="utf-8") == ">query\nAC\n>rows\n1\n"
def test_template_debug_helpers_cover_success_and_failure_paths(
af2_backend_module,
monkeypatch,
tmp_path,
):
assert af2_backend_module._decode_debug_value(b"templ") == "templ"
assert (
af2_backend_module._sanitize_debug_filename("bad/name:*", "fallback")
== "bad_name__"
)
one_hot = np.eye(22, dtype=np.int32)[[1, 2]]
np.testing.assert_array_equal(
af2_backend_module._template_aatype_to_indices(one_hot),
np.array([1, 2], dtype=np.int32),
)
np.testing.assert_array_equal(
af2_backend_module._template_aatype_to_indices(np.array([0, 21], dtype=np.int32)),
np.array([0, 20], dtype=np.int32),
)
with pytest.raises(ValueError, match="Unsupported template_aatype rank"):
af2_backend_module._template_aatype_to_indices(np.zeros((1, 1, 1, 1)))
call_counter = {"value": 0}
def fake_protein_constructor(**kwargs):
current = call_counter["value"]
call_counter["value"] += 1
if current == 1:
raise RuntimeError("boom")
return SimpleNamespace(name="template", **kwargs)
monkeypatch.setattr(af2_backend_module.protein, "Protein", fake_protein_constructor)
monkeypatch.setattr(af2_backend_module.protein, "to_pdb", lambda prot: "PDB:TEMPLATE")
af2_backend_module._write_processed_template_debug_artifacts(
processed_feature_dict={
"template_all_atom_positions": np.ones((2, 2, 37, 3), dtype=np.float32),
"template_all_atom_mask": np.ones((2, 2, 37), dtype=np.float32),
"template_aatype": np.stack([np.eye(22)[[0, 1]], np.eye(22)[[0, 1]]]),
"template_domain_names": [b"good/template", b"bad"],
"residue_index": np.array([7, 8], dtype=np.int32),
"asym_id": np.array([3, 3], dtype=np.int32),
},
output_dir=tmp_path,
model_name="modelB",
)
debug_dir = tmp_path / "templates_debug"
assert (debug_dir / "modelB_good_template_idx0.pdb").is_file()
error_file = debug_dir / "ERROR_modelB_template_1.txt"
assert error_file.is_file()
assert "boom" in error_file.read_text(encoding="utf-8")
def test_setup_configures_model_runners_and_validates_custom_names(af2_backend_module):
configured = af2_backend_module.AlphaFold2Backend.setup(
model_name="multimer",
num_cycle=5,
model_dir="/models",
num_predictions_per_model=2,
msa_depth_scan=True,
model_names_custom=["model_1_multimer"],
dropout=True,
)
runners = configured["model_runners"]
assert sorted(runners) == [
"model_1_multimer_pred_0_msa_16",
"model_1_multimer_pred_1_msa_64",
]
runner = runners["model_1_multimer_pred_0_msa_16"]
assert runner.config["model"]["num_recycle"] == 5
assert runner.config.model.global_config.eval_dropout is True
assert runner.params["data_dir"] == "/models"
with pytest.raises(Exception, match="Provided model names"):
af2_backend_module.AlphaFold2Backend.setup(
model_name="multimer",
num_cycle=1,
model_dir="/models",
num_predictions_per_model=1,
model_names_custom=["missing_model"],
)
def test_predict_individual_job_writes_outputs_and_runs_debug_hooks(
af2_backend_module,
monkeypatch,
tmp_path,
):
multimer = af2_backend_module.MultimericObject(
description="complex",
input_seqs=["AB"],
feature_dict={"msa": np.ones((1, 2), dtype=np.int32)},
multimeric_mode=True,
)
pad_calls = []
msa_calls = []
template_calls = []
def fake_pad(feature_dict, desired_num_msa, desired_num_res):
pad_calls.append((desired_num_res, desired_num_msa))
processed_feature_dict = {
"msa": np.ones((1, 2), dtype=np.int32),
"template_all_atom_positions": np.ones((1, 2, 37, 3), dtype=np.float32),
"template_all_atom_mask": np.ones((1, 2, 37), dtype=np.float32),
"template_aatype": np.eye(22)[[0, 1]][None, :, :],
"residue_index": np.array([0, 1], dtype=np.int32),
"asym_id": np.array([1, 2], dtype=np.int32),
}
fake_runner = SimpleNamespace(
multimer_mode=True,
process_features=lambda feature_dict, random_seed: dict(processed_feature_dict),
predict=lambda processed_feature_dict, random_seed: {
"plddt": np.array([91.0, 88.0], dtype=np.float32),
"predicted_aligned_error": np.zeros((2, 2), dtype=np.float32),
"max_predicted_aligned_error": 31.0,
},
)
monkeypatch.setattr(af2_backend_module, "pad_input_features", fake_pad)
monkeypatch.setattr(
af2_backend_module,
"_write_processed_msa_debug_artifact",
lambda **kwargs: msa_calls.append(kwargs["model_name"]),
)
monkeypatch.setattr(
af2_backend_module,
"_write_processed_template_debug_artifacts",
lambda **kwargs: template_calls.append(kwargs["model_name"]),
)
results = af2_backend_module.AlphaFold2Backend.predict_individual_job(
model_runners={"modelA": fake_runner},
multimeric_object=multimer,
allow_resume=False,
skip_templates=False,
output_dir=tmp_path,
random_seed=7,
desired_num_res=4,
desired_num_msa=3,
debug_msas=True,
debug_templates=True,
)
assert pad_calls == [(4, 3)]
assert msa_calls == ["modelA"]
assert template_calls == ["modelA"]
assert "modelA" in results
assert results["modelA"]["seqs"] == ["AB"]
assert (tmp_path / "result_modelA.pkl").is_file()
assert (tmp_path / "unrelaxed_modelA.pdb").read_text(encoding="utf-8") == "PDB:predicted"
timings = json.loads((tmp_path / "timings.json").read_text(encoding="utf-8"))
assert "process_features_modelA" in timings
assert "predict_and_compile_modelA" in timings
with open(tmp_path / "result_modelA.pkl", "rb") as handle:
payload = pickle.load(handle)
assert payload["plddt"].tolist() == [91.0, 88.0]
def test_predict_individual_job_rejects_skipped_templates_in_multimer_mode(
af2_backend_module,
tmp_path,
):
multimer = af2_backend_module.MultimericObject(
description="complex",
input_seqs=["AB"],
feature_dict={"msa": np.ones((1, 2), dtype=np.int32)},
multimeric_mode=True,
)
fake_runner = SimpleNamespace(
multimer_mode=True,
process_features=lambda feature_dict, random_seed: {
"seq_length": 2,
"template_aatype": np.ones((1, 2), dtype=np.int32),
"template_all_atom_positions": np.zeros((1, 2, 37, 3), dtype=np.float32),
"template_all_atom_mask": np.ones((1, 2, 37), dtype=np.float32),
"num_templates": np.array([1], dtype=np.int32),
},
predict=lambda *args, **kwargs: None,
)
with pytest.raises(ValueError, match="cannot skip templates"):
af2_backend_module.AlphaFold2Backend.predict_individual_job(
model_runners={"modelA": fake_runner},
multimeric_object=multimer,
allow_resume=False,
skip_templates=True,
output_dir=tmp_path,
random_seed=1,
)
def test_predict_individual_job_resumes_completed_models_and_returns_early(
af2_backend_module,
tmp_path,
):
monomer = af2_backend_module.MonomericObject("single", "AB")
monomer.feature_dict = {"residue_index": np.array([0, 1], dtype=np.int32)}
result_path = tmp_path / "result_modelA.pkl"
with open(result_path, "wb") as handle:
pickle.dump({"plddt": np.array([91.0, 88.0], dtype=np.float32)}, handle, protocol=4)
(tmp_path / "unrelaxed_modelA.pdb").write_text("existing pdb", encoding="utf-8")
process_calls = []
predict_calls = []
fake_runner = SimpleNamespace(
multimer_mode=False,
process_features=lambda feature_dict, random_seed: process_calls.append(
random_seed
)
or dict(feature_dict),
predict=lambda *args, **kwargs: predict_calls.append((args, kwargs)),
)
results = af2_backend_module.AlphaFold2Backend.predict_individual_job(
model_runners={"modelA": fake_runner},
multimeric_object=monomer,
allow_resume=True,
skip_templates=False,
output_dir=tmp_path,
random_seed=11,
)
assert predict_calls == []
assert process_calls == [11]
assert results["modelA"]["seqs"] == ["AB"]
assert results["modelA"]["unrelaxed_protein"].name == "predicted"
def test_predict_individual_job_rejects_missing_or_zero_template_positions(
af2_backend_module,
tmp_path,
):
multimer = af2_backend_module.MultimericObject(
description="complex",
input_seqs=["AB"],
feature_dict={"msa": np.ones((1, 2), dtype=np.int32)},
multimeric_mode=True,
)
missing_template_runner = SimpleNamespace(
multimer_mode=True,
process_features=lambda feature_dict, random_seed: {},
predict=lambda *args, **kwargs: None,
)
zero_template_runner = SimpleNamespace(
multimer_mode=True,
process_features=lambda feature_dict, random_seed: {
"template_all_atom_positions": np.zeros((1, 2, 37, 3), dtype=np.float32),
},
predict=lambda *args, **kwargs: None,
)
with pytest.raises(ValueError, match="No template_all_atom_positions key found"):
af2_backend_module.AlphaFold2Backend.predict_individual_job(
model_runners={"modelA": missing_template_runner},
multimeric_object=multimer,
allow_resume=False,
skip_templates=False,
output_dir=tmp_path,
random_seed=3,
)
with pytest.raises(ValueError, match="No valid templates found"):
af2_backend_module.AlphaFold2Backend.predict_individual_job(
model_runners={"modelA": zero_template_runner},
multimeric_object=multimer,
allow_resume=False,
skip_templates=False,
output_dir=tmp_path,
random_seed=3,
)
def test_predict_yields_results_for_each_object(af2_backend_module, monkeypatch, tmp_path):
monomer_a = af2_backend_module.MonomericObject("a", "AA")
monomer_b = af2_backend_module.MonomericObject("b", "BB")
calls = []
monkeypatch.setattr(
af2_backend_module.AlphaFold2Backend,
"predict_individual_job",
staticmethod(
lambda **kwargs: calls.append(kwargs["multimeric_object"].description)
or {"modelA": kwargs["multimeric_object"].description}
),
)
outputs = list(
af2_backend_module.AlphaFold2Backend.predict(
model_runners={"modelA": object()},
objects_to_model=[
{"object": monomer_a, "output_dir": str(tmp_path / "a")},
{"object": monomer_b, "output_dir": str(tmp_path / "b")},
],
allow_resume=False,
skip_templates=False,
random_seed=5,
)
)
assert calls == ["a", "b"]
assert outputs == [
{
"object": monomer_a,
"prediction_results": {"modelA": "a"},
"output_dir": str(tmp_path / "a"),
},
{
"object": monomer_b,
"prediction_results": {"modelA": "b"},
"output_dir": str(tmp_path / "b"),
},
]
def test_recalculate_confidence_handles_multimer_and_monomer_paths(af2_backend_module):
already_numpy = {
"predicted_aligned_error": np.zeros((1, 1), dtype=np.float32),
"plddt": np.array([42.0], dtype=np.float32),
}
assert (
af2_backend_module.AlphaFold2Backend.recalculate_confidence(
already_numpy, multimer_mode=False, total_num_res=1
)
is already_numpy
)
padded = {
"predicted_aligned_error": {
"logits": np.ones((4, 4), dtype=np.float32),
"breaks": np.array([0.5], dtype=np.float32),
"asym_id": np.array([0, 0, 1, 1], dtype=np.int32),
},
"plddt": np.array([10.0, 20.0, 30.0, 40.0], dtype=np.float32),
}
multimer_output = af2_backend_module.AlphaFold2Backend.recalculate_confidence(
padded,
multimer_mode=True,
total_num_res=2,
)
assert multimer_output["ptm"] == 0.5
assert multimer_output["iptm"] == 0.8
assert multimer_output["ranking_confidence"] == pytest.approx(0.74)
assert multimer_output["predicted_aligned_error"].shape == (2, 2)
monomer_output = af2_backend_module.AlphaFold2Backend.recalculate_confidence(
padded,
multimer_mode=False,
total_num_res=2,
)
assert monomer_output["ranking_confidence"] == pytest.approx(15.0)
def test_postprocess_ranks_models_relaxes_best_and_runs_cleanup(
af2_backend_module,
monkeypatch,
tmp_path,
):
plot_calls = []
cleanup_calls = []
subprocess_calls = []
monkeypatch.setattr(
af2_backend_module,
"plot_pae_from_matrix",
lambda **kwargs: plot_calls.append(kwargs["ranking"]),
)
monkeypatch.setattr(
af2_backend_module,
"post_prediction_process",
lambda *args, **kwargs: cleanup_calls.append((args, kwargs)),
)
monkeypatch.setattr(
af2_backend_module.subprocess,
"run",
lambda *args, **kwargs: subprocess_calls.append((args, kwargs))
or SimpleNamespace(stderr="", stdout="", returncode=0),
)
multimer = af2_backend_module.MultimericObject(
description="complex",
input_seqs=["AA", "BB"],
feature_dict={},
multimeric_mode=True,
)
prediction_results = {
"model_low": {
"plddt": np.array([70.0, 71.0, 72.0, 73.0], dtype=np.float32),
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
"max_predicted_aligned_error": 31.0,
"ranking_confidence": 0.2,
"iptm": 0.1,
"ptm": 0.2,
"unrelaxed_protein": SimpleNamespace(name="low"),
"seqs": ["AA", "BB"],
},
"model_high": {
"plddt": np.array([90.0, 91.0, 92.0, 93.0], dtype=np.float32),
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
"max_predicted_aligned_error": 31.0,
"ranking_confidence": 0.9,
"iptm": 0.7,
"ptm": 0.8,
"unrelaxed_protein": SimpleNamespace(name="high"),
"seqs": ["AA", "BB"],
},
}
af2_backend_module.AlphaFold2Backend.postprocess(
prediction_results=prediction_results,
multimeric_object=multimer,
output_dir=tmp_path,
features_directory="/features",
models_to_relax=af2_backend_module.ModelsToRelax.BEST,
convert_to_modelcif=True,
)
ranking = json.loads((tmp_path / "ranking_debug.json").read_text(encoding="utf-8"))
assert ranking["order"] == ["model_high", "model_low"]
assert plot_calls == [0, 1]
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "RELAXED:high"
assert (tmp_path / "ranked_1.pdb").read_text(encoding="utf-8") == "PDB:low"
assert (tmp_path / "relax_metrics.json").is_file()
assert cleanup_calls
assert subprocess_calls
def test_postprocess_relaxes_all_using_saved_unrelaxed_pdbs(
af2_backend_module,
monkeypatch,
tmp_path,
):
cleanup_calls = []
monkeypatch.setattr(
af2_backend_module,
"plot_pae_from_matrix",
lambda **kwargs: None,
)
monkeypatch.setattr(
af2_backend_module,
"post_prediction_process",
lambda *args, **kwargs: cleanup_calls.append((args, kwargs)),
)
monkeypatch.setattr(
af2_backend_module.protein,
"from_pdb_string",
lambda text: SimpleNamespace(name=text.strip()),
)
(tmp_path / "unrelaxed_model_a.pdb").write_text("model_a_unrelaxed", encoding="utf-8")
(tmp_path / "unrelaxed_model_b.pdb").write_text("model_b_unrelaxed", encoding="utf-8")
multimer = af2_backend_module.MultimericObject(
description="complex",
input_seqs=["AA", "BB"],
feature_dict={},
multimeric_mode=True,
)
prediction_results = {
"model_a": {
"plddt": np.array([80.0, 81.0, 82.0, 83.0], dtype=np.float32),
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
"max_predicted_aligned_error": 31.0,
"ranking_confidence": 0.6,
"iptm": 0.4,
"ptm": 0.5,
"seqs": ["AA", "BB"],
},
"model_b": {
"plddt": np.array([90.0, 91.0, 92.0, 93.0], dtype=np.float32),
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
"max_predicted_aligned_error": 31.0,
"ranking_confidence": 0.9,
"iptm": 0.7,
"ptm": 0.8,
"seqs": ["AA", "BB"],
},
}
af2_backend_module.AlphaFold2Backend.postprocess(
prediction_results=prediction_results,
multimeric_object=multimer,
output_dir=tmp_path,
features_directory="/features",
models_to_relax=af2_backend_module.ModelsToRelax.ALL,
convert_to_modelcif=False,
)
assert (tmp_path / "relaxed_model_a.pdb").read_text(encoding="utf-8") == "RELAXED:model_a_unrelaxed"
assert (tmp_path / "relaxed_model_b.pdb").read_text(encoding="utf-8") == "RELAXED:model_b_unrelaxed"
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "RELAXED:model_b_unrelaxed"
assert (tmp_path / "ranked_1.pdb").read_text(encoding="utf-8") == "RELAXED:model_a_unrelaxed"
assert cleanup_calls
def test_postprocess_handles_monomers_without_relaxation_and_logs_modelcif_errors(
af2_backend_module,
monkeypatch,
tmp_path,
):
cleanup_calls = []
modelcif_errors = []
plot_calls = []
monkeypatch.setattr(
af2_backend_module,
"plot_pae_from_matrix",
lambda **kwargs: plot_calls.append(kwargs["ranking"]),
)
monkeypatch.setattr(
af2_backend_module,
"post_prediction_process",
lambda *args, **kwargs: cleanup_calls.append((args, kwargs)),
)
monkeypatch.setattr(
af2_backend_module.logging,
"error",
lambda message: modelcif_errors.append(message),
)
monkeypatch.setattr(
af2_backend_module.subprocess,
"run",
lambda *args, **kwargs: SimpleNamespace(
stderr="convert failed",
stdout="",
returncode=1,
),
)
monomer = af2_backend_module.MonomericObject("single", "AB")
prediction_results = {
"modelA": {
"plddt": np.array([81.0, 82.0], dtype=np.float32),
"predicted_aligned_error": np.zeros((2, 2), dtype=np.float32),
"max_predicted_aligned_error": 31.0,
"ranking_confidence": 81.5,
"ptm": 0.4,
"seqs": ["AB"],
"unrelaxed_protein": SimpleNamespace(name="mono"),
}
}
af2_backend_module.AlphaFold2Backend.postprocess(
prediction_results=prediction_results,
multimeric_object=monomer,
output_dir=tmp_path,
features_directory="/features",
models_to_relax=af2_backend_module.ModelsToRelax.NONE,
convert_to_modelcif=True,
)
ranking = json.loads((tmp_path / "ranking_debug.json").read_text(encoding="utf-8"))
assert ranking["plddts"] == {"modelA": 81.5}
assert ranking["ptm"] == {"modelA": 0.4}
assert ranking["order"] == ["modelA"]
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "PDB:mono"
assert not (tmp_path / "relaxed_modelA.pdb").exists()
assert plot_calls == [0]
assert cleanup_calls
assert modelcif_errors == ["Error: convert failed"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,434 @@
import importlib.util
import json
import sys
import types
from pathlib import Path
from types import SimpleNamespace
import numpy as np
import pytest
MODULE_PATH = (
Path(__file__).resolve().parents[2]
/ "alphapulldown"
/ "folding_backend"
/ "alphalink_backend.py"
)
def _package(name: str) -> types.ModuleType:
module = types.ModuleType(name)
module.__path__ = [] # type: ignore[attr-defined]
return module
def _restore_modules(saved_modules: dict[str, types.ModuleType | None]) -> None:
for name, module in saved_modules.items():
if module is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = module
def _tensor_tree_map(func, tree):
if isinstance(tree, dict):
return {key: _tensor_tree_map(func, value) for key, value in tree.items()}
return func(tree)
def _install_alphalink_backend_stubs() -> dict[str, types.ModuleType | None]:
names_to_replace = [
"alphapulldown.folding_backend.alphafold2_backend",
"alphapulldown.objects",
"alphapulldown.utils.plotting",
"torch",
"unifold",
"unifold.config",
"unifold.modules",
"unifold.modules.alphafold",
"unifold.dataset",
"unifold.data",
"unifold.data.residue_constants",
"unifold.data.protein",
"unifold.data.data_ops",
"unicore",
"unicore.utils",
]
saved_modules = {name: sys.modules.get(name) for name in names_to_replace}
af2_backend_mod = types.ModuleType("alphapulldown.folding_backend.alphafold2_backend")
af2_backend_mod._save_pae_json_file = (
lambda pae, max_pae, output_dir, model_name: Path(output_dir, f"pae_{model_name}.json").write_text(
json.dumps({"max_pae": max_pae}),
encoding="utf-8",
)
)
objects_mod = types.ModuleType("alphapulldown.objects")
class MonomericObject:
pass
class MultimericObject:
pass
class ChoppedObject:
pass
objects_mod.MonomericObject = MonomericObject
objects_mod.MultimericObject = MultimericObject
objects_mod.ChoppedObject = ChoppedObject
plotting_mod = types.ModuleType("alphapulldown.utils.plotting")
plotting_mod.plot_pae_from_matrix = lambda *args, **kwargs: None
torch_mod = types.ModuleType("torch")
torch_mod.bfloat16 = "bfloat16"
torch_mod.half = "half"
class FakeTensor:
def __init__(self, value, dtype=None):
self._value = np.asarray(value)
self.dtype = dtype if dtype is not None else self._value.dtype
def float(self):
return FakeTensor(self._value.astype(np.float32), np.float32)
def cpu(self):
return self._value
torch_mod.FakeTensor = FakeTensor
torch_mod.cuda = SimpleNamespace(
is_available=lambda: False,
current_device=lambda: 0,
get_device_properties=lambda _device: SimpleNamespace(
total_memory=40 * 1024 * 1024 * 1024
),
)
torch_mod.load = lambda path: {
"ema": {"params": {"module.layer": 1, "module.bias": 2}}
}
torch_mod.no_grad = lambda: types.SimpleNamespace(
__enter__=lambda self: None,
__exit__=lambda self, exc_type, exc, tb: False,
)
torch_mod.autograd = SimpleNamespace(set_detect_anomaly=lambda flag: None)
torch_mod.as_tensor = lambda value, device=None: value
torch_mod.from_numpy = lambda value: value
unifold_pkg = _package("unifold")
config_mod = types.ModuleType("unifold.config")
def _model_config(_name):
return SimpleNamespace(
data=SimpleNamespace(
common=SimpleNamespace(max_recycling_iters=None),
predict=SimpleNamespace(
num_ensembles=None,
subsample_templates=False,
),
),
globals=SimpleNamespace(max_recycling_iters=None),
)
config_mod.model_config = _model_config
modules_pkg = _package("unifold.modules")
alphafold_mod = types.ModuleType("unifold.modules.alphafold")
class FakeAlphaFold:
def __init__(self, config):
self.config = config
self.loaded_state_dict = None
self.device = None
self.eval_called = False
self.inference_mode_called = False
self.bfloat16_called = False
self.globals = SimpleNamespace(chunk_size=None, block_size=None)
def load_state_dict(self, state_dict):
self.loaded_state_dict = state_dict
def to(self, device):
self.device = device
return self
def eval(self):
self.eval_called = True
def inference_mode(self):
self.inference_mode_called = True
def bfloat16(self):
self.bfloat16_called = True
alphafold_mod.AlphaFold = FakeAlphaFold
dataset_mod = types.ModuleType("unifold.dataset")
dataset_mod.process_ap = lambda **kwargs: ({}, None)
data_pkg = _package("unifold.data")
residue_constants_mod = types.ModuleType("unifold.data.residue_constants")
residue_constants_mod.atom_order = {"CA": 1}
residue_constants_mod.atom_type_num = 37
protein_mod = types.ModuleType("unifold.data.protein")
protein_mod.from_prediction = lambda **kwargs: SimpleNamespace(chain_index=np.array([[0]]), aatype=np.array([[0]]))
protein_mod.to_pdb = lambda protein: "PDB"
data_ops_mod = types.ModuleType("unifold.data.data_ops")
data_ops_mod.get_pairwise_distances = lambda coords: np.zeros((1, 1))
unicore_pkg = _package("unicore")
unicore_utils_mod = types.ModuleType("unicore.utils")
unicore_utils_mod.tensor_tree_map = _tensor_tree_map
modules = {
"alphapulldown.folding_backend.alphafold2_backend": af2_backend_mod,
"alphapulldown.objects": objects_mod,
"alphapulldown.utils.plotting": plotting_mod,
"torch": torch_mod,
"unifold": unifold_pkg,
"unifold.config": config_mod,
"unifold.modules": modules_pkg,
"unifold.modules.alphafold": alphafold_mod,
"unifold.dataset": dataset_mod,
"unifold.data": data_pkg,
"unifold.data.residue_constants": residue_constants_mod,
"unifold.data.protein": protein_mod,
"unifold.data.data_ops": data_ops_mod,
"unicore": unicore_pkg,
"unicore.utils": unicore_utils_mod,
}
for name, module in modules.items():
sys.modules[name] = module
unifold_pkg.config = config_mod
unifold_pkg.modules = modules_pkg
unifold_pkg.dataset = dataset_mod
unifold_pkg.data = data_pkg
modules_pkg.alphafold = alphafold_mod
data_pkg.residue_constants = residue_constants_mod
data_pkg.protein = protein_mod
data_pkg.data_ops = data_ops_mod
unicore_pkg.utils = unicore_utils_mod
return saved_modules
@pytest.fixture(scope="module")
def alphalink_backend_module():
saved_modules = _install_alphalink_backend_stubs()
sys.modules.pop("alphapulldown.folding_backend.alphalink_backend", None)
spec = importlib.util.spec_from_file_location(
"alphapulldown.folding_backend.alphalink_backend",
MODULE_PATH,
)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
assert spec.loader is not None
try:
spec.loader.exec_module(module)
yield module
finally:
sys.modules.pop(spec.name, None)
_restore_modules(saved_modules)
def test_setup_accepts_expected_weights_and_validates_missing_inputs(
alphalink_backend_module,
monkeypatch,
tmp_path,
):
warning_messages = []
monkeypatch.setattr(alphalink_backend_module.logging, "warning", warning_messages.append)
direct_file = tmp_path / "custom_weights.pt"
direct_file.write_text("stub", encoding="utf-8")
direct_setup = alphalink_backend_module.AlphaLinkBackend.setup(str(direct_file))
assert direct_setup["param_path"] == str(direct_file)
assert warning_messages
weights_dir = tmp_path / "weights"
weights_dir.mkdir()
canonical_file = weights_dir / "AlphaLink-Multimer_SDA_v3.pt"
canonical_file.write_text("stub", encoding="utf-8")
dir_setup = alphalink_backend_module.AlphaLinkBackend.setup(str(weights_dir))
assert dir_setup["param_path"] == str(canonical_file)
wrong_extension = tmp_path / "weights.bin"
wrong_extension.write_text("stub", encoding="utf-8")
with pytest.raises(ValueError, match=".pt extension"):
alphalink_backend_module.AlphaLinkBackend.setup(str(wrong_extension))
with pytest.raises(FileNotFoundError, match="does not exist"):
alphalink_backend_module.AlphaLinkBackend.setup(str(tmp_path / "missing"))
def test_unload_tensors_and_prepare_model_runner(alphalink_backend_module):
fake_tensor = alphalink_backend_module.torch.FakeTensor
batch, out = alphalink_backend_module.AlphaLinkBackend.unload_tensors(
{"a": fake_tensor([1, 2], dtype=alphalink_backend_module.torch.bfloat16)},
{"b": fake_tensor([3, 4], dtype=np.float32)},
)
np.testing.assert_array_equal(batch["a"], np.array([1.0, 2.0], dtype=np.float32))
np.testing.assert_array_equal(out["b"], np.array([3, 4], dtype=np.float32))
model = alphalink_backend_module.AlphaLinkBackend.prepare_model_runner(
"weights.pt",
bf16=True,
model_device="cpu",
)
assert model.device == "cpu"
assert model.eval_called is True
assert model.inference_mode_called is True
assert model.bfloat16_called is True
assert model.loaded_state_dict == {"layer": 1, "bias": 2}
assert (
model.config.data.common.max_recycling_iters
== alphalink_backend_module.MAX_RECYCLING_ITERS
)
assert (
model.config.data.predict.num_ensembles
== alphalink_backend_module.NUM_ENSEMBLES
)
def test_resume_preprocess_and_chunk_helpers(alphalink_backend_module, tmp_path):
(tmp_path / "AlphaLink2_model_0_seed_123_0.875.pdb").write_text("pdb", encoding="utf-8")
(tmp_path / "pae_AlphaLink2_model_0_seed_123_0.875.json").write_text(
"{}",
encoding="utf-8",
)
already_exists, iptm_value = alphalink_backend_module.AlphaLinkBackend.check_resume_status(
"AlphaLink2_model_0_seed_123",
str(tmp_path),
)
assert already_exists is True
assert iptm_value == pytest.approx(0.875)
processed = alphalink_backend_module.AlphaLinkBackend.preprocess_features(
{
"seq_length": np.array([7, 7]),
"num_alignments": np.array([3, 3]),
"num_templates": np.array([2, 2]),
"template_all_atom_masks": np.ones((1, 7, 37), dtype=np.float32),
"template_aatype": np.eye(22)[[0, 1, 2, 3, 4, 5, 6]][None, :, :],
"template_sum_probs": np.array([0.5], dtype=np.float32),
"deletion_matrix_int": np.ones((1, 7), dtype=np.float32),
"deletion_matrix_int_all_seq": np.full((2, 7), 2.0, dtype=np.float32),
"msa": np.ones((2, 7), dtype=np.int32),
}
)
assert processed["seq_length"] == 7
assert processed["num_alignments"] == 3
assert processed["num_templates"] == 2
assert processed["template_all_atom_mask"].shape == (1, 7, 37)
assert processed["template_aatype"].shape == (1, 7)
assert processed["template_sum_probs"].shape == (1, 1)
np.testing.assert_array_equal(processed["deletion_matrix"], np.ones((1, 7)))
np.testing.assert_array_equal(
processed["extra_deletion_matrix"],
np.full((2, 7), 2.0),
)
assert processed["msa_mask"].shape == (2, 7)
assert processed["msa_row_mask"].shape == (2,)
np.testing.assert_array_equal(processed["asym_id"], np.zeros(7, dtype=np.int32))
np.testing.assert_array_equal(processed["entity_id"], np.zeros(7, dtype=np.int32))
np.testing.assert_array_equal(processed["sym_id"], np.ones(7, dtype=np.int32))
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(500, "cpu") == (
256,
None,
)
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(1000, "cpu") == (
128,
None,
)
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(1500, "cpu") == (
64,
None,
)
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(2200, "cpu") == (
32,
512,
)
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(3000, "cpu") == (
4,
256,
)
def test_predict_validates_inputs_and_builds_chain_maps(
alphalink_backend_module,
monkeypatch,
tmp_path,
):
with pytest.raises(ValueError, match="Missing required parameters"):
list(alphalink_backend_module.AlphaLinkBackend.predict(objects_to_model=[]))
captured_calls = []
def fake_predict_iterations(feature_dict, output_dir, **kwargs):
captured_calls.append((feature_dict, output_dir, kwargs))
monkeypatch.setattr(
alphalink_backend_module.AlphaLinkBackend,
"predict_iterations",
staticmethod(fake_predict_iterations),
)
default_chain_object = SimpleNamespace(
feature_dict={"x": 1},
input_seqs=["AAA"],
)
integer_chain_map_object = SimpleNamespace(
feature_dict={"y": 2},
input_seqs=["AAA", "BBB"],
chain_id_map={"A": 0, "B": 1},
)
results = list(
alphalink_backend_module.AlphaLinkBackend.predict(
objects_to_model=[
{"object": default_chain_object, "output_dir": str(tmp_path / "default")},
{"object": integer_chain_map_object, "output_dir": str(tmp_path / "int_map")},
],
configs=SimpleNamespace(data="cfg"),
param_path="weights.pt",
num_predictions_per_model=2,
)
)
assert len(results) == 2
assert captured_calls[0][2]["crosslinks"] == ""
default_chain_map = captured_calls[0][2]["chain_id_map"]
assert sorted(default_chain_map) == ["A"]
assert default_chain_map["A"].description == "default_chain"
assert default_chain_map["A"].sequence == "AAA"
integer_chain_map = captured_calls[1][2]["chain_id_map"]
assert integer_chain_map["A"].description == "chain_A"
assert integer_chain_map["A"].sequence == "AAA"
assert integer_chain_map["B"].description == "chain_B"
assert integer_chain_map["B"].sequence == "BBB"
assert captured_calls[1][2]["num_inference"] == 2
def test_postprocess_ranks_nested_prediction_outputs(alphalink_backend_module, tmp_path):
low_dir = tmp_path / "seed0"
high_dir = tmp_path / "seed1"
low_dir.mkdir()
high_dir.mkdir()
(low_dir / "AlphaLink2_model_0_seed_1_0.500.pdb").write_text("LOW", encoding="utf-8")
(high_dir / "AlphaLink2_model_1_seed_2_0.900.pdb").write_text("HIGH", encoding="utf-8")
alphalink_backend_module.AlphaLinkBackend.postprocess({}, str(tmp_path))
ranking = json.loads((tmp_path / "ranking_debug.json").read_text(encoding="utf-8"))
assert ranking["order"] == [
"AlphaLink2_model_1_seed_2_0.900",
"AlphaLink2_model_0_seed_1_0.500",
]
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "HIGH"
assert (tmp_path / "ranked_1.pdb").read_text(encoding="utf-8") == "LOW"

View File

@@ -0,0 +1,104 @@
import importlib.util
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[2]
AF2_CHECK_PATH = REPO_ROOT / "test" / "cluster" / "check_alphafold2_predictions.py"
AF2_WRAPPER_PATH = REPO_ROOT / "test" / "cluster" / "run_alphafold2_predictions.py"
AF3_WRAPPER_PATH = REPO_ROOT / "test" / "cluster" / "run_alphafold3_predictions.py"
def _load_module(module_name: str, module_path: Path):
sys.modules.pop(module_name, None)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def test_af2_cluster_subprocess_env_sets_safe_gpu_defaults(monkeypatch):
module = _load_module("test_cluster_af2_check_module", AF2_CHECK_PATH)
for name in (
"OMP_NUM_THREADS",
"MKL_NUM_THREADS",
"NUMEXPR_NUM_THREADS",
"TF_NUM_INTEROP_THREADS",
"TF_NUM_INTRAOP_THREADS",
"TF_FORCE_GPU_ALLOW_GROWTH",
"TF_CPP_MIN_LOG_LEVEL",
"XLA_PYTHON_CLIENT_PREALLOCATE",
"XLA_PYTHON_CLIENT_MEM_FRACTION",
"JAX_PLATFORM_NAME",
"XLA_FLAGS",
):
monkeypatch.delenv(name, raising=False)
env = module._af2_subprocess_env()
assert env["OMP_NUM_THREADS"] == "4"
assert env["MKL_NUM_THREADS"] == "4"
assert env["NUMEXPR_NUM_THREADS"] == "4"
assert env["TF_NUM_INTEROP_THREADS"] == "4"
assert env["TF_NUM_INTRAOP_THREADS"] == "4"
assert env["TF_FORCE_GPU_ALLOW_GROWTH"] == "true"
assert env["TF_CPP_MIN_LOG_LEVEL"] == "2"
assert env["XLA_PYTHON_CLIENT_PREALLOCATE"] == "false"
assert env["XLA_PYTHON_CLIENT_MEM_FRACTION"] == "0.8"
assert env["JAX_PLATFORM_NAME"] == "gpu"
assert "--xla_gpu_force_compilation_parallelism=0" in env["XLA_FLAGS"]
def test_af2_cluster_wrapper_job_script_exports_gpu_defaults(tmp_path):
module = _load_module("test_cluster_af2_wrapper_module", AF2_WRAPPER_PATH)
job = module.JobSpec(
index=1,
nodeid="test/cluster/check_alphafold2_predictions.py::TestRunModes::test__monomer",
slug="af2_node",
stdout_path=tmp_path / "stdout.log",
stderr_path=tmp_path / "stderr.log",
script_path=tmp_path / "job.sh",
rerun_command="python -m pytest",
)
module.write_job_script(
job=job,
python_executable=sys.executable,
use_temp_dir=True,
cpus_per_task=8,
)
script_text = job.script_path.read_text(encoding="utf-8")
assert 'OMP_NUM_THREADS="${OMP_NUM_THREADS:-4}"' in script_text
assert 'TF_NUM_INTRAOP_THREADS="${TF_NUM_INTRAOP_THREADS:-4}"' in script_text
assert 'JAX_PLATFORM_NAME="${JAX_PLATFORM_NAME:-gpu}"' in script_text
assert "addopts=-ra --strict-markers" in script_text
assert "--use-temp-dir" in script_text
def test_af3_cluster_wrapper_job_script_sets_perf_flag(tmp_path):
module = _load_module("test_cluster_af3_wrapper_module", AF3_WRAPPER_PATH)
job = module.JobSpec(
index=1,
nodeid="test/cluster/check_alphafold3_predictions.py::TestRunModes::test__monomer",
slug="af3_node",
stdout_path=tmp_path / "stdout.log",
stderr_path=tmp_path / "stderr.log",
script_path=tmp_path / "job.sh",
rerun_command="python -m pytest",
)
module.write_job_script(
job=job,
python_executable=sys.executable,
use_temp_dir=True,
include_perf=True,
)
script_text = job.script_path.read_text(encoding="utf-8")
assert "export AF3_RUN_PERF_TESTS=1" in script_text
assert "addopts=-ra --strict-markers" in script_text
assert "--use-temp-dir" in script_text

View File

@@ -0,0 +1,134 @@
import json
import shutil
from pathlib import Path
from types import SimpleNamespace
import pytest
pytest.importorskip("ihm")
pytest.importorskip("modelcif")
import alphapulldown.scripts.convert_to_modelcif as convert_to_modelcif
TEST_PREDICTIONS_DIR = (
Path(__file__).resolve().parents[1] / "test_data" / "predictions"
)
def test_cast_param_preserves_expected_python_types():
assert convert_to_modelcif._cast_param("7") == 7
assert convert_to_modelcif._cast_param("3.5") == 3.5
assert convert_to_modelcif._cast_param("True") is True
assert convert_to_modelcif._cast_param("False") is False
assert convert_to_modelcif._cast_param("plain-text") == "plain-text"
def test_compress_cif_file_replaces_plaintext_file(tmp_path):
cif_file = tmp_path / "ranked_0.cif"
cif_file.write_text("data_test\n", encoding="ascii")
compressed_path = convert_to_modelcif._compress_cif_file(str(cif_file))
assert compressed_path.endswith(".gz")
assert not cif_file.exists()
assert (tmp_path / "ranked_0.cif.gz").exists()
def test_get_feature_metadata_falls_back_to_structure_sequence(tmp_path):
source_dir = TEST_PREDICTIONS_DIR / "TEST"
work_dir = tmp_path / "TEST"
shutil.copytree(source_dir, work_dir)
metadata_file = next(work_dir.glob("*_feature_metadata_*.json"))
payload = json.loads(metadata_file.read_text(encoding="ascii"))
payload["other"]["fasta_paths"] = "['/this/path/does/not/exist.fasta']"
metadata_file.write_text(json.dumps(payload, indent=2), encoding="ascii")
modelcif_json = {}
complex_name, fasta_dicts = convert_to_modelcif._get_feature_metadata(
modelcif_json,
"TEST",
str(work_dir),
fallback_structure_path=str(work_dir / "ranked_0.pdb"),
)
assert complex_name == "TEST"
assert fasta_dicts
assert fasta_dicts[0]["description"] == "chain_A"
assert fasta_dicts[0]["sequence"].startswith("MESAIA")
assert modelcif_json["__meta__"]["TEST"]["databases"]
def test_get_model_list_selects_requested_model_and_tracks_non_selected_models():
selected_models = convert_to_modelcif._get_model_list(
str(TEST_PREDICTIONS_DIR / "TEST"),
0,
True,
)
assert len(selected_models) == 1
result = selected_models[0]
assert result["complex"] == "TEST"
assert len(result["models"]) == 1
assert result["models"][0][0].endswith("ranked_0.pdb")
assert len(result["not_selected"]) == 4
def test_main_processes_associated_models_before_selected_models(monkeypatch, tmp_path):
calls = []
def fake_convert(complex_name, model_tuple, out_dir, compress, additional_assoc_files=None):
calls.append(
{
"complex_name": complex_name,
"model_tuple": model_tuple,
"out_dir": out_dir,
"compress": compress,
"additional_assoc_files": additional_assoc_files,
}
)
return {
f"{Path(model_tuple[0]).stem}.zip": (
str(Path(out_dir) / f"{Path(model_tuple[0]).stem}.zip"),
object(),
)
}
monkeypatch.setattr(
convert_to_modelcif,
"FLAGS",
SimpleNamespace(
ap_output=str(tmp_path / "predictions"),
model_selected=0,
add_associated=True,
compress=False,
),
)
monkeypatch.setattr(
convert_to_modelcif,
"_get_model_list",
lambda ap_output, model_selected, get_non_selected: [
{
"complex": "TEST",
"path": str(tmp_path / "out"),
"models": [("ranked_0.pdb", "result_0.pkl", "0", 0)],
"not_selected": [("ranked_1.pdb", "result_1.pkl", "1", 1)],
}
],
)
monkeypatch.setattr(
convert_to_modelcif,
"alphapulldown_model_to_modelcif",
fake_convert,
)
convert_to_modelcif.main([])
assert len(calls) == 2
assert calls[0]["model_tuple"][0] == "ranked_1.pdb"
assert calls[0]["additional_assoc_files"] is None
assert calls[0]["out_dir"] != str(tmp_path / "out")
assert calls[1]["model_tuple"][0] == "ranked_0.pdb"
assert calls[1]["additional_assoc_files"]

View File

@@ -0,0 +1,76 @@
from pathlib import Path
import pickle
import numpy as np
import pytest
from alphafold.data.pipeline_multimer import _FastaChain
torch = pytest.importorskip("torch", reason="AlphaLink crosslink helpers require torch")
from unifold.dataset import bin_xl, calculate_offsets, create_xl_features
TEST_ROOT = Path(__file__).resolve().parents[1]
CROSSLINK_FIXTURE = TEST_ROOT / "alphalink" / "test_xl_input.pkl"
@pytest.fixture
def crosslink_context():
return {
"crosslinks": pickle.load(CROSSLINK_FIXTURE.open("rb")),
"asym_id": torch.tensor([1] * 10 + [2] * 25 + [3] * 40),
"chain_id_map": {
"A": _FastaChain(sequence="", description="chain1"),
"B": _FastaChain(sequence="", description="chain2"),
"C": _FastaChain(sequence="", description="chain3"),
},
"bins": torch.arange(0, 1.05, 0.05),
}
def test_calculate_offsets(crosslink_context):
offsets = calculate_offsets(crosslink_context["asym_id"]).tolist()
assert offsets == [0, 10, 35, 75]
def test_create_xl_inputs(crosslink_context):
offsets = calculate_offsets(crosslink_context["asym_id"])
xl = create_xl_features(
crosslink_context["crosslinks"],
offsets,
chain_id_map=crosslink_context["chain_id_map"],
)
expected_xl = torch.tensor(
[
[10, 35, 0.01],
[3, 27, 0.01],
[5, 56, 0.01],
[20, 65, 0.01],
]
)
assert torch.equal(xl, expected_xl)
def test_bin_xl(crosslink_context):
offsets = calculate_offsets(crosslink_context["asym_id"])
xl = create_xl_features(
crosslink_context["crosslinks"],
offsets,
chain_id_map=crosslink_context["chain_id_map"],
)
num_res = len(crosslink_context["asym_id"])
xl = bin_xl(xl, num_res)
expected_xl = np.zeros((num_res, num_res, 1))
expected_xl[3, 27, 0] = expected_xl[27, 3, 0] = torch.bucketize(
0.99, crosslink_context["bins"]
)
expected_xl[10, 35, 0] = expected_xl[35, 10, 0] = torch.bucketize(
0.99, crosslink_context["bins"]
)
expected_xl[5, 56, 0] = expected_xl[56, 5, 0] = torch.bucketize(
0.99, crosslink_context["bins"]
)
expected_xl[20, 65, 0] = expected_xl[65, 20, 0] = torch.bucketize(
0.99, crosslink_context["bins"]
)
assert np.array_equal(xl, expected_xl)

View File

@@ -296,3 +296,96 @@ def test_resolve_species_ids_by_accession_skips_single_accession_fallback_after_
'A0A743YDY2': '',
}
assert calls == [('A0A636IKY3', 'A0A743YDY2')]
def test_build_mmseq_identifier_features_skips_non_uniprot_identifiers(
monkeypatch,
):
calls = []
def fake_resolver(accessions):
calls.append(tuple(accessions))
return {'A0A636IKY3': '108619'}
a3m = '\n'.join([
'>101',
'ACDE',
'>MGYP000264027769',
'ACDF',
'>UniRef100_MGYP000264027769',
'ACDG',
'>UniRef100_A0A636IKY3',
'ACDH',
'',
])
features = mmseqs_species_identifiers.build_mmseq_identifier_features(
a3m,
species_resolver=fake_resolver,
)
assert calls == [('A0A636IKY3',)]
assert features['msa_species_identifiers'].tolist() == [
b'',
b'',
b'',
b'108619',
]
assert features['msa_uniprot_accession_identifiers'].tolist() == [
b'',
b'',
b'',
b'A0A636IKY3',
]
def test_resolve_species_ids_by_accession_skips_unsupported_accessions(
monkeypatch,
):
uniprot_calls = []
uniparc_calls = []
def fake_uniprot_query(accessions, *, urlopen):
uniprot_calls.append(tuple(accessions))
return {
'results': [
{
'primaryAccession': 'A0A636IKY3',
'organism': {'taxonId': 562},
}
]
}
def fake_uniparc_query(accessions, *, urlopen):
uniparc_calls.append(tuple(accessions))
return {
'results': [
{
'uniParcId': 'UPI001118B830',
'organisms': [{'taxonId': 83333}],
}
]
}
monkeypatch.setattr(
mmseqs_species_identifiers,
'_query_uniprot_batch',
fake_uniprot_query,
)
monkeypatch.setattr(
mmseqs_species_identifiers,
'_query_uniparc_batch',
fake_uniparc_query,
)
resolved = mmseqs_species_identifiers.resolve_species_ids_by_accession(
['A0A636IKY3', 'MGYP000264027769', 'UPI001118B830']
)
assert resolved == {
'A0A636IKY3': '562',
'MGYP000264027769': '',
'UPI001118B830': '83333',
}
assert uniprot_calls == [('A0A636IKY3',)]
assert uniparc_calls == [('UPI001118B830',)]

View File

@@ -1,7 +1,9 @@
import json
from alphapulldown.utils.output_paths import (
build_af3_combined_json_job_name,
derive_af3_job_name_from_json,
resolve_af3_combined_json_output_dir,
resolve_af3_json_output_dir,
sanitise_af3_job_name,
)
@@ -106,3 +108,60 @@ def test_resolve_af3_json_output_dir_keeps_unsafe_json_name_within_root(tmp_path
use_ap_style=True,
shared_output_root=True,
) == str(tmp_path / "predictions" / "input_name")
def test_build_af3_combined_json_job_name_uses_fold_fragments(tmp_path):
ptm_json = tmp_path / "protein_with_ptms.json"
ptm_json.write_text(
json.dumps(
{
"name": "protein_ptms",
"dialect": "alphafold3",
"version": 1,
}
),
encoding="utf-8",
)
partner_json = tmp_path / "P61626_af3_input.json"
partner_json.write_text(
json.dumps(
{
"name": "P61626",
"dialect": "alphafold3",
"version": 1,
}
),
encoding="utf-8",
)
assert build_af3_combined_json_job_name(
[
{"json_input": str(ptm_json)},
{"json_input": str(partner_json)},
]
) == "protein_with_ptms_and_p61626"
def test_resolve_af3_combined_json_output_dir_uses_one_fold_directory(tmp_path):
json_a = tmp_path / "P01308_af3_input.json"
json_b = tmp_path / "P61626_af3_input.json"
for json_path, name in ((json_a, "P01308"), (json_b, "P61626")):
json_path.write_text(
json.dumps(
{
"name": name,
"dialect": "alphafold3",
"version": 1,
}
),
encoding="utf-8",
)
assert resolve_af3_combined_json_output_dir(
[
{"json_input": str(json_a)},
{"json_input": str(json_b)},
],
str(tmp_path / "predictions"),
use_ap_style=True,
) == str(tmp_path / "predictions" / "p01308_and_p61626")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,430 @@
import gzip
import importlib.util
import json
import pickle
import runpy
import sys
import types
from pathlib import Path
from types import SimpleNamespace
import numpy as np
import pandas as pd
import pytest
import alphapulldown.scripts.generate_crosslink_pickle as crosslink_pickle
import alphapulldown.scripts.rename_colab_search_a3m as rename_a3m
REPO_ROOT = Path(__file__).resolve().parents[2]
PREPARE_SEQ_NAMES_PATH = (
REPO_ROOT / "alphapulldown" / "scripts" / "prepare_seq_names.py"
)
PARSE_INPUT_PATH = REPO_ROOT / "alphapulldown" / "scripts" / "parse_input.py"
SPLIT_JOBS_PATH = (
REPO_ROOT / "alphapulldown" / "scripts" / "split_jobs_into_clusters.py"
)
TRUNCATE_PICKLES_PATH = (
REPO_ROOT / "alphapulldown" / "scripts" / "truncate_pickles.py"
)
def _load_module_from_path(module_name: str, module_path: Path):
sys.modules.pop(module_name, None)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def _load_parse_input_module(monkeypatch):
flags_mod = types.ModuleType("absl.flags")
flags_mod.FLAGS = SimpleNamespace()
def _define_list(name, default, help_text):
del help_text
setattr(flags_mod.FLAGS, name, default)
def _define_string(name, default, help_text):
del help_text
setattr(flags_mod.FLAGS, name, default)
flags_mod.DEFINE_list = _define_list
flags_mod.DEFINE_string = _define_string
app_calls = []
app_mod = types.ModuleType("absl.app")
app_mod.run = lambda fn: app_calls.append(fn)
logging_mod = types.ModuleType("absl.logging")
logging_mod.INFO = 20
logging_mod.set_verbosity = lambda *_args, **_kwargs: None
absl_pkg = types.ModuleType("absl")
absl_pkg.app = app_mod
absl_pkg.flags = flags_mod
absl_pkg.logging = logging_mod
parser_calls = {}
parser_mod = types.ModuleType("alphapulldown_input_parser")
def _generate_fold_specifications(**kwargs):
parser_calls["generate"] = kwargs
return ["foldA"]
parser_mod.generate_fold_specifications = _generate_fold_specifications
modelling_calls = {}
modelling_setup_mod = types.ModuleType("alphapulldown.utils.modelling_setup")
def _parse_fold(specifications, features_directory, delimiter):
modelling_calls["parse_fold"] = (
specifications,
features_directory,
delimiter,
)
return specifications
def _create_custom_info(parsed):
modelling_calls["create_custom_info"] = parsed
return {"parsed": parsed}
modelling_setup_mod.parse_fold = _parse_fold
modelling_setup_mod.create_custom_info = _create_custom_info
monkeypatch.setitem(sys.modules, "absl", absl_pkg)
monkeypatch.setitem(sys.modules, "absl.app", app_mod)
monkeypatch.setitem(sys.modules, "absl.flags", flags_mod)
monkeypatch.setitem(sys.modules, "absl.logging", logging_mod)
monkeypatch.setitem(sys.modules, "alphapulldown_input_parser", parser_mod)
monkeypatch.setitem(
sys.modules,
"alphapulldown.utils.modelling_setup",
modelling_setup_mod,
)
module = _load_module_from_path("test_parse_input_module", PARSE_INPUT_PATH)
return module, app_calls, parser_calls, modelling_calls
def _load_split_jobs_module(monkeypatch):
parser_mod = types.ModuleType("alphapulldown_input_parser")
parser_mod.generate_fold_specifications = lambda **kwargs: ["A,B", "C;D"]
modelling_setup_mod = types.ModuleType("alphapulldown.utils.modelling_setup")
modelling_setup_mod.parse_fold = lambda args: args
modelling_setup_mod.create_custom_info = lambda parsed_input: parsed_input
modelling_setup_mod.create_interactors = (
lambda data, features_directory, index: [[data, features_directory, index]]
)
objects_mod = types.ModuleType("alphapulldown.objects")
class StubMultimericObject:
def __init__(self, interactors):
self.interactors = interactors
self.feature_dict = {"msa": np.zeros((4, 12), dtype=np.int32)}
objects_mod.MultimericObject = StubMultimericObject
monkeypatch.setitem(sys.modules, "alphapulldown_input_parser", parser_mod)
monkeypatch.setitem(
sys.modules,
"alphapulldown.utils.modelling_setup",
modelling_setup_mod,
)
monkeypatch.setitem(sys.modules, "alphapulldown.objects", objects_mod)
return _load_module_from_path("test_split_jobs_module", SPLIT_JOBS_PATH)
def _load_truncate_pickles_module(monkeypatch):
flags_mod = types.ModuleType("absl.flags")
flags_mod.FLAGS = SimpleNamespace()
def _define_string(name, default, help_text, required=False):
del help_text, required
setattr(flags_mod.FLAGS, name, default)
def _define_list(name, default, help_text):
del help_text
setattr(flags_mod.FLAGS, name, default)
def _define_integer(name, default, help_text):
del help_text
setattr(flags_mod.FLAGS, name, default)
flags_mod.DEFINE_string = _define_string
flags_mod.DEFINE_list = _define_list
flags_mod.DEFINE_integer = _define_integer
app_mod = types.ModuleType("absl.app")
app_mod.run = lambda fn: fn([])
logging_mod = types.ModuleType("absl.logging")
logging_mod.error = lambda *_args, **_kwargs: None
absl_pkg = types.ModuleType("absl")
absl_pkg.app = app_mod
absl_pkg.flags = flags_mod
absl_pkg.logging = logging_mod
monkeypatch.setitem(sys.modules, "absl", absl_pkg)
monkeypatch.setitem(sys.modules, "absl.app", app_mod)
monkeypatch.setitem(sys.modules, "absl.flags", flags_mod)
monkeypatch.setitem(sys.modules, "absl.logging", logging_mod)
return _load_module_from_path(
"test_truncate_pickles_module",
TRUNCATE_PICKLES_PATH,
)
def test_prepare_seq_names_rewrites_headers_from_uniprot_style_fasta(
monkeypatch,
tmp_path,
capsys,
):
fasta_path = tmp_path / "input.fasta"
fasta_path.write_text(
">sp|Q9H9K5|Protein alpha OS=Test\nACDE\n>tr|P12345|Protein beta\nFGHI\n",
encoding="utf-8",
)
monkeypatch.setattr(sys, "argv", [str(PREPARE_SEQ_NAMES_PATH), str(fasta_path)])
runpy.run_path(str(PREPARE_SEQ_NAMES_PATH), run_name="__main__")
assert capsys.readouterr().out.strip().splitlines() == [
">Q9H9K5",
"ACDE",
">P12345",
"FGHI",
]
def test_rename_colab_search_a3m_renames_legacy_search_outputs(monkeypatch, tmp_path):
monkeypatch.chdir(tmp_path)
(tmp_path / "0.a3m").write_text(">protA\nACDE\n>hit\nACDE\n", encoding="utf-8")
rename_a3m.main()
renamed = tmp_path / "protA.a3m"
assert renamed.read_text(encoding="utf-8") == ">protA\nACDE\n>hit\nACDE\n"
assert not (tmp_path / "0.a3m").exists()
def test_rename_colab_search_a3m_requires_input_fasta_for_new_colabfold(
monkeypatch,
tmp_path,
):
monkeypatch.chdir(tmp_path)
(tmp_path / "0.a3m").write_text(">101\nACDE\n>hit\nACDE\n", encoding="utf-8")
with pytest.raises(ValueError, match="Please provide the input FASTA file"):
rename_a3m.main()
def test_rename_colab_search_a3m_uses_input_fasta_names_for_new_colabfold(
monkeypatch,
tmp_path,
):
monkeypatch.chdir(tmp_path)
(tmp_path / "0.a3m").write_text(">101\nACDE\n>hit\nACDE\n", encoding="utf-8")
fasta_path = tmp_path / "input.fasta"
fasta_path.write_text(">queryA\nACDE\n", encoding="utf-8")
rename_a3m.main(str(fasta_path))
renamed = tmp_path / "queryA.a3m"
assert renamed.read_text(encoding="utf-8") == ">queryA\nACDE\n>hit\nACDE\n"
assert not (tmp_path / "0.a3m").exists()
def test_generate_crosslink_pickle_parses_single_row_input(tmp_path, monkeypatch):
links_path = tmp_path / "links.txt"
links_path.write_text("5 A 9 B 0.05\n", encoding="utf-8")
output_path = tmp_path / "crosslinks.pkl.gz"
monkeypatch.setattr(
crosslink_pickle,
"parse_arguments",
lambda: SimpleNamespace(csv=str(links_path), output=str(output_path)),
)
crosslink_pickle.main()
with gzip.open(output_path, "rb") as handle:
assert pickle.load(handle) == {"A": {"B": [(4, 8, 0.05)]}}
def test_generate_crosslink_pickle_parse_arguments_reads_cli_flags(monkeypatch):
monkeypatch.setattr(
sys,
"argv",
["prog", "--csv", "links.txt", "--output", "crosslinks.pkl.gz"],
)
args = crosslink_pickle.parse_arguments()
assert args.csv == "links.txt"
assert args.output == "crosslinks.pkl.gz"
def test_parse_input_main_writes_fold_spec_json(monkeypatch, tmp_path):
module, app_calls, parser_calls, modelling_calls = _load_parse_input_module(
monkeypatch
)
module.FLAGS = SimpleNamespace(
input_list=["folds.txt"],
protein_delimiter="+",
features_directory=["/features"],
output_prefix=str(tmp_path / "parsed_"),
)
module.main([])
assert app_calls == [module.main]
assert parser_calls["generate"] == {
"input_files": ["folds.txt"],
"delimiter": "+",
"exclude_permutations": True,
}
assert modelling_calls["parse_fold"] == (
["foldA"],
["/features"],
"+",
)
assert json.loads((tmp_path / "parsed_data.json").read_text(encoding="utf-8")) == {
"parsed": ["foldA"]
}
def test_split_jobs_cluster_jobs_writes_cluster_files(monkeypatch, tmp_path):
module = _load_split_jobs_module(monkeypatch)
plot_calls = []
monkeypatch.setattr(
module,
"profile_all_jobs_and_cluster",
lambda all_folds, args: pd.DataFrame(
{
"name": ["job_a", "job_b", "job_c"],
"msa_depth": [20, 40, 60],
"seq_length": [100, 120, 320],
}
),
)
monkeypatch.setattr(
module,
"plot_clustering_result",
lambda X, labels, num_cluster, output_dir: plot_calls.append(
(X.copy(), np.asarray(labels), num_cluster, output_dir)
),
)
module.cluster_jobs(["job_a", "job_b", "job_c"], SimpleNamespace(output_dir=str(tmp_path)))
assert (tmp_path / "job_cluster1_120_40.txt").read_text(encoding="utf-8").splitlines() == [
"job_a",
"job_b",
]
assert (tmp_path / "job_cluster2_320_60.txt").read_text(encoding="utf-8").splitlines() == [
"job_c",
]
assert plot_calls[0][2] == 2
np.testing.assert_array_equal(
plot_calls[0][0],
np.asarray([[100, 20], [120, 40], [320, 60]], dtype=np.int64),
)
np.testing.assert_array_equal(plot_calls[0][1], np.asarray([0, 0, 1]))
def test_split_jobs_main_normalises_generated_specs_and_all_vs_all_input(
monkeypatch,
tmp_path,
):
module = _load_split_jobs_module(monkeypatch)
cluster_calls = []
monkeypatch.setattr(
module.argparse.ArgumentParser,
"parse_args",
lambda self: SimpleNamespace(
protein_lists=["proteins.txt"],
protein_delimiter="+",
mode="all_vs_all",
features_directory=["/features"],
output_dir=str(tmp_path),
),
)
monkeypatch.setattr(
module,
"generate_fold_specifications",
lambda **kwargs: ["A,B", "C;D"],
)
monkeypatch.setattr(
module,
"cluster_jobs",
lambda all_folds, args: cluster_calls.append((all_folds, args)),
)
module.main()
assert cluster_calls[0][0] == ["A:B", "C+D"]
assert cluster_calls[0][1].protein_lists == ["proteins.txt"]
def test_truncate_pickles_main_copies_tree_and_removes_selected_pickle_keys(
monkeypatch,
tmp_path,
):
truncate_pickles = _load_truncate_pickles_module(monkeypatch)
src_dir = tmp_path / "src"
dst_dir = tmp_path / "dst"
nested_src = src_dir / "nested"
nested_dst = dst_dir / "nested"
nested_src.mkdir(parents=True)
nested_dst.mkdir(parents=True)
with open(nested_src / "result.pkl", "wb") as handle:
pickle.dump(
{
"keep": 1,
"aligned_confidence_probs": [1, 2],
"distogram": [3, 4],
},
handle,
)
(nested_src / "notes.txt").write_text("copied\n", encoding="utf-8")
(nested_dst / "notes.txt").write_text("existing\n", encoding="utf-8")
monkeypatch.setattr(
truncate_pickles,
"FLAGS",
SimpleNamespace(
src_dir=str(src_dir),
dst_dir=str(dst_dir),
keys_to_exclude="aligned_confidence_probs,distogram",
number_of_threads=2,
),
)
truncate_pickles.main([])
with open(nested_dst / "result.pkl", "rb") as handle:
assert pickle.load(handle) == {"keep": 1}
assert (nested_dst / "notes.txt").read_text(encoding="utf-8") == "existing\n"
def test_truncate_pickles_main_exits_when_source_dir_is_missing(monkeypatch, tmp_path):
truncate_pickles = _load_truncate_pickles_module(monkeypatch)
monkeypatch.setattr(
truncate_pickles,
"FLAGS",
SimpleNamespace(
src_dir=str(tmp_path / "missing"),
dst_dir=str(tmp_path / "dst"),
keys_to_exclude="aligned_confidence_probs,distogram",
number_of_threads=1,
),
)
with pytest.raises(SystemExit, match="1"):
truncate_pickles.main([])

View File

@@ -0,0 +1,161 @@
import importlib.util
import sys
import types
from pathlib import Path
from types import SimpleNamespace
MODULE_PATH = (
Path(__file__).resolve().parents[2]
/ "alphapulldown"
/ "folding_backend"
/ "unifold_backend.py"
)
def _restore_modules(saved_modules: dict[str, types.ModuleType | None]) -> None:
for name, module in saved_modules.items():
if module is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = module
def _install_unifold_backend_stubs() -> dict[str, types.ModuleType | None]:
names_to_replace = [
"alphapulldown.objects",
"unifold",
"unifold.config",
"unifold.inference",
"unifold.dataset",
]
saved_modules = {name: sys.modules.get(name) for name in names_to_replace}
objects_mod = types.ModuleType("alphapulldown.objects")
objects_mod.MultimericObject = type("MultimericObject", (), {})
unifold_pkg = types.ModuleType("unifold")
config_mod = types.ModuleType("unifold.config")
config_mod.model_config = lambda model_name: {"model_name": model_name}
inference_mod = types.ModuleType("unifold.inference")
inference_mod.calls = []
inference_mod.config_args = (
lambda model_dir, target_name, output_dir: {
"model_dir": model_dir,
"target_name": target_name,
"output_dir": output_dir,
}
)
inference_mod.unifold_config_model = lambda general_args: {"runner_args": general_args}
inference_mod.unifold_predict = (
lambda model_runner, model_args, processed_features: inference_mod.calls.append(
(model_runner, model_args, processed_features)
)
)
dataset_mod = types.ModuleType("unifold.dataset")
dataset_mod.process_ap = (
lambda config, features, mode, labels, seed, batch_idx, data_idx, is_distillation: (
{
"processed_features": features,
"seed": seed,
"mode": mode,
"config": config,
},
None,
)
)
modules = {
"alphapulldown.objects": objects_mod,
"unifold": unifold_pkg,
"unifold.config": config_mod,
"unifold.inference": inference_mod,
"unifold.dataset": dataset_mod,
}
for name, module in modules.items():
sys.modules[name] = module
unifold_pkg.config = config_mod
unifold_pkg.inference = inference_mod
unifold_pkg.dataset = dataset_mod
return saved_modules
def _load_unifold_backend_module():
saved_modules = _install_unifold_backend_stubs()
sys.modules.pop("alphapulldown.folding_backend.unifold_backend", None)
spec = importlib.util.spec_from_file_location(
"alphapulldown.folding_backend.unifold_backend",
MODULE_PATH,
)
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
assert spec.loader is not None
try:
spec.loader.exec_module(module)
return module, saved_modules
except Exception:
sys.modules.pop(spec.name, None)
_restore_modules(saved_modules)
raise
def test_unifold_setup_predict_and_postprocess():
module, saved_modules = _load_unifold_backend_module()
try:
multimeric_object = SimpleNamespace(description="complex", feature_dict={"msa": [1, 2]})
configured = module.UnifoldBackend.setup(
model_name="multimer",
model_dir="/models",
output_dir="/output",
multimeric_object=multimeric_object,
)
assert configured == {
"model_runner": {
"runner_args": {
"model_dir": "/models",
"target_name": "complex",
"output_dir": "/output",
}
},
"model_args": {
"model_dir": "/models",
"target_name": "complex",
"output_dir": "/output",
},
"model_config": {"model_name": "multimer"},
}
backend = module.UnifoldBackend()
assert (
backend.predict(
model_runner="runner",
model_args={"arg": 1},
model_config={"cfg": 2},
multimeric_object=multimeric_object,
random_seed=11,
)
is None
)
inference_mod = sys.modules["unifold.inference"]
assert inference_mod.calls == [
(
"runner",
{"arg": 1},
{
"processed_features": {"msa": [1, 2]},
"seed": 11,
"mode": "predict",
"config": {"cfg": 2},
},
)
]
assert module.UnifoldBackend.postprocess() is None
finally:
sys.modules.pop("alphapulldown.folding_backend.unifold_backend", None)
_restore_modules(saved_modules)