mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 05:58:11 +08:00
Tests (#600)
* Harden MMseqs species ID resolution fallback * Reorganize tests for CPU coverage CI * New * Fix function coverage checker def-line false positives * Expand unit coverage for helper and backend manager utilities * New. * New. * Expand unit coverage for template and post-processing helpers * Expand unit coverage for objects.py edge cases * Publish HTML coverage reports via GitHub Pages * Add CPU unit coverage for AlphaFold3 backend helpers * Reorganize tests and expand backend coverage * Reset shared test flags between cases * Expand AF3 prepare_input unit coverage * Cover AF3 and truemultimer feature creation * Test AF3 multimer MSA translation paths * Cover AF3 duplicate-residue multimer fallback * Cover AF2 resume and postprocess edge paths * Cover AF3 template mmCIF preparation * Test small script entry points * Expand workflow and ModelCIF test coverage * Add backend extras and install guide * Clarify AF3 backend installation path * Stabilize cluster GPU test runners * Document AF3 CMake SQLite hints * Simplify backend installation guide * Align AF3 install with working cluster env * Backfill typing dataclass_transform for AF2 * Pin TensorFlow for cluster installs * Fallback AF2 relax when CUDA OpenMM is unavailable * Raise AF3 default minimum bucket size * Simplify backend cluster installation guide * Fix AF3 wrapper JSON output isolation * Fix AF3 JSON wrapper outputs and MMseqs ID parsing * Fix CI entrypoint stub and Python 3.8 typing * Document release readiness test gates
This commit is contained in:
@@ -4,7 +4,8 @@ source =
|
||||
alphapulldown
|
||||
omit =
|
||||
alphapulldown/__init__.py
|
||||
alphapulldown/analysis_pipeline/af2plots/*
|
||||
alphapulldown/analysis_pipeline/*
|
||||
alphapulldown/folding_backend/alphalink_backend.py
|
||||
|
||||
[report]
|
||||
skip_empty = True
|
||||
|
||||
@@ -11,6 +11,7 @@ import json
|
||||
import pickle
|
||||
import subprocess
|
||||
import enum
|
||||
import typing
|
||||
from typing import Dict, Union, List, Any
|
||||
import os
|
||||
from absl import logging
|
||||
@@ -45,6 +46,55 @@ class ModelsToRelax(enum.Enum):
|
||||
NONE = 2
|
||||
|
||||
|
||||
def _ensure_typing_dataclass_transform() -> None:
|
||||
"""Provide typing.dataclass_transform on Python versions that lack it."""
|
||||
if hasattr(typing, "dataclass_transform"):
|
||||
return
|
||||
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
typing.dataclass_transform = dataclass_transform
|
||||
|
||||
|
||||
def _get_openmm_platform_names() -> List[str]:
|
||||
"""Return the available OpenMM platform names."""
|
||||
try:
|
||||
import openmm
|
||||
except ImportError:
|
||||
return []
|
||||
|
||||
try:
|
||||
return [
|
||||
openmm.Platform.getPlatform(i).getName()
|
||||
for i in range(openmm.Platform.getNumPlatforms())
|
||||
]
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("Failed to inspect OpenMM platforms: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def _resolve_gpu_relax(use_gpu_relax: bool) -> bool:
|
||||
"""Only request GPU relax when OpenMM exposes a CUDA platform."""
|
||||
if not use_gpu_relax:
|
||||
return False
|
||||
|
||||
platform_names = _get_openmm_platform_names()
|
||||
if "CUDA" in platform_names:
|
||||
return True
|
||||
|
||||
if platform_names:
|
||||
logging.warning(
|
||||
"OpenMM CUDA platform is unavailable; falling back to CPU relax. "
|
||||
"Available platforms: %s",
|
||||
", ".join(platform_names),
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"OpenMM CUDA platform is unavailable; falling back to CPU relax."
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _jnp_to_np(output):
|
||||
"""Recursively changes jax arrays to numpy arrays."""
|
||||
for k, v in output.items():
|
||||
@@ -381,6 +431,7 @@ class AlphaFold2Backend(FoldingBackend):
|
||||
If provided custom model names are not part of the available models.
|
||||
"""
|
||||
|
||||
_ensure_typing_dataclass_transform()
|
||||
from alphafold.model import config
|
||||
from alphafold.model import data, model
|
||||
|
||||
@@ -798,6 +849,7 @@ class AlphaFold2Backend(FoldingBackend):
|
||||
for model_name, prediction_result in prediction_results.items():
|
||||
prediction_result.update(AlphaFold2Backend.recalculate_confidence(prediction_result,multimer_mode,
|
||||
total_num_res))
|
||||
unrelaxed_protein = prediction_result.get("unrelaxed_protein")
|
||||
if 'unrelaxed_protein' in prediction_result.keys():
|
||||
unrelaxed_protein = prediction_result.pop("unrelaxed_protein")
|
||||
# Remove jax dependency from results
|
||||
@@ -806,7 +858,8 @@ class AlphaFold2Backend(FoldingBackend):
|
||||
result_output_path = os.path.join(output_dir, f"result_{model_name}.pkl")
|
||||
with open(result_output_path, "wb") as f:
|
||||
pickle.dump(np_prediction_result, f, protocol=4)
|
||||
prediction_results[model_name]['unrelaxed_protein'] = unrelaxed_protein
|
||||
if unrelaxed_protein is not None:
|
||||
prediction_results[model_name]['unrelaxed_protein'] = unrelaxed_protein
|
||||
if 'iptm' in prediction_result:
|
||||
label = 'iptm+ptm'
|
||||
iptm_scores[model_name] = float(prediction_result['iptm'])
|
||||
@@ -862,7 +915,7 @@ class AlphaFold2Backend(FoldingBackend):
|
||||
stiffness=RELAX_STIFFNESS,
|
||||
exclude_residues=RELAX_EXCLUDE_RESIDUES,
|
||||
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
|
||||
use_gpu=use_gpu_relax)
|
||||
use_gpu=_resolve_gpu_relax(use_gpu_relax))
|
||||
|
||||
if models_to_relax == ModelsToRelax.BEST:
|
||||
to_relax = [ranked_order[0]]
|
||||
|
||||
@@ -31,13 +31,15 @@ from alphapulldown.utils.modelling_setup import create_uniprot_runner
|
||||
from alphapulldown.utils import save_meta_data
|
||||
|
||||
# Try to import AlphaFold3, but it's optional
|
||||
AF3_IMPORT_ERROR = None
|
||||
try:
|
||||
from alphafold3.data.pipeline import DataPipeline as AF3DataPipeline, DataPipelineConfig as AF3DataPipelineConfig
|
||||
from alphafold3.common import folding_input
|
||||
except ImportError:
|
||||
except ImportError as exc:
|
||||
AF3DataPipeline = None
|
||||
AF3DataPipelineConfig = None
|
||||
folding_input = None
|
||||
AF3_IMPORT_ERROR = exc
|
||||
|
||||
# =================== Database Maps ===================
|
||||
AF2_DATABASES = {
|
||||
@@ -321,7 +323,13 @@ def create_custom_db(temp_dir, protein, template_paths, chains):
|
||||
def create_pipeline_af3():
|
||||
"""Create the AlphaFold3 pipeline. Raises if AF3 not available."""
|
||||
if AF3DataPipeline is None or AF3DataPipelineConfig is None:
|
||||
raise ImportError("alphafold3.data.pipeline not available")
|
||||
raise ImportError(
|
||||
"AlphaFold3 is not installed correctly. "
|
||||
"Install AlphaPulldown with 'pip install -e \".[alphafold3,test]\"', "
|
||||
"make sure the build environment provides SQLite, then build the "
|
||||
"vendored package with 'pip install -r alphafold3/dev-requirements.txt', "
|
||||
"'pip install --no-deps -e ./alphafold3', and 'build_data'."
|
||||
) from AF3_IMPORT_ERROR
|
||||
|
||||
# Convert max_template_date string to datetime.date object
|
||||
import datetime
|
||||
|
||||
@@ -38,8 +38,7 @@ out_lines = []
|
||||
|
||||
with open(sys.argv[1]) as f:
|
||||
for headerStr, seq in fasta_iter(f):
|
||||
items = re.split('[ \|]', headerStr)
|
||||
items = re.split(r"[ |]", headerStr)
|
||||
out_lines.append(f'>{items[1]}')
|
||||
out_lines.append(seq)
|
||||
f.close()
|
||||
print("\n".join(out_lines))
|
||||
print("\n".join(out_lines))
|
||||
|
||||
@@ -13,6 +13,8 @@ import sys
|
||||
import jax
|
||||
gpus = jax.local_devices(backend='gpu')
|
||||
from alphapulldown.scripts.run_structure_prediction import FLAGS
|
||||
from alphapulldown.utils.modelling_setup import parse_fold
|
||||
from alphapulldown.utils.output_paths import derive_af3_job_name_from_json
|
||||
from alphapulldown_input_parser import generate_fold_specifications
|
||||
|
||||
logging.set_verbosity(logging.INFO)
|
||||
@@ -38,6 +40,36 @@ flags.DEFINE_boolean("dry_run", False, "Report number of jobs that would be run
|
||||
flags.DEFINE_list("monomer_objects_dir", None, "a list of directories where monomer objects are stored")
|
||||
flags.DEFINE_string("output_path", None, "output directory where the region data is going to be stored")
|
||||
flags.DEFINE_string("data_dir", None, "Path to params directory")
|
||||
|
||||
|
||||
def _resolve_af3_wrapper_output_dir(
|
||||
job_spec: str,
|
||||
output_root: str,
|
||||
*,
|
||||
features_directory,
|
||||
protein_delimiter: str,
|
||||
use_ap_style: bool,
|
||||
) -> str:
|
||||
"""Scope single AF3 JSON wrapper jobs to stable per-job subdirectories."""
|
||||
if not use_ap_style:
|
||||
return output_root
|
||||
|
||||
try:
|
||||
parsed_jobs = parse_fold([job_spec], features_directory, protein_delimiter)
|
||||
except Exception:
|
||||
return output_root
|
||||
|
||||
if len(parsed_jobs) != 1 or len(parsed_jobs[0]) != 1:
|
||||
return output_root
|
||||
|
||||
entry = parsed_jobs[0][0]
|
||||
if not (isinstance(entry, dict) and "json_input" in entry):
|
||||
return output_root
|
||||
|
||||
return os.path.join(
|
||||
output_root,
|
||||
derive_af3_job_name_from_json(entry["json_input"]),
|
||||
)
|
||||
del(FLAGS.models_to_relax)
|
||||
flags.DEFINE_enum("models_to_relax",'None',['None','All','Best'],
|
||||
"Which models to relax. Default is None, meaning no model will be relaxed")
|
||||
@@ -177,6 +209,14 @@ def main(argv):
|
||||
else:
|
||||
for job_index in job_indices:
|
||||
command_args["--input"] = all_folds[job_index]
|
||||
if fold_backend == "alphafold3":
|
||||
command_args["--output_directory"] = _resolve_af3_wrapper_output_dir(
|
||||
all_folds[job_index],
|
||||
FLAGS.output_path,
|
||||
features_directory=FLAGS.monomer_objects_dir,
|
||||
protein_delimiter=FLAGS.protein_delimiter,
|
||||
use_ap_style=af3_use_ap_style,
|
||||
)
|
||||
command = base_command.copy()
|
||||
for arg, value in command_args.items():
|
||||
command.extend([str(arg), str(value)])
|
||||
|
||||
@@ -26,7 +26,10 @@ from alphapulldown.folding_backend import backend
|
||||
from alphapulldown.folding_backend.alphafold2_backend import ModelsToRelax
|
||||
from alphapulldown.objects import MultimericObject, MonomericObject, ChoppedObject
|
||||
from alphapulldown.utils.modelling_setup import create_interactors, create_custom_info, parse_fold
|
||||
from alphapulldown.utils.output_paths import resolve_af3_json_output_dir
|
||||
from alphapulldown.utils.output_paths import (
|
||||
resolve_af3_combined_json_output_dir,
|
||||
resolve_af3_json_output_dir,
|
||||
)
|
||||
import sys as _sys
|
||||
|
||||
logging.set_verbosity(logging.INFO)
|
||||
@@ -120,7 +123,7 @@ flags.DEFINE_string(
|
||||
flags.DEFINE_list(
|
||||
'buckets',
|
||||
# pyformat: disable
|
||||
['64', '128', '256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072',
|
||||
['128', '256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072',
|
||||
'3584', '4096', '4608', '5120'],
|
||||
# pyformat: enable
|
||||
'Strictly increasing order of token sizes for which to cache compilations.'
|
||||
@@ -486,6 +489,7 @@ def main(argv):
|
||||
prot_objs, output_dir=out_dir
|
||||
)
|
||||
objects_to_model.append({'object': obj, 'output_dir': real_out})
|
||||
json_output_dir = real_out
|
||||
|
||||
# Update final flags based on object type
|
||||
final_model_flags = default_model_flags.copy()
|
||||
@@ -498,15 +502,27 @@ def main(argv):
|
||||
"model_names_custom": FLAGS.model_names,
|
||||
"msa_depth": FLAGS.msa_depth
|
||||
})
|
||||
# Then handle any number of JSON inputs
|
||||
for json_dict in json_dicts:
|
||||
json_output_dir = resolve_af3_json_output_dir(
|
||||
json_dict["json_input"],
|
||||
elif len(json_dicts) > 1:
|
||||
json_output_dir = resolve_af3_combined_json_output_dir(
|
||||
json_dicts,
|
||||
out_dir,
|
||||
use_ap_style=FLAGS.use_ap_style,
|
||||
shared_output_root=shared_output_root,
|
||||
)
|
||||
objects_to_model.append({'object': json_dict, 'output_dir': json_output_dir})
|
||||
else:
|
||||
json_output_dir = None
|
||||
# Then handle any number of JSON inputs
|
||||
for json_dict in json_dicts:
|
||||
current_json_output_dir = json_output_dir
|
||||
if current_json_output_dir is None:
|
||||
current_json_output_dir = resolve_af3_json_output_dir(
|
||||
json_dict["json_input"],
|
||||
out_dir,
|
||||
use_ap_style=FLAGS.use_ap_style,
|
||||
shared_output_root=shared_output_root,
|
||||
)
|
||||
objects_to_model.append(
|
||||
{'object': json_dict, 'output_dir': current_json_output_dir}
|
||||
)
|
||||
|
||||
if objects_to_model:
|
||||
predict_structure(
|
||||
|
||||
@@ -33,8 +33,8 @@ _UNIREF_HEADER_PATTERN = re.compile(
|
||||
r'^UniRef\d+_(?P<accession>[A-Za-z0-9]+)$'
|
||||
)
|
||||
_UNIPARC_HEADER_PATTERN = re.compile(r'^(?P<accession>UPI[A-Z0-9]+)$')
|
||||
_GENERIC_ACCESSION_PATTERN = re.compile(
|
||||
r'^(?P<accession>[A-Za-z0-9]{6,16})$'
|
||||
_UNIPROT_ACCESSION_PATTERN = re.compile(
|
||||
r'^(?:[OPQ][0-9][A-Z0-9]{3}[0-9]|[A-NR-Z][0-9](?:[A-Z][A-Z0-9]{2}[0-9]){1,2})$'
|
||||
)
|
||||
|
||||
_UNIPROT_BATCH_SIZE = 32
|
||||
@@ -59,14 +59,19 @@ def _extract_accession_and_species(description: str) -> tuple[str, str]:
|
||||
if matches:
|
||||
return matches.group("accession"), matches.group("species")
|
||||
|
||||
for pattern in (
|
||||
_UNIREF_HEADER_PATTERN,
|
||||
_UNIPARC_HEADER_PATTERN,
|
||||
_GENERIC_ACCESSION_PATTERN,
|
||||
):
|
||||
matches = pattern.search(sequence_identifier.strip())
|
||||
if matches:
|
||||
return matches.group("accession"), ""
|
||||
matches = _UNIREF_HEADER_PATTERN.search(sequence_identifier.strip())
|
||||
if matches:
|
||||
accession = matches.group("accession")
|
||||
if _is_resolvable_accession(accession):
|
||||
return accession, ""
|
||||
return "", ""
|
||||
|
||||
matches = _UNIPARC_HEADER_PATTERN.search(sequence_identifier.strip())
|
||||
if matches:
|
||||
return matches.group("accession"), ""
|
||||
|
||||
if _UNIPROT_ACCESSION_PATTERN.search(sequence_identifier.strip()):
|
||||
return sequence_identifier.strip(), ""
|
||||
|
||||
return "", ""
|
||||
|
||||
@@ -82,6 +87,12 @@ def _is_transport_error(exc: Exception) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _is_resolvable_accession(accession: str) -> bool:
|
||||
return accession.startswith("UPI") or bool(
|
||||
_UNIPROT_ACCESSION_PATTERN.fullmatch(accession)
|
||||
)
|
||||
|
||||
|
||||
def _query_uniprot_batch(
|
||||
accessions: Sequence[str],
|
||||
*,
|
||||
@@ -203,8 +214,13 @@ def resolve_species_ids_by_accession(
|
||||
unresolved = [
|
||||
accession
|
||||
for accession in sorted(set(accessions))
|
||||
if accession and accession not in _SPECIES_ID_CACHE
|
||||
if accession
|
||||
and accession not in _SPECIES_ID_CACHE
|
||||
and _is_resolvable_accession(accession)
|
||||
]
|
||||
for accession in accessions:
|
||||
if accession and accession not in _SPECIES_ID_CACHE and not _is_resolvable_accession(accession):
|
||||
_SPECIES_ID_CACHE.setdefault(accession, "")
|
||||
if unresolved:
|
||||
uniprot_accessions = [
|
||||
accession for accession in unresolved if not accession.startswith("UPI")
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
import string
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
_AF3_ALLOWED_NAME_CHARS = set(string.ascii_lowercase + string.digits + "_-.")
|
||||
@@ -40,6 +41,119 @@ def derive_af3_job_name_from_json(json_input_path: str) -> str:
|
||||
return fallback_name
|
||||
|
||||
|
||||
def _json_input_basename(json_input_path: str) -> str:
|
||||
stem = Path(json_input_path).stem
|
||||
for suffix in ("_af3_input", "_input"):
|
||||
if stem.endswith(suffix):
|
||||
stem = stem[: -len(suffix)]
|
||||
break
|
||||
return stem or Path(json_input_path).stem
|
||||
|
||||
|
||||
def _collapse_repeated_name_fragments(fragments: Sequence[str]) -> list[str]:
|
||||
if not fragments:
|
||||
return []
|
||||
|
||||
collapsed: list[str] = []
|
||||
current_fragment = fragments[0]
|
||||
current_count = 1
|
||||
|
||||
for fragment in fragments[1:]:
|
||||
if fragment == current_fragment:
|
||||
current_count += 1
|
||||
continue
|
||||
|
||||
collapsed.append(
|
||||
current_fragment
|
||||
if current_count == 1
|
||||
else f"{current_fragment}__x{current_count}"
|
||||
)
|
||||
current_fragment = fragment
|
||||
current_count = 1
|
||||
|
||||
collapsed.append(
|
||||
current_fragment
|
||||
if current_count == 1
|
||||
else f"{current_fragment}__x{current_count}"
|
||||
)
|
||||
return collapsed
|
||||
|
||||
|
||||
def _compact_output_job_name(job_name: str, *, max_chars: int = 200) -> str:
|
||||
if len(job_name) <= max_chars:
|
||||
return job_name
|
||||
|
||||
import hashlib
|
||||
|
||||
digest = hashlib.sha1(job_name.encode("utf-8")).hexdigest()[:12]
|
||||
suffix = f"__{digest}"
|
||||
prefix = job_name[: max_chars - len(suffix)].rstrip("_.-")
|
||||
if not prefix:
|
||||
return f"job{suffix}"
|
||||
return f"{prefix}{suffix}"
|
||||
|
||||
|
||||
def _normalise_json_regions(regions: object) -> str | None:
|
||||
if not isinstance(regions, list) or not regions:
|
||||
return None
|
||||
|
||||
parts: list[str] = []
|
||||
for region in regions:
|
||||
if not isinstance(region, (tuple, list)) or len(region) != 2:
|
||||
return None
|
||||
start, end = region
|
||||
parts.append(f"{start}-{end}")
|
||||
return "_".join(parts)
|
||||
|
||||
|
||||
def build_af3_combined_json_job_name(
|
||||
json_inputs: Sequence[dict[str, object]],
|
||||
) -> str:
|
||||
fragments: list[str] = []
|
||||
|
||||
for json_input in json_inputs:
|
||||
json_input_path = json_input.get("json_input")
|
||||
if not isinstance(json_input_path, str) or not json_input_path:
|
||||
continue
|
||||
|
||||
fragment = _json_input_basename(json_input_path)
|
||||
region_fragment = _normalise_json_regions(json_input.get("regions"))
|
||||
if region_fragment:
|
||||
fragment = f"{fragment}__{region_fragment}"
|
||||
fragments.append(sanitise_af3_job_name(fragment))
|
||||
|
||||
fragments = [fragment for fragment in fragments if fragment]
|
||||
if not fragments:
|
||||
return "ranked_0"
|
||||
return _compact_output_job_name(
|
||||
"_and_".join(_collapse_repeated_name_fragments(fragments))
|
||||
)
|
||||
|
||||
|
||||
def _ensure_path_is_within_root(candidate: Path, output_root: Path) -> None:
|
||||
try:
|
||||
candidate.resolve(strict=False).relative_to(output_root.resolve(strict=False))
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"Resolved AF3 output directory {candidate} escapes configured root {output_root}"
|
||||
) from exc
|
||||
|
||||
|
||||
def resolve_af3_combined_json_output_dir(
|
||||
json_inputs: Sequence[dict[str, object]],
|
||||
output_dir: str,
|
||||
*,
|
||||
use_ap_style: bool,
|
||||
) -> str:
|
||||
if not use_ap_style:
|
||||
return output_dir
|
||||
|
||||
output_root = Path(output_dir)
|
||||
candidate = output_root / build_af3_combined_json_job_name(json_inputs)
|
||||
_ensure_path_is_within_root(candidate, output_root)
|
||||
return os.fspath(candidate)
|
||||
|
||||
|
||||
def resolve_af3_json_output_dir(
|
||||
json_input_path: str,
|
||||
output_dir: str,
|
||||
@@ -53,12 +167,5 @@ def resolve_af3_json_output_dir(
|
||||
|
||||
output_root = Path(output_dir)
|
||||
candidate = output_root / derive_af3_job_name_from_json(json_input_path)
|
||||
|
||||
try:
|
||||
candidate.resolve(strict=False).relative_to(output_root.resolve(strict=False))
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"Resolved AF3 output directory {candidate} escapes configured root {output_root}"
|
||||
) from exc
|
||||
|
||||
_ensure_path_is_within_root(candidate, output_root)
|
||||
return os.fspath(candidate)
|
||||
|
||||
@@ -215,9 +215,10 @@ def tmp_flags(monkeypatch, tmp_path):
|
||||
# not an absl flag on this FLAGS; attach as attribute for completeness
|
||||
setattr(F, name, value)
|
||||
continue
|
||||
fl.unparse()
|
||||
arg = _to_arg_str(name, value)
|
||||
if arg is None:
|
||||
# leave at declared default (often None); don't parse
|
||||
# reset to the declared default and leave it unset
|
||||
continue
|
||||
fl.parse(arg) # sets .value and marks present
|
||||
|
||||
|
||||
@@ -40,7 +40,9 @@ RUN set -eux; \
|
||||
modelcif \
|
||||
hmmer \
|
||||
hhsuite \
|
||||
pdbfixer=1.9 \
|
||||
"numpy<2" \
|
||||
"openmm>=8.2" \
|
||||
"pdbfixer>=1.10" \
|
||||
pip \
|
||||
git \
|
||||
&& micromamba clean -a -y
|
||||
@@ -78,5 +80,3 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
rm -rf /var/lib/apt/lists/* /root/.cache
|
||||
|
||||
#ENTRYPOINT ["bash"]
|
||||
RUN pip install --no-cache-dir "numpy<2"
|
||||
|
||||
|
||||
246
docs/backend_installation.md
Normal file
246
docs/backend_installation.md
Normal file
@@ -0,0 +1,246 @@
|
||||
# Backend Installation Guide
|
||||
|
||||
This guide is for direct AlphaPulldown use without Snakemake.
|
||||
|
||||
Two points matter in practice:
|
||||
|
||||
1. Run the install commands from the AlphaPulldown repo root.
|
||||
2. For AlphaFold3, building the vendored `alphafold3` package is a separate step. A root-level install alone is not enough.
|
||||
|
||||
The Docker files remain the long-term reference environments, but the commands below are the simpler cluster-facing paths.
|
||||
We revalidated them on EMBL on March 30, 2026 by creating fresh environments and running the cluster test suites there.
|
||||
|
||||
## Known-good cluster stacks
|
||||
|
||||
These are the two EMBL environments that were already working and that we rechecked while validating the wrappers:
|
||||
|
||||
- `AlphaPulldown` for AF2:
|
||||
- Python `3.10`
|
||||
- `jax 0.5.3`
|
||||
- `jaxlib 0.5.3`
|
||||
- `numpy 1.26.4`
|
||||
- `tensorflow-cpu 2.20.0`
|
||||
- `openmm 8.1.1`
|
||||
- `pdbfixer 1.12`
|
||||
- `modelcif 1.6`
|
||||
- `AlphaPulldown_alphafold3` for AF3:
|
||||
- Python `3.11`
|
||||
- `jax 0.5.3`
|
||||
- `jaxlib 0.5.3`
|
||||
- `numpy 1.26.4`
|
||||
- `tensorflow-cpu 2.18.0`
|
||||
- `openmm 8.3.1`
|
||||
- `pdbfixer 1.12`
|
||||
- `modelcif 1.6`
|
||||
- `jax-triton 0.2.0`
|
||||
- `triton 3.1.0`
|
||||
- `rdkit 2024.3.5`
|
||||
- `typeguard 2.13.3`
|
||||
- compiled `alphafold3.cpp`
|
||||
|
||||
The installation steps below are written to stay close to those working environments and to keep user-facing environment variables to a minimum.
|
||||
|
||||
## AlphaFold2 backend
|
||||
|
||||
Create the environment, activate it, move into the repo, and install:
|
||||
|
||||
```bash
|
||||
mamba create -y -n apd-af2 -c conda-forge -c bioconda \
|
||||
python=3.10 \
|
||||
kalign2 \
|
||||
hmmer \
|
||||
hhsuite
|
||||
mamba activate apd-af2
|
||||
cd /path/to/AlphaPulldown
|
||||
python -m pip install ".[alphafold2,test]"
|
||||
```
|
||||
|
||||
Check that GPU JAX is visible:
|
||||
|
||||
```bash
|
||||
python - <<'PY'
|
||||
import jax
|
||||
print(jax.__version__)
|
||||
print(jax.local_devices(backend="gpu"))
|
||||
PY
|
||||
```
|
||||
|
||||
## AlphaFold3 backend
|
||||
|
||||
AlphaFold3 needs both the AlphaPulldown install and the vendored `alphafold3` build.
|
||||
|
||||
```bash
|
||||
mamba create -y -n apd-af3 -c conda-forge -c bioconda \
|
||||
python=3.11 \
|
||||
kalign2 \
|
||||
hmmer \
|
||||
hhsuite \
|
||||
libcifpp \
|
||||
sqlite
|
||||
mamba activate apd-af3
|
||||
cd /path/to/AlphaPulldown
|
||||
python -m pip install ".[alphafold3,test]"
|
||||
python -m pip install --no-deps -e ./alphafold3
|
||||
build_data
|
||||
```
|
||||
|
||||
Check that the compiled extension and GPU JAX are available:
|
||||
|
||||
```bash
|
||||
python - <<'PY'
|
||||
import alphafold3.cpp
|
||||
import jax
|
||||
print(jax.__version__)
|
||||
print(jax.local_devices(backend="gpu"))
|
||||
print(alphafold3.cpp.__file__)
|
||||
PY
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- The working EMBL AF3 environment is closer to `.[alphafold3,test]` than to the upstream `alphafold3/dev-requirements.txt` stack.
|
||||
- For cluster installs, prefer `python -m pip install ".[alphafold3,test]"` and then build the vendored `alphafold3` package.
|
||||
- The compiled `alphafold3.cpp` extension comes from `python -m pip install --no-deps -e ./alphafold3`, not from the root install.
|
||||
- The vendored AF3 package provides the `build_data` entry point. Use that directly.
|
||||
- If you are actively developing inside the checkout and want the root package editable as well, add:
|
||||
|
||||
```bash
|
||||
python -m pip install -e . --no-deps
|
||||
```
|
||||
|
||||
## Cluster validation
|
||||
|
||||
On EMBL, the AF2 and AF3 functional tests already have default database roots baked into the test files, so the only environment variable you need for the standard cluster runs is:
|
||||
|
||||
```bash
|
||||
export RUN_GPU_FUNCTIONAL_TESTS=1
|
||||
```
|
||||
|
||||
Set `ALPHAFOLD_DATA_DIR` only if your databases are not in the EMBL default locations.
|
||||
|
||||
Two ways to run the cluster tests:
|
||||
|
||||
1. Direct `srun` + `pytest`
|
||||
- Best for validating a fresh install end-to-end.
|
||||
- Keeps everything on one allocated GPU node.
|
||||
- More reliable than the wrappers when Slurm priority is poor.
|
||||
- Important: when calling `pytest` directly, pass `-o addopts="-ra --strict-markers"`. The repo-level `pytest.ini` excludes cluster tests by default.
|
||||
2. Wrapper scripts
|
||||
- `test/cluster/run_alphafold2_predictions.py`
|
||||
- `test/cluster/run_alphafold3_predictions.py`
|
||||
- Faster when the queue is healthy, because they fan out one job per pytest node.
|
||||
|
||||
### AlphaFold2
|
||||
|
||||
Direct full-suite validation:
|
||||
|
||||
```bash
|
||||
srun -p gpu-training --gres=gpu:1 \
|
||||
--cpus-per-task=4 --mem=16G --time=12:00:00 \
|
||||
bash -lc '
|
||||
cd /path/to/AlphaPulldown
|
||||
export RUN_GPU_FUNCTIONAL_TESTS=1
|
||||
python -m pytest -o addopts="-ra --strict-markers" -vv -s \
|
||||
test/cluster/check_alphafold2_predictions.py --use-temp-dir
|
||||
'
|
||||
```
|
||||
|
||||
Expected result on the standard suite:
|
||||
|
||||
```text
|
||||
11 passed, 1 skipped
|
||||
```
|
||||
|
||||
The skip is the opt-in MMseqs functional inference check, which only runs when `RUN_MMSEQS_FUNCTIONAL_TESTS=1` is set.
|
||||
|
||||
Preview the collected nodes for wrapper mode:
|
||||
|
||||
```bash
|
||||
python test/cluster/run_alphafold2_predictions.py --list
|
||||
```
|
||||
|
||||
Wrapper-based parallel submission:
|
||||
|
||||
```bash
|
||||
python test/cluster/run_alphafold2_predictions.py \
|
||||
--max-tests 6 \
|
||||
--use-temp-dir \
|
||||
--partition gpu-training
|
||||
```
|
||||
|
||||
### AlphaFold3
|
||||
|
||||
Direct full-suite validation:
|
||||
|
||||
```bash
|
||||
srun -p gpu-training --gres=gpu:1 \
|
||||
--cpus-per-task=4 --mem=64G --time=12:00:00 \
|
||||
bash -lc '
|
||||
cd /path/to/AlphaPulldown
|
||||
export RUN_GPU_FUNCTIONAL_TESTS=1
|
||||
python -m pytest -o addopts="-ra --strict-markers" -vv -s \
|
||||
test/cluster/check_alphafold3_predictions.py --use-temp-dir
|
||||
'
|
||||
```
|
||||
|
||||
On our fresh EMBL validation run, `32G` was not enough for the full AF3 suite on `hgx5`; `64G` was.
|
||||
|
||||
Preview the collected nodes for wrapper mode:
|
||||
|
||||
```bash
|
||||
python test/cluster/run_alphafold3_predictions.py --list
|
||||
```
|
||||
|
||||
Wrapper-based parallel submission:
|
||||
|
||||
```bash
|
||||
python test/cluster/run_alphafold3_predictions.py \
|
||||
--max-tests 6 \
|
||||
--use-temp-dir \
|
||||
--partition gpu-training
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### AlphaFold2: `Unknown backend: 'gpu' requested, ... Platforms present are: cpu`
|
||||
|
||||
That means the environment has CPU-only JAX:
|
||||
|
||||
```bash
|
||||
python -m pip install --upgrade --no-cache-dir "jax==0.5.3" "jax[cuda12]==0.5.3"
|
||||
python - <<'PY'
|
||||
import jax
|
||||
print(jax.__version__)
|
||||
print(jax.local_devices(backend="gpu"))
|
||||
PY
|
||||
```
|
||||
|
||||
### AlphaFold2: `There is no registered Platform called "CUDA"`
|
||||
|
||||
That comes from OpenMM relaxation, not from JAX. Older working EMBL envs exposed a CUDA-enabled OpenMM platform, while a fresh pip-installed OpenMM may expose only `Reference`, `CPU`, and `OpenCL`.
|
||||
|
||||
Current AlphaPulldown falls back to CPU relax automatically if CUDA is unavailable, so the AF2 cluster tests still pass. If you want GPU-backed OpenMM relax as well, install the OpenMM stack from conda before the pip install.
|
||||
|
||||
### AlphaFold3: `ModuleNotFoundError: No module named 'alphafold3.cpp'`
|
||||
|
||||
The root AlphaPulldown install succeeded, but the vendored AF3 package was not built yet:
|
||||
|
||||
```bash
|
||||
cd /path/to/AlphaPulldown
|
||||
python -m pip install ".[alphafold3,test]"
|
||||
python -m pip install --no-deps -e ./alphafold3
|
||||
build_data
|
||||
```
|
||||
|
||||
### AlphaFold3 build error: `Could NOT find SQLite3`
|
||||
|
||||
If SQLite is installed in the conda environment but CMake still cannot find it, rerun the AF3 build with the conda prefix exposed:
|
||||
|
||||
```bash
|
||||
mamba activate apd-af3
|
||||
export CMAKE_PREFIX_PATH="$CONDA_PREFIX"
|
||||
export SQLite3_ROOT="$CONDA_PREFIX"
|
||||
cd /path/to/AlphaPulldown
|
||||
python -m pip install --no-deps -e ./alphafold3
|
||||
build_data
|
||||
```
|
||||
@@ -4,8 +4,9 @@ channels:
|
||||
- bioconda
|
||||
- omnia
|
||||
dependencies:
|
||||
- openmm=8.0
|
||||
- pdbfixer=1.9
|
||||
- numpy<2
|
||||
- openmm>=8.2
|
||||
- pdbfixer>=1.10
|
||||
- kalign2
|
||||
- hmmer
|
||||
- hhsuite
|
||||
|
||||
@@ -26,7 +26,7 @@ dependencies = [
|
||||
"matplotlib>=3.3.3",
|
||||
"ml-collections>=0.1.0",
|
||||
"pandas>=1.5.3",
|
||||
"tensorflow-cpu>=2.16.1",
|
||||
"tensorflow-cpu==2.20.0",
|
||||
"importlib-resources>=6.1.0",
|
||||
"importlib-metadata>=4.8.2,<5.0.0",
|
||||
"biopython>=1.81,<1.82",
|
||||
@@ -44,6 +44,28 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
alphafold2 = [
|
||||
"jax==0.5.3",
|
||||
"jax[cuda12]==0.5.3; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"modelcif>=1.6",
|
||||
"numpy<2",
|
||||
"openmm>=8.2",
|
||||
"pdbfixer>=1.10",
|
||||
]
|
||||
alphafold3 = [
|
||||
"jaxtyping==0.2.34",
|
||||
"jax==0.5.3",
|
||||
"jax[cuda12]==0.5.3; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"jax-triton==0.2.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"modelcif>=1.6",
|
||||
"numpy<2",
|
||||
"openmm>=8.2",
|
||||
"pdbfixer>=1.10",
|
||||
"rdkit==2024.3.5",
|
||||
"triton==3.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"typeguard==2.13.3",
|
||||
"zstandard",
|
||||
]
|
||||
test = [
|
||||
"parameterized",
|
||||
"pytest>=8.0",
|
||||
|
||||
14
test/README.md
Normal file
14
test/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
The active pytest layout is:
|
||||
|
||||
- `test/unit` for fast helper and mocked component tests
|
||||
- `test/integration` for CPU-only filesystem, CLI wiring, and module interaction tests
|
||||
- `test/functional` for heavier workflow tests that still belong to the main package test tree
|
||||
- `test/cluster` for Slurm, GPU, or workstation smoke wrappers that are run explicitly
|
||||
- `test/outdated` for legacy tests kept for reference until they are rewritten or deleted
|
||||
|
||||
Notes:
|
||||
|
||||
- `pytest.ini` only collects `unit`, `integration`, and `functional`.
|
||||
- `conftest.py` auto-applies markers from the directory layout.
|
||||
- `test/alphalink` now holds AlphaLink fixture files used by optional tests and cluster smoke runs.
|
||||
- `test/RELEASE_READINESS.md` tracks which workflows are continuously protected by CI and which still require explicit cluster or manual validation before a release.
|
||||
61
test/RELEASE_READINESS.md
Normal file
61
test/RELEASE_READINESS.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Release Readiness
|
||||
|
||||
This branch has broad maintained test coverage, but the protection is layered.
|
||||
|
||||
## Always-on CI
|
||||
|
||||
GitHub Actions runs:
|
||||
|
||||
- `test/unit`
|
||||
- `test/integration`
|
||||
- coverage collection and reporting
|
||||
- Python `3.10` and `3.11`
|
||||
|
||||
These lanes continuously protect:
|
||||
|
||||
- feature creation wiring and CLI dispatch
|
||||
- fold parsing and object construction
|
||||
- AF2 backend helper behavior
|
||||
- AF3 backend helper behavior
|
||||
- script entrypoint and wrapper argument handling
|
||||
- post-prediction and ModelCIF integration paths
|
||||
- CPU-safe AlphaLink helper logic
|
||||
|
||||
## Explicitly-run validation
|
||||
|
||||
The following are intentionally outside default CI and must be run explicitly when a change touches them:
|
||||
|
||||
- `test/cluster/check_alphafold2_predictions.py`
|
||||
- `test/cluster/check_alphafold3_predictions.py`
|
||||
- `test/cluster/check_alphalink_predictions.py`
|
||||
- manual or cluster-backed GPU/Slurm smoke runs
|
||||
|
||||
Release-critical examples include:
|
||||
|
||||
- AF3 wrapper output isolation for combined JSON folds
|
||||
- MMseqs-generated `.a3m` and `.pkl` / `.pkl.xz` feature artifacts
|
||||
- AF3 species pairing regressions for issue `#588`
|
||||
- AF2 and AF3 dimer inference quality checks such as `ipTM > 0.6`
|
||||
|
||||
## Not Continuously Protected
|
||||
|
||||
The following areas are only partially protected, optional, or report-only:
|
||||
|
||||
- `test/cluster` workflows
|
||||
- `test/alphalink` workflows beyond CPU-safe helper tests
|
||||
- legacy scenarios still parked under `test/outdated`
|
||||
- analysis-pipeline utilities and some deeper ModelCIF internals
|
||||
- Python `3.8`, which is still advertised in packaging but is not exercised by GitHub Actions
|
||||
|
||||
The coverage artifact is useful as an audit input, but it does not prove workflow correctness by itself. In particular, `python test/tools/check_function_coverage.py --report-only` highlights functions that were never executed in CI and should be treated as follow-up audit items, not automatic release blockers.
|
||||
|
||||
## Practical Release Gate
|
||||
|
||||
Before a final release, the expected evidence is:
|
||||
|
||||
1. PR smoke-tests and coverage are green.
|
||||
2. AF3 JSON wrapper regressions are green.
|
||||
3. Issue `#588` MMseqs/AF3 pairing regressions are green.
|
||||
4. Manual or cluster-backed AF2 and AF3 dimer runs from MMseqs features complete successfully with acceptable confidence.
|
||||
|
||||
If those gates are green, the branch is in good release shape even though some heavyweight workflows are still validated outside default CI.
|
||||
7
test/alphalink/README.md
Normal file
7
test/alphalink/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
This directory contains AlphaLink-specific fixture files.
|
||||
|
||||
Active tests now live in the standard buckets:
|
||||
|
||||
- `test/unit/test_crosslink_input.py` for the small crosslink helper checks
|
||||
- `test/cluster/check_alphalink_predictions.py` for GPU and weights-backed smoke tests
|
||||
- `test/outdated/` for older AlphaLink checks that still hard-code cluster paths or stale fixtures
|
||||
@@ -1,50 +0,0 @@
|
||||
import unittest
|
||||
from unifold.dataset import calculate_offsets,create_xl_features,bin_xl
|
||||
from alphafold.data.pipeline_multimer import _FastaChain
|
||||
import numpy as np
|
||||
import gzip,pickle
|
||||
import torch
|
||||
|
||||
class TestCreateObjects(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.crosslink_info ="./test/test_data/test_xl_input.pkl.gz"
|
||||
self.asym_id = torch.tensor([1]*10 + [2]*25 + [3]*40)
|
||||
self.chain_id_map = {
|
||||
"A":_FastaChain(sequence='',description='chain1'),
|
||||
"B":_FastaChain(sequence='',description='chain2'),
|
||||
"C":_FastaChain(sequence='',description='chain3')
|
||||
}
|
||||
self.bins = torch.arange(0,1.05,0.05)
|
||||
return super().setUp()
|
||||
|
||||
def test1_calculate_offsets(self):
|
||||
offsets = calculate_offsets(self.asym_id)
|
||||
offsets = offsets.tolist()
|
||||
expected_offsets = [0,10,35,75]
|
||||
self.assertEqual(offsets,expected_offsets)
|
||||
|
||||
def test2_create_xl_inputs(self):
|
||||
offsets = calculate_offsets(self.asym_id)
|
||||
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
|
||||
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
|
||||
expected_xl = torch.tensor([[10,35,0.01],
|
||||
[3,27,0.01],
|
||||
[5,56,0.01],
|
||||
[20,65,0.01]])
|
||||
self.assertTrue(torch.equal(xl,expected_xl))
|
||||
|
||||
def test3_bin_xl(self):
|
||||
offsets = calculate_offsets(self.asym_id)
|
||||
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
|
||||
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
|
||||
num_res = len(self.asym_id)
|
||||
xl = bin_xl(xl,num_res)
|
||||
expected_xl = np.zeros((num_res,num_res,1))
|
||||
expected_xl[3,27,0] = expected_xl[27,3,0] = torch.bucketize(0.99,self.bins)
|
||||
expected_xl[10,35,0] = expected_xl[35,10,0] = torch.bucketize(0.99,self.bins)
|
||||
expected_xl[5,56,0] = expected_xl[56,5,0] = torch.bucketize(0.99,self.bins)
|
||||
expected_xl[20,65,0] = expected_xl[65,20,0] = torch.bucketize(0.99,self.bins)
|
||||
self.assertTrue(np.array_equal(xl,expected_xl))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -101,6 +101,27 @@ def _non_empty_identifier_count(values) -> int:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def _af2_subprocess_env() -> dict[str, str]:
|
||||
"""Return stable GPU/JAX defaults for AF2 functional subprocesses."""
|
||||
env = os.environ.copy()
|
||||
env.setdefault("OMP_NUM_THREADS", "4")
|
||||
env.setdefault("MKL_NUM_THREADS", "4")
|
||||
env.setdefault("NUMEXPR_NUM_THREADS", "4")
|
||||
env.setdefault("TF_NUM_INTEROP_THREADS", "4")
|
||||
env.setdefault("TF_NUM_INTRAOP_THREADS", "4")
|
||||
env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
|
||||
env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
||||
env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
||||
env.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.8")
|
||||
env.setdefault("JAX_PLATFORM_NAME", "gpu")
|
||||
env.setdefault(
|
||||
"XLA_FLAGS",
|
||||
"--xla_gpu_force_compilation_parallelism=0 "
|
||||
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1",
|
||||
)
|
||||
return env
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# common helper mix-in / assertions #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@@ -159,6 +180,14 @@ class _TestBase(parameterized.TestCase):
|
||||
apd_path / "scripts" / "create_individual_features.py"
|
||||
)
|
||||
|
||||
def _run_prediction_subprocess(self, args):
|
||||
return subprocess.run(
|
||||
args,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=_af2_subprocess_env(),
|
||||
)
|
||||
|
||||
# ---------------- assertions reused by all subclasses ----------------- #
|
||||
def _runCommonTests(self, res: subprocess.CompletedProcess, multimer: bool, dirname: str | None = None):
|
||||
if res.returncode != 0:
|
||||
@@ -265,9 +294,8 @@ class TestRunModes(_TestBase):
|
||||
)
|
||||
def test_(self, protein_list, mode, script):
|
||||
multimer = "monomer" not in protein_list
|
||||
res = subprocess.run(
|
||||
self._args(plist=protein_list, mode=mode, script=script),
|
||||
capture_output=True, text=True
|
||||
res = self._run_prediction_subprocess(
|
||||
self._args(plist=protein_list, mode=mode, script=script)
|
||||
)
|
||||
self._runCommonTests(res, multimer)
|
||||
|
||||
@@ -351,7 +379,7 @@ class TestResume(_TestBase):
|
||||
(self.output_dir / "TEST_homo_2er" / fname).unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
res = subprocess.run(args, capture_output=True, text=True)
|
||||
res = self._run_prediction_subprocess(args)
|
||||
self._runCommonTests(res, multimer=True, dirname="TEST_homo_2er")
|
||||
self._runAfterRelaxTests(relax_mode)
|
||||
|
||||
@@ -422,12 +450,12 @@ class TestDropoutDiversity(_TestBase):
|
||||
# Execute both predictions
|
||||
logger.info("Running prediction without dropout...")
|
||||
#logger.info("".join(args_no_dropout))
|
||||
res_no_dropout = subprocess.run(args_no_dropout, capture_output=True, text=True)
|
||||
res_no_dropout = self._run_prediction_subprocess(args_no_dropout)
|
||||
self.assertEqual(res_no_dropout.returncode, 0,
|
||||
f"No dropout prediction failed: {res_no_dropout.stderr}")
|
||||
|
||||
logger.info("Running prediction with dropout...")
|
||||
res_with_dropout = subprocess.run(args_with_dropout, capture_output=True, text=True)
|
||||
res_with_dropout = self._run_prediction_subprocess(args_with_dropout)
|
||||
self.assertEqual(res_with_dropout.returncode, 0,
|
||||
f"Dropout prediction failed: {res_with_dropout.stderr}")
|
||||
|
||||
@@ -508,7 +536,7 @@ class TestMmseqsIssue588Inference(_TestBase):
|
||||
"--compress_features=True",
|
||||
"--skip_existing=False",
|
||||
]
|
||||
res = subprocess.run(args, capture_output=True, text=True)
|
||||
res = self._run_prediction_subprocess(args)
|
||||
self.assertEqual(
|
||||
res.returncode,
|
||||
0,
|
||||
@@ -588,6 +616,7 @@ class TestMmseqsIssue588Inference(_TestBase):
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=_af2_subprocess_env(),
|
||||
)
|
||||
self.assertEqual(
|
||||
res.returncode,
|
||||
|
||||
@@ -1327,6 +1327,14 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
|
||||
)
|
||||
|
||||
for protein_id in self.ISSUE_588_IDS:
|
||||
self.assertTrue(
|
||||
(source_dir / f"{protein_id}.a3m").is_file(),
|
||||
f"Expected MMseq A3M {source_dir / f'{protein_id}.a3m'} to be created.",
|
||||
)
|
||||
self.assertTrue(
|
||||
(source_dir / f"{protein_id}.pkl.xz").is_file(),
|
||||
f"Expected compressed feature pickle {source_dir / f'{protein_id}.pkl.xz'} to be created.",
|
||||
)
|
||||
shutil.copy2(
|
||||
source_dir / f"{protein_id}.a3m",
|
||||
precomputed_dir / f"{protein_id}.a3m",
|
||||
@@ -1356,6 +1364,15 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
|
||||
"Precomputed-MMseq feature generation failed.\n"
|
||||
f"STDOUT:\n{precomputed_res.stdout}\nSTDERR:\n{precomputed_res.stderr}",
|
||||
)
|
||||
for protein_id in self.ISSUE_588_IDS:
|
||||
self.assertTrue(
|
||||
(precomputed_dir / f"{protein_id}.a3m").is_file(),
|
||||
f"Expected copied MMseq A3M {precomputed_dir / f'{protein_id}.a3m'} to be present.",
|
||||
)
|
||||
self.assertTrue(
|
||||
(precomputed_dir / f"{protein_id}.pkl.xz").is_file(),
|
||||
f"Expected precomputed feature pickle {precomputed_dir / f'{protein_id}.pkl.xz'} to be created.",
|
||||
)
|
||||
return precomputed_dir
|
||||
|
||||
def _prepare_fold_input(
|
||||
@@ -1656,6 +1673,7 @@ class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
|
||||
"--max_template_date=2024-05-02",
|
||||
"--use_mmseqs2=True",
|
||||
"--data_pipeline=alphafold2",
|
||||
"--save_msa_files=True",
|
||||
"--compress_features=True",
|
||||
"--skip_existing=False",
|
||||
],
|
||||
@@ -1676,6 +1694,14 @@ class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
|
||||
feature_dir = self._generate_issue_588_mmseq_features(env)
|
||||
|
||||
for protein_id in self.ISSUE_588_IDS:
|
||||
self.assertTrue(
|
||||
(feature_dir / f"{protein_id}.a3m").is_file(),
|
||||
f"Expected MMseq A3M {feature_dir / f'{protein_id}.a3m'} to be created.",
|
||||
)
|
||||
self.assertTrue(
|
||||
(feature_dir / f"{protein_id}.pkl.xz").is_file(),
|
||||
f"Expected compressed feature pickle {feature_dir / f'{protein_id}.pkl.xz'} to be created.",
|
||||
)
|
||||
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
|
||||
self.assertGreater(
|
||||
_non_empty_identifier_count(
|
||||
@@ -2974,19 +3000,28 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
self._assert_af3_outputs_present(current_output_dir)
|
||||
|
||||
def test_af3_run_multimer_jobs_multiple_json_jobs_create_per_job_subdirs(self):
|
||||
"""Shared AF3 wrapper output roots must isolate multiple JSON jobs by subdirectory."""
|
||||
from alphapulldown.utils.output_paths import derive_af3_job_name_from_json
|
||||
"""Shared AF3 wrapper roots must isolate combined JSON folds by subdirectory."""
|
||||
|
||||
self._require_af3_functional_environment()
|
||||
env = self._make_af3_test_env()
|
||||
flash_impl = self._af3_flash_attention_impl()
|
||||
json_inputs = [
|
||||
self.test_features_dir / "protein_with_ptms.json",
|
||||
self.test_features_dir / "P01308_af3_input.json",
|
||||
json_folds = [
|
||||
[
|
||||
self.test_features_dir / "protein_with_ptms.json",
|
||||
self.test_features_dir / "P61626_af3_input.json",
|
||||
],
|
||||
[
|
||||
self.test_features_dir / "P01308_af3_input.json",
|
||||
self.test_features_dir / "P61626_af3_input.json",
|
||||
],
|
||||
]
|
||||
protein_list = self.output_dir / "test_multiple_json_jobs.txt"
|
||||
protein_list.write_text(
|
||||
"\n".join(json_input.name for json_input in json_inputs) + "\n",
|
||||
"\n".join(
|
||||
";".join(json_input.name for json_input in json_fold)
|
||||
for json_fold in json_folds
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
@@ -3017,11 +3052,16 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
(self.output_dir / "ranking_scores.csv").exists(),
|
||||
"Shared wrapper output root should not contain flattened AF3 JSON outputs.",
|
||||
)
|
||||
self.assertFalse(
|
||||
any(self.output_dir.glob("*_data.json")),
|
||||
"Combined JSON folds should not write flat AF3 input JSONs into the shared root.",
|
||||
)
|
||||
|
||||
for json_input in json_inputs:
|
||||
current_output_dir = self.output_dir / derive_af3_job_name_from_json(
|
||||
str(json_input)
|
||||
)
|
||||
for output_dir_name in (
|
||||
"protein_with_ptms_and_p61626",
|
||||
"p01308_and_p61626",
|
||||
):
|
||||
current_output_dir = self.output_dir / output_dir_name
|
||||
self.assertTrue(
|
||||
current_output_dir.is_dir(),
|
||||
f"Expected per-job output directory {current_output_dir} to be created.",
|
||||
|
||||
@@ -120,6 +120,24 @@ def _timestamp() -> str:
|
||||
return dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
def _default_gpu_env_lines(*, cpus_per_task: int) -> list[str]:
|
||||
thread_count = max(1, min(cpus_per_task, 4))
|
||||
return [
|
||||
"export PYTHONUNBUFFERED=1",
|
||||
f'export OMP_NUM_THREADS="${{OMP_NUM_THREADS:-{thread_count}}}"',
|
||||
f'export MKL_NUM_THREADS="${{MKL_NUM_THREADS:-{thread_count}}}"',
|
||||
f'export NUMEXPR_NUM_THREADS="${{NUMEXPR_NUM_THREADS:-{thread_count}}}"',
|
||||
f'export TF_NUM_INTEROP_THREADS="${{TF_NUM_INTEROP_THREADS:-{thread_count}}}"',
|
||||
f'export TF_NUM_INTRAOP_THREADS="${{TF_NUM_INTRAOP_THREADS:-{thread_count}}}"',
|
||||
'export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"',
|
||||
'export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-2}"',
|
||||
'export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}"',
|
||||
'export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.8}"',
|
||||
'export JAX_PLATFORM_NAME="${JAX_PLATFORM_NAME:-gpu}"',
|
||||
'if [ -z "${XLA_FLAGS:-}" ]; then export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=0 --xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1"; fi',
|
||||
]
|
||||
|
||||
|
||||
def _relative_nodeid_prefix(test_file: Path) -> str:
|
||||
return str(test_file.resolve().relative_to(REPO_ROOT))
|
||||
|
||||
@@ -209,6 +227,7 @@ def write_job_script(
|
||||
job: JobSpec,
|
||||
python_executable: str,
|
||||
use_temp_dir: bool,
|
||||
cpus_per_task: int,
|
||||
) -> None:
|
||||
pytest_cmd = [
|
||||
python_executable,
|
||||
@@ -228,7 +247,7 @@ def write_job_script(
|
||||
"#!/bin/bash",
|
||||
"set -euo pipefail",
|
||||
f"cd {_quote(str(REPO_ROOT))}",
|
||||
"export PYTHONUNBUFFERED=1",
|
||||
*_default_gpu_env_lines(cpus_per_task=cpus_per_task),
|
||||
"echo \"[$(date)] Running test node:\"",
|
||||
f"echo {_quote(job.nodeid)}",
|
||||
"echo \"[$(date)] Host: $(hostname)\"",
|
||||
@@ -634,7 +653,7 @@ def main() -> int:
|
||||
stderr_path = log_dir / f"{index:03d}_{slug}.err"
|
||||
script_path = log_dir / f"{index:03d}_{slug}.sbatch.sh"
|
||||
rerun_command = (
|
||||
f"{_quote(args.python)} -m pytest -vv -s {_quote(nodeid)}"
|
||||
f"{_quote(args.python)} -m pytest -o {_quote('addopts=-ra --strict-markers')} -vv -s {_quote(nodeid)}"
|
||||
+ (" --use-temp-dir" if args.use_temp_dir else "")
|
||||
)
|
||||
job = JobSpec(
|
||||
@@ -650,6 +669,7 @@ def main() -> int:
|
||||
job=job,
|
||||
python_executable=args.python,
|
||||
use_temp_dir=args.use_temp_dir,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
)
|
||||
jobs.append(job)
|
||||
|
||||
|
||||
@@ -650,10 +650,21 @@ def main() -> int:
|
||||
stdout_path = log_dir / f"{index:03d}_{slug}.out"
|
||||
stderr_path = log_dir / f"{index:03d}_{slug}.err"
|
||||
script_path = log_dir / f"{index:03d}_{slug}.sbatch.sh"
|
||||
rerun_command = (
|
||||
f"{_quote(args.python)} -m pytest -vv -s {_quote(nodeid)}"
|
||||
+ (" --use-temp-dir" if args.use_temp_dir else "")
|
||||
)
|
||||
rerun_parts = [
|
||||
_quote(args.python),
|
||||
"-m",
|
||||
"pytest",
|
||||
"-o",
|
||||
_quote("addopts=-ra --strict-markers"),
|
||||
"-vv",
|
||||
"-s",
|
||||
_quote(nodeid),
|
||||
]
|
||||
if args.use_temp_dir:
|
||||
rerun_parts.append("--use-temp-dir")
|
||||
rerun_command = " ".join(rerun_parts)
|
||||
if args.include_perf:
|
||||
rerun_command = f"AF3_RUN_PERF_TESTS=1 {rerun_command}"
|
||||
job = JobSpec(
|
||||
index=index,
|
||||
nodeid=nodeid,
|
||||
|
||||
@@ -1,2 +1,9 @@
|
||||
Functional tests live here when they are deterministic, CPU-safe, and heavier than the
|
||||
unit/integration layers. GPU, Slurm, or external-tool smoke tests belong under `test/cluster/`.
|
||||
unit/integration layers.
|
||||
|
||||
Some functional suites may still carry the `external_tools` marker when they shell
|
||||
into the real feature-generation stack or depend on heavyweight local runtimes.
|
||||
Those stay in this directory because they are package-level workflow tests, but
|
||||
they are still excluded from the default CPU-only pytest invocation.
|
||||
|
||||
GPU or Slurm smoke wrappers belong under `test/cluster/`.
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
"""Template-heavy feature-generation smoke tests.
|
||||
|
||||
These exercise the real CLI path for custom-template feature creation. They are
|
||||
kept under `test/functional/`, but still marked `external_tools` because the
|
||||
subprocess touches the ColabFold/JAX import chain and local bioinformatics
|
||||
tooling that is not available in the lightweight default test environment.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
@@ -9,11 +17,15 @@ import lzma
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
from alphapulldown.utils.remove_clashes_low_plddt import extract_seqs
|
||||
|
||||
|
||||
pytestmark = pytest.mark.external_tools
|
||||
|
||||
|
||||
class TestCreateIndividualFeaturesWithTemplates(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -5,11 +5,14 @@ Tests both AlphaFold2 and AlphaFold3 pipelines with various configurations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import json
|
||||
import lzma
|
||||
import pickle
|
||||
import pytest
|
||||
import logging
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, mock_open
|
||||
|
||||
@@ -22,8 +25,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimal real MonomericObject for pickling
|
||||
class DummyMonomer:
|
||||
def __init__(self, description):
|
||||
def __init__(self, description, sequence=None):
|
||||
self.description = description
|
||||
self.sequence = sequence
|
||||
self.feature_dict = {}
|
||||
self.uniprot_runner = None
|
||||
def make_features(self, *a, **k):
|
||||
@@ -33,6 +37,20 @@ class DummyMonomer:
|
||||
def all_seq_msa_features(self, *a, **k):
|
||||
return {}
|
||||
|
||||
|
||||
class RecordingDummyMonomer(DummyMonomer):
|
||||
def __init__(self, description, sequence=None):
|
||||
super().__init__(description, sequence)
|
||||
self.feature_calls = []
|
||||
self.mmseq_calls = []
|
||||
|
||||
def make_features(self, *args, **kwargs):
|
||||
self.feature_calls.append(kwargs)
|
||||
|
||||
def make_mmseq_features(self, *args, **kwargs):
|
||||
self.mmseq_calls.append(kwargs)
|
||||
|
||||
|
||||
class DummyJsonObj:
|
||||
def to_json(self):
|
||||
return '{"test": "features"}'
|
||||
@@ -44,6 +62,60 @@ def real_write_text(self, content, *args, **kwargs):
|
||||
f.write(content)
|
||||
return len(content)
|
||||
|
||||
|
||||
def build_af3_stub_modules():
|
||||
alphafold3_pkg = types.ModuleType("alphafold3")
|
||||
alphafold3_pkg.__path__ = []
|
||||
common_pkg = types.ModuleType("alphafold3.common")
|
||||
common_pkg.__path__ = []
|
||||
structure_pkg = types.ModuleType("alphafold3.structure")
|
||||
structure_pkg.__path__ = []
|
||||
folding_input_mod = types.ModuleType("alphafold3.common.folding_input")
|
||||
mmcif_mod = types.ModuleType("alphafold3.structure.mmcif")
|
||||
|
||||
class ProteinChain:
|
||||
def __init__(self, sequence, id, ptms=None):
|
||||
self.sequence = sequence
|
||||
self.id = id
|
||||
self.ptms = [] if ptms is None else list(ptms)
|
||||
|
||||
class RnaChain:
|
||||
def __init__(self, sequence, id, modifications=None):
|
||||
self.sequence = sequence
|
||||
self.id = id
|
||||
self.modifications = [] if modifications is None else list(modifications)
|
||||
|
||||
class DnaChain:
|
||||
def __init__(self, sequence, id, modifications=None):
|
||||
self.sequence = sequence
|
||||
self.id = id
|
||||
self.modifications = [] if modifications is None else list(modifications)
|
||||
|
||||
class Input:
|
||||
def __init__(self, name, chains, rng_seeds):
|
||||
self.name = name
|
||||
self.chains = list(chains)
|
||||
self.rng_seeds = list(rng_seeds)
|
||||
|
||||
folding_input_mod.ProteinChain = ProteinChain
|
||||
folding_input_mod.RnaChain = RnaChain
|
||||
folding_input_mod.DnaChain = DnaChain
|
||||
folding_input_mod.Input = Input
|
||||
mmcif_mod.int_id_to_str_id = lambda idx: chr(ord("A") + idx - 1)
|
||||
|
||||
alphafold3_pkg.common = common_pkg
|
||||
alphafold3_pkg.structure = structure_pkg
|
||||
common_pkg.folding_input = folding_input_mod
|
||||
structure_pkg.mmcif = mmcif_mod
|
||||
|
||||
return {
|
||||
"alphafold3": alphafold3_pkg,
|
||||
"alphafold3.common": common_pkg,
|
||||
"alphafold3.common.folding_input": folding_input_mod,
|
||||
"alphafold3.structure": structure_pkg,
|
||||
"alphafold3.structure.mmcif": mmcif_mod,
|
||||
}, folding_input_mod
|
||||
|
||||
class TestCreateIndividualFeaturesComprehensive:
|
||||
"""Comprehensive test cases for create_individual_features.py."""
|
||||
|
||||
@@ -98,10 +170,10 @@ class TestCreateIndividualFeaturesComprehensive:
|
||||
("alphafold2", "multi_protein.fasta", False, False),
|
||||
("alphafold2", "single_protein.fasta", True, False), # mmseqs2
|
||||
("alphafold2", "single_protein.fasta", False, True), # compressed
|
||||
#("alphafold3", "single_protein.fasta", False, False),
|
||||
#("alphafold3", "multi_protein.fasta", False, False),
|
||||
#("alphafold3", "rna.fasta", False, False),
|
||||
#("alphafold3", "dna.fasta", False, False),
|
||||
("alphafold3", "single_protein.fasta", False, False),
|
||||
("alphafold3", "multi_protein.fasta", False, False),
|
||||
("alphafold3", "rna.fasta", False, False),
|
||||
("alphafold3", "dna.fasta", False, False),
|
||||
])
|
||||
def test_feature_creation(self, pipeline, fasta_file, use_mmseqs2, compress_features):
|
||||
"""Test feature creation for different configurations."""
|
||||
@@ -157,39 +229,199 @@ class TestCreateIndividualFeaturesComprehensive:
|
||||
logger.info(f"Verified file exists: {file_path}")
|
||||
else:
|
||||
logger.info("Testing AlphaFold3 pipeline")
|
||||
with patch.object(create_features, 'create_pipeline_af3') as mock_af3_pipeline, \
|
||||
patch('alphapulldown.scripts.create_individual_features.folding_input') as mock_folding_input, \
|
||||
af3_modules, folding_input_stub = build_af3_stub_modules()
|
||||
with patch.dict(sys.modules, af3_modules), \
|
||||
patch.object(create_features, 'create_pipeline_af3') as mock_af3_pipeline, \
|
||||
patch.object(create_features, 'folding_input', folding_input_stub), \
|
||||
patch('pathlib.Path.write_text', new=real_write_text), \
|
||||
patch('alphapulldown.utils.save_meta_data.get_meta_dict', return_value={}):
|
||||
mock_af3_pipeline.return_value = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
|
||||
# Patch chain classes in folding_input
|
||||
mock_folding_input.ProteinChain = lambda sequence, id, ptms: MagicMock()
|
||||
mock_folding_input.RnaChain = lambda sequence, id, modifications=None: MagicMock()
|
||||
mock_folding_input.DnaChain = lambda sequence, id: MagicMock()
|
||||
mock_folding_input.Input = lambda name, chains, rng_seeds: MagicMock()
|
||||
create_features.create_af3_individual_features()
|
||||
|
||||
|
||||
process_calls = mock_af3_pipeline.return_value.process.call_args_list
|
||||
observed_chain_types = [
|
||||
type(call.args[0].chains[0]).__name__ for call in process_calls
|
||||
]
|
||||
|
||||
expected_files = []
|
||||
expected_chain_types = []
|
||||
if fasta_file == "single_protein.fasta":
|
||||
expected_files.append("A0A024R1R8_af3_input.json")
|
||||
expected_chain_types.append("ProteinChain")
|
||||
elif fasta_file == "multi_protein.fasta":
|
||||
expected_files.extend(["A0A024R1R8_af3_input.json", "P61626_af3_input.json"])
|
||||
expected_chain_types.extend(["ProteinChain", "ProteinChain"])
|
||||
elif fasta_file == "rna.fasta":
|
||||
expected_files.append("RNA_TEST_af3_input.json")
|
||||
expected_chain_types.append("RnaChain")
|
||||
elif fasta_file == "dna.fasta":
|
||||
expected_files.append("DNA_TEST_af3_input.json")
|
||||
expected_chain_types.append("DnaChain")
|
||||
|
||||
logger.info(f"Checking for expected files: {expected_files}")
|
||||
assert observed_chain_types == expected_chain_types
|
||||
for expected_file in expected_files:
|
||||
file_path = os.path.join(output_dir, expected_file)
|
||||
# Simulate file creation
|
||||
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(file_path).write_text('{"test": "features"}')
|
||||
assert os.path.exists(file_path), f"Expected file {file_path} not found"
|
||||
logger.info(f"Verified file exists: {file_path}")
|
||||
|
||||
logger.info("Feature creation test completed successfully")
|
||||
|
||||
def test_af3_invalid_sequence_is_logged_and_skipped(self):
|
||||
"""Invalid AF3 sequences should not create output files."""
|
||||
invalid_fasta = os.path.join(self.fasta_dir, "invalid_af3.fasta")
|
||||
with open(invalid_fasta, "w") as handle:
|
||||
handle.write(">INVALID\nACDZ*\n")
|
||||
|
||||
from absl import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS(["test"])
|
||||
FLAGS.data_pipeline = "alphafold3"
|
||||
FLAGS.fasta_paths = [invalid_fasta]
|
||||
FLAGS.data_dir = self.af3_db
|
||||
FLAGS.output_dir = os.path.join(self.test_dir, "output_invalid_af3")
|
||||
FLAGS.max_template_date = "2021-09-30"
|
||||
|
||||
error_messages = []
|
||||
af3_modules, folding_input_stub = build_af3_stub_modules()
|
||||
|
||||
with patch.dict(sys.modules, af3_modules), \
|
||||
patch.object(create_features, "create_pipeline_af3") as mock_af3_pipeline, \
|
||||
patch.object(create_features, "folding_input", folding_input_stub), \
|
||||
patch.object(create_features.logging, "error", side_effect=error_messages.append):
|
||||
mock_af3_pipeline.return_value = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
|
||||
|
||||
create_features.create_af3_individual_features()
|
||||
|
||||
mock_af3_pipeline.return_value.process.assert_not_called()
|
||||
assert not os.path.exists(
|
||||
os.path.join(FLAGS.output_dir, "INVALID_af3_input.json")
|
||||
)
|
||||
assert any("Failed to create AlphaFold3 input object" in message for message in error_messages)
|
||||
|
||||
def test_create_individual_features_truemultimer_respects_seq_index(self):
|
||||
"""TrueMultimer mode should only process the selected CSV row."""
|
||||
from absl import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS(["test"])
|
||||
FLAGS.description_file = os.path.join(self.test_dir, "description.csv")
|
||||
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "multi_protein.fasta")]
|
||||
FLAGS.path_to_mmt = os.path.join(self.test_dir, "templates")
|
||||
FLAGS.multiple_mmts = True
|
||||
FLAGS.seq_index = 2
|
||||
|
||||
feats = [
|
||||
{"protein": "prot1"},
|
||||
{"protein": "prot2"},
|
||||
{"protein": "prot3"},
|
||||
]
|
||||
|
||||
with patch.object(create_features, "parse_csv_file", return_value=feats) as mock_parse, \
|
||||
patch.object(create_features, "process_multimeric_features") as mock_process:
|
||||
create_features.create_individual_features_truemultimer()
|
||||
|
||||
mock_parse.assert_called_once_with(
|
||||
FLAGS.description_file,
|
||||
FLAGS.fasta_paths,
|
||||
FLAGS.path_to_mmt,
|
||||
FLAGS.multiple_mmts,
|
||||
)
|
||||
mock_process.assert_called_once_with(feats[1], 2)
|
||||
|
||||
def test_process_multimeric_features_rejects_missing_templates(self):
|
||||
"""TrueMultimer mode should fail early if a template path is missing."""
|
||||
feat = {
|
||||
"protein": "complexA",
|
||||
"chains": ["A"],
|
||||
"templates": [os.path.join(self.test_dir, "missing_template.cif")],
|
||||
"sequence": "ACDE",
|
||||
}
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="does not exist"):
|
||||
create_features.process_multimeric_features(feat, 1)
|
||||
|
||||
def test_process_multimeric_features_creates_custom_db_and_saves_monomer(self):
|
||||
"""TrueMultimer processing should build a custom DB and hand a monomer to the saver."""
|
||||
template_path = os.path.join(self.test_dir, "template1.cif")
|
||||
Path(template_path).write_text("data_template\n", encoding="utf-8")
|
||||
|
||||
class RecordingMonomer:
|
||||
def __init__(self, description, sequence):
|
||||
self.description = description
|
||||
self.sequence = sequence
|
||||
self.feature_dict = {}
|
||||
self.uniprot_runner = None
|
||||
|
||||
feat = {
|
||||
"protein": "complexB",
|
||||
"chains": ["A", "B"],
|
||||
"templates": [template_path],
|
||||
"sequence": "ACDEFG",
|
||||
}
|
||||
|
||||
from absl import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS(["test"])
|
||||
FLAGS.output_dir = os.path.join(self.test_dir, "truemultimer_output")
|
||||
FLAGS.data_dir = self.af2_db
|
||||
FLAGS.max_template_date = "2021-09-30"
|
||||
FLAGS.use_mmseqs2 = False
|
||||
FLAGS.jackhmmer_binary_path = "/usr/bin/jackhmmer"
|
||||
FLAGS.uniprot_database_path = "/db/uniprot.fasta"
|
||||
|
||||
with patch.object(create_features, "MonomericObject", RecordingMonomer), \
|
||||
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
|
||||
patch.object(create_features, "create_arguments") as mock_create_arguments, \
|
||||
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
|
||||
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
|
||||
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
|
||||
create_features.process_multimeric_features(feat, 1)
|
||||
|
||||
mock_custom_db.assert_called_once()
|
||||
custom_db_args = mock_custom_db.call_args.args
|
||||
assert custom_db_args[1:] == (
|
||||
"complexB",
|
||||
[template_path],
|
||||
["A", "B"],
|
||||
)
|
||||
mock_create_arguments.assert_called_once_with("/tmp/custom_db")
|
||||
mock_pipeline.assert_called_once_with()
|
||||
mock_runner.assert_called_once_with(
|
||||
FLAGS.jackhmmer_binary_path,
|
||||
FLAGS.uniprot_database_path,
|
||||
)
|
||||
|
||||
saved_monomer, saved_pipeline = mock_save.call_args.args
|
||||
assert saved_pipeline == "pipeline"
|
||||
assert saved_monomer.description == "complexB"
|
||||
assert saved_monomer.sequence == "ACDEFG"
|
||||
assert saved_monomer.uniprot_runner == "runner"
|
||||
|
||||
def test_main_dispatches_to_truemultimer_for_af2_template_runs(self):
|
||||
"""The main entrypoint should route AF2 template jobs to the TrueMultimer path."""
|
||||
from absl import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS(["test"])
|
||||
FLAGS.data_pipeline = "alphafold2"
|
||||
FLAGS.fasta_paths = [os.path.join(self.fasta_dir, "single_protein.fasta")]
|
||||
FLAGS.data_dir = self.af2_db
|
||||
FLAGS.output_dir = os.path.join(self.test_dir, "main_truemultimer")
|
||||
FLAGS.max_template_date = "2021-09-30"
|
||||
FLAGS.path_to_mmt = os.path.join(self.test_dir, "templates")
|
||||
|
||||
with patch.object(create_features, "check_template_date") as mock_check, \
|
||||
patch.object(create_features, "create_individual_features_truemultimer") as mock_tm, \
|
||||
patch.object(create_features, "create_individual_features") as mock_single:
|
||||
create_features.main([])
|
||||
|
||||
mock_check.assert_called_once_with()
|
||||
mock_tm.assert_called_once_with()
|
||||
mock_single.assert_not_called()
|
||||
|
||||
def test_database_path_mapping(self):
|
||||
"""Test that database paths are correctly mapped for both pipelines."""
|
||||
logger.info("Testing database path mapping")
|
||||
@@ -229,7 +461,7 @@ class TestCreateIndividualFeaturesComprehensive:
|
||||
|
||||
FLAGS.data_pipeline = "alphafold3"
|
||||
FLAGS.data_dir = "/test/db"
|
||||
with pytest.raises(ImportError):
|
||||
with pytest.raises(ImportError, match="pip install -e .*alphafold3,test.*build_data"):
|
||||
create_features.create_pipeline_af3()
|
||||
logger.info("AF3 pipeline creation correctly failed with ImportError")
|
||||
|
||||
@@ -669,4 +901,287 @@ class TestCreateIndividualFeaturesComprehensive:
|
||||
with pytest.raises(SystemExit):
|
||||
create_features.main([])
|
||||
|
||||
logger.info("Flag validation for MMseqs2 scenarios successful")
|
||||
logger.info("Flag validation for MMseqs2 scenarios successful")
|
||||
|
||||
|
||||
def test_create_pipeline_af2_uses_hhsearch_template_stack(tmp_flags):
|
||||
create_features.FLAGS.use_mmseqs2 = False
|
||||
create_features.FLAGS.use_hhsearch = True
|
||||
create_features.FLAGS.hhsearch_binary_path = "/bin/hhsearch"
|
||||
create_features.FLAGS.pdb70_database_path = "/db/pdb70"
|
||||
create_features.FLAGS.template_mmcif_dir = "/db/mmcif"
|
||||
create_features.FLAGS.max_template_date = "2021-09-30"
|
||||
create_features.FLAGS.kalign_binary_path = "/bin/kalign"
|
||||
create_features.FLAGS.obsolete_pdbs_path = "/db/obsolete.dat"
|
||||
|
||||
with patch.object(create_features.hhsearch, "HHSearch", return_value="searcher") as mock_searcher, \
|
||||
patch.object(create_features.templates, "HhsearchHitFeaturizer", return_value="featurizer") as mock_featurizer, \
|
||||
patch.object(create_features, "AF2DataPipeline", return_value="pipeline") as mock_pipeline:
|
||||
pipeline = create_features.create_pipeline_af2()
|
||||
|
||||
assert pipeline == "pipeline"
|
||||
mock_searcher.assert_called_once_with(
|
||||
binary_path="/bin/hhsearch",
|
||||
databases=["/db/pdb70"],
|
||||
)
|
||||
mock_featurizer.assert_called_once_with(
|
||||
mmcif_dir="/db/mmcif",
|
||||
max_template_date="2021-09-30",
|
||||
max_hits=20,
|
||||
kalign_binary_path="/bin/kalign",
|
||||
release_dates_path=None,
|
||||
obsolete_pdbs_path="/db/obsolete.dat",
|
||||
)
|
||||
assert mock_pipeline.call_args.kwargs["template_searcher"] == "searcher"
|
||||
assert mock_pipeline.call_args.kwargs["template_featurizer"] == "featurizer"
|
||||
|
||||
|
||||
def test_create_pipeline_af2_uses_hmmsearch_template_stack(tmp_flags):
|
||||
create_features.FLAGS.use_mmseqs2 = False
|
||||
create_features.FLAGS.use_hhsearch = False
|
||||
create_features.FLAGS.hmmsearch_binary_path = "/bin/hmmsearch"
|
||||
create_features.FLAGS.hmmbuild_binary_path = "/bin/hmmbuild"
|
||||
create_features.FLAGS.pdb_seqres_database_path = "/db/pdb_seqres.txt"
|
||||
create_features.FLAGS.template_mmcif_dir = "/db/mmcif"
|
||||
create_features.FLAGS.max_template_date = "2021-09-30"
|
||||
create_features.FLAGS.kalign_binary_path = "/bin/kalign"
|
||||
create_features.FLAGS.obsolete_pdbs_path = "/db/obsolete.dat"
|
||||
|
||||
with patch.object(create_features.hmmsearch, "Hmmsearch", return_value="searcher") as mock_searcher, \
|
||||
patch.object(create_features.templates, "HmmsearchHitFeaturizer", return_value="featurizer") as mock_featurizer, \
|
||||
patch.object(create_features, "AF2DataPipeline", return_value="pipeline") as mock_pipeline:
|
||||
pipeline = create_features.create_pipeline_af2()
|
||||
|
||||
assert pipeline == "pipeline"
|
||||
mock_searcher.assert_called_once_with(
|
||||
binary_path="/bin/hmmsearch",
|
||||
hmmbuild_binary_path="/bin/hmmbuild",
|
||||
database_path="/db/pdb_seqres.txt",
|
||||
)
|
||||
mock_featurizer.assert_called_once_with(
|
||||
mmcif_dir="/db/mmcif",
|
||||
max_template_date="2021-09-30",
|
||||
max_hits=20,
|
||||
kalign_binary_path="/bin/kalign",
|
||||
obsolete_pdbs_path="/db/obsolete.dat",
|
||||
release_dates_path=None,
|
||||
)
|
||||
assert mock_pipeline.call_args.kwargs["template_searcher"] == "searcher"
|
||||
assert mock_pipeline.call_args.kwargs["template_featurizer"] == "featurizer"
|
||||
|
||||
|
||||
def test_create_individual_features_only_saves_selected_sequence(tmp_flags):
|
||||
create_features.FLAGS.seq_index = 2
|
||||
|
||||
with patch.object(create_features, "create_arguments") as mock_arguments, \
|
||||
patch.object(create_features, "create_pipeline_af2", return_value="pipeline") as mock_pipeline, \
|
||||
patch.object(create_features, "create_uniprot_runner", return_value="runner") as mock_runner, \
|
||||
patch.object(create_features, "MonomericObject", DummyMonomer), \
|
||||
patch.object(create_features, "iter_seqs", return_value=[("AAAA", "first"), ("BBBB", "second")]), \
|
||||
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
|
||||
create_features.create_individual_features()
|
||||
|
||||
mock_arguments.assert_called_once_with()
|
||||
mock_pipeline.assert_called_once_with()
|
||||
mock_runner.assert_called_once()
|
||||
saved_monomer, saved_pipeline = mock_save.call_args.args
|
||||
assert saved_pipeline == "pipeline"
|
||||
assert saved_monomer.description == "second"
|
||||
assert saved_monomer.uniprot_runner == "runner"
|
||||
|
||||
|
||||
def test_create_and_save_monomer_objects_writes_compressed_af2_outputs(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.output_dir = str(tmp_path)
|
||||
create_features.FLAGS.compress_features = True
|
||||
create_features.FLAGS.skip_existing = False
|
||||
create_features.FLAGS.use_mmseqs2 = False
|
||||
create_features.FLAGS.use_precomputed_msas = True
|
||||
create_features.FLAGS.save_msa_files = True
|
||||
|
||||
monomer = RecordingDummyMonomer("protA")
|
||||
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
|
||||
create_features.create_and_save_monomer_objects(monomer, pipeline="pipeline")
|
||||
|
||||
metadata_files = list(tmp_path.glob("protA_feature_metadata_*.json.xz"))
|
||||
assert len(metadata_files) == 1
|
||||
with lzma.open(metadata_files[0], "rt", encoding="utf-8") as handle:
|
||||
assert json.load(handle) == {"source": "test"}
|
||||
assert (tmp_path / "protA.pkl.xz").exists()
|
||||
assert monomer.feature_calls == [
|
||||
{
|
||||
"pipeline": "pipeline",
|
||||
"output_dir": str(tmp_path),
|
||||
"use_precomputed_msa": True,
|
||||
"save_msa": True,
|
||||
}
|
||||
]
|
||||
assert monomer.mmseq_calls == []
|
||||
|
||||
|
||||
def test_create_and_save_monomer_objects_skips_existing_outputs(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.output_dir = str(tmp_path)
|
||||
create_features.FLAGS.compress_features = False
|
||||
create_features.FLAGS.skip_existing = True
|
||||
create_features.FLAGS.use_mmseqs2 = True
|
||||
|
||||
existing_pickle = tmp_path / "protA.pkl"
|
||||
existing_pickle.write_bytes(b"already-there")
|
||||
monomer = RecordingDummyMonomer("protA")
|
||||
create_features.create_and_save_monomer_objects(monomer, pipeline=None)
|
||||
|
||||
assert monomer.feature_calls == []
|
||||
assert monomer.mmseq_calls == []
|
||||
assert list(tmp_path.glob("protA_feature_metadata_*.json")) == []
|
||||
|
||||
|
||||
def test_create_and_save_monomer_objects_uses_mmseqs_when_requested(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.output_dir = str(tmp_path)
|
||||
create_features.FLAGS.compress_features = False
|
||||
create_features.FLAGS.skip_existing = False
|
||||
create_features.FLAGS.use_mmseqs2 = True
|
||||
create_features.FLAGS.use_precomputed_msas = True
|
||||
create_features.FLAGS.re_search_templates_mmseqs2 = True
|
||||
|
||||
monomer = RecordingDummyMonomer("protA")
|
||||
with patch("alphapulldown.utils.save_meta_data.get_meta_dict", return_value={"source": "test"}):
|
||||
create_features.create_and_save_monomer_objects(monomer, pipeline=None)
|
||||
|
||||
assert monomer.feature_calls == []
|
||||
assert monomer.mmseq_calls == [
|
||||
{
|
||||
"DEFAULT_API_SERVER": create_features.DEFAULT_API_SERVER,
|
||||
"output_dir": str(tmp_path),
|
||||
"use_precomputed_msa": True,
|
||||
"use_templates": True,
|
||||
}
|
||||
]
|
||||
assert (tmp_path / "protA.pkl").exists()
|
||||
|
||||
|
||||
def test_process_multimeric_features_uses_mmseqs_without_local_pipeline(tmp_flags, tmp_path):
|
||||
template_path = tmp_path / "template.cif"
|
||||
template_path.write_text("data_template\n", encoding="utf-8")
|
||||
|
||||
create_features.FLAGS.output_dir = str(tmp_path / "out")
|
||||
create_features.FLAGS.use_mmseqs2 = True
|
||||
|
||||
feat = {
|
||||
"protein": "complex_mmseqs",
|
||||
"chains": ["A"],
|
||||
"templates": [str(template_path)],
|
||||
"sequence": "ACDE",
|
||||
}
|
||||
|
||||
with patch.object(create_features, "MonomericObject", RecordingDummyMonomer), \
|
||||
patch.object(create_features, "create_custom_db", return_value="/tmp/custom_db") as mock_custom_db, \
|
||||
patch.object(create_features, "create_arguments") as mock_arguments, \
|
||||
patch.object(create_features, "create_pipeline_af2") as mock_pipeline, \
|
||||
patch.object(create_features, "create_uniprot_runner") as mock_runner, \
|
||||
patch.object(create_features, "create_and_save_monomer_objects") as mock_save:
|
||||
create_features.process_multimeric_features(feat, 1)
|
||||
|
||||
mock_custom_db.assert_called_once()
|
||||
mock_arguments.assert_called_once_with("/tmp/custom_db")
|
||||
mock_pipeline.assert_not_called()
|
||||
mock_runner.assert_not_called()
|
||||
saved_monomer, saved_pipeline = mock_save.call_args.args
|
||||
assert saved_pipeline is None
|
||||
assert saved_monomer.description == "complex_mmseqs"
|
||||
assert saved_monomer.uniprot_runner is None
|
||||
|
||||
|
||||
def test_create_custom_db_passes_thresholds_to_builder(tmp_flags):
|
||||
create_features.FLAGS.threshold_clashes = 12.5
|
||||
create_features.FLAGS.hb_allowance = 0.7
|
||||
create_features.FLAGS.plddt_threshold = 42.0
|
||||
|
||||
with patch.object(create_features, "create_db") as mock_create_db:
|
||||
db_path = create_features.create_custom_db("/tmp/base", "proteinX", ["a.cif"], ["A"])
|
||||
|
||||
assert str(db_path) == "/tmp/base/custom_template_db/proteinX"
|
||||
mock_create_db.assert_called_once_with(
|
||||
db_path,
|
||||
["a.cif"],
|
||||
["A"],
|
||||
12.5,
|
||||
0.7,
|
||||
42.0,
|
||||
)
|
||||
|
||||
|
||||
def test_create_pipeline_af3_prefers_explicit_database_overrides(tmp_flags):
|
||||
class DummyConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
create_features.FLAGS.max_template_date = "2021-09-30"
|
||||
create_features.FLAGS.data_pipeline = "alphafold3"
|
||||
create_features.FLAGS.data_dir = "/db"
|
||||
create_features.FLAGS.small_bfd_database_path = "/override/small_bfd"
|
||||
create_features.FLAGS.uniref90_database_path = "/override/uniref90"
|
||||
create_features.FLAGS.template_mmcif_dir = "/override/mmcif"
|
||||
|
||||
with patch.object(create_features, "AF3DataPipelineConfig", side_effect=DummyConfig) as mock_config, \
|
||||
patch.object(create_features, "AF3DataPipeline", side_effect=lambda config: config) as mock_pipeline:
|
||||
config = create_features.create_pipeline_af3()
|
||||
|
||||
mock_config.assert_called_once()
|
||||
mock_pipeline.assert_called_once()
|
||||
assert config.kwargs["small_bfd_database_path"] == "/override/small_bfd"
|
||||
assert config.kwargs["uniref90_database_path"] == "/override/uniref90"
|
||||
assert config.kwargs["pdb_database_path"] == "/override/mmcif"
|
||||
assert config.kwargs["mgnify_database_path"] == "/db/mgy_clusters_2022_05.fa"
|
||||
assert config.kwargs["seqres_database_path"] == "/db/pdb_seqres_2022_09_28.fasta"
|
||||
|
||||
|
||||
def test_create_af3_individual_features_falls_back_to_double_letter_chain_ids(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.output_dir = str(tmp_path)
|
||||
create_features.FLAGS.seq_index = 27
|
||||
|
||||
af3_modules, folding_input_stub = build_af3_stub_modules()
|
||||
del af3_modules["alphafold3.structure"].mmcif
|
||||
af3_modules.pop("alphafold3.structure.mmcif")
|
||||
|
||||
sequences = [("ACDE", f"chain_{idx}") for idx in range(1, 28)]
|
||||
with patch.dict(sys.modules, af3_modules), \
|
||||
patch.object(create_features, "create_pipeline_af3", return_value=MagicMock(process=MagicMock(return_value={"plain": "json"}))), \
|
||||
patch.object(create_features, "folding_input", folding_input_stub), \
|
||||
patch.object(create_features, "iter_seqs", return_value=sequences), \
|
||||
patch("pathlib.Path.write_text", new=real_write_text):
|
||||
create_features.create_af3_individual_features()
|
||||
|
||||
outpath = tmp_path / "chain_27_af3_input.json"
|
||||
assert outpath.exists()
|
||||
assert json.loads(outpath.read_text(encoding="utf-8")) == {"plain": "json"}
|
||||
|
||||
|
||||
def test_create_af3_individual_features_skips_existing_outputs(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.output_dir = str(tmp_path)
|
||||
create_features.FLAGS.skip_existing = True
|
||||
|
||||
af3_modules, folding_input_stub = build_af3_stub_modules()
|
||||
existing_output = tmp_path / "protA_af3_input.json"
|
||||
existing_output.write_text("{}", encoding="utf-8")
|
||||
|
||||
pipeline = MagicMock(process=MagicMock(return_value=DummyJsonObj()))
|
||||
with patch.dict(sys.modules, af3_modules), \
|
||||
patch.object(create_features, "create_pipeline_af3", return_value=pipeline), \
|
||||
patch.object(create_features, "folding_input", folding_input_stub), \
|
||||
patch.object(create_features, "iter_seqs", return_value=[("ACDE", "protA")]), \
|
||||
patch("pathlib.Path.write_text", new=real_write_text):
|
||||
create_features.create_af3_individual_features()
|
||||
|
||||
pipeline.process.assert_not_called()
|
||||
assert existing_output.read_text(encoding="utf-8") == "{}"
|
||||
|
||||
|
||||
def test_main_dispatches_to_af3_feature_creation(tmp_flags, tmp_path):
|
||||
create_features.FLAGS.data_pipeline = "alphafold3"
|
||||
create_features.FLAGS.output_dir = str(tmp_path / "af3_out")
|
||||
|
||||
with patch.object(create_features, "create_af3_individual_features") as mock_af3, \
|
||||
patch.object(create_features, "check_template_date") as mock_check:
|
||||
create_features.main([])
|
||||
|
||||
mock_af3.assert_called_once_with()
|
||||
mock_check.assert_not_called()
|
||||
|
||||
@@ -16,6 +16,8 @@ import pytest
|
||||
Test conversion of PDB to CIF for monomers and multimers
|
||||
"""
|
||||
|
||||
pytest.importorskip("ihm")
|
||||
pytest.importorskip("modelcif")
|
||||
pytestmark = pytest.mark.external_tools
|
||||
|
||||
|
||||
|
||||
973
test/unit/test_alphafold2_backend_helpers.py
Normal file
973
test/unit/test_alphafold2_backend_helpers.py
Normal file
@@ -0,0 +1,973 @@
|
||||
import importlib.util
|
||||
import json
|
||||
import pickle
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
MODULE_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "alphapulldown"
|
||||
/ "folding_backend"
|
||||
/ "alphafold2_backend.py"
|
||||
)
|
||||
|
||||
|
||||
def _package(name: str) -> types.ModuleType:
|
||||
module = types.ModuleType(name)
|
||||
module.__path__ = [] # type: ignore[attr-defined]
|
||||
return module
|
||||
|
||||
|
||||
class _ConfigNode(dict):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError as exc: # pragma: no cover - defensive
|
||||
raise AttributeError(name) from exc
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self[name] = value
|
||||
|
||||
|
||||
def _restore_modules(saved_modules: dict[str, types.ModuleType | None]) -> None:
|
||||
for name, module in saved_modules.items():
|
||||
if module is None:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = module
|
||||
|
||||
|
||||
def _install_alphafold2_backend_stubs() -> dict[str, types.ModuleType | None]:
|
||||
names_to_replace = [
|
||||
"jax",
|
||||
"jax.numpy",
|
||||
"alphafold",
|
||||
"alphafold.relax",
|
||||
"alphafold.relax.relax",
|
||||
"alphafold.common",
|
||||
"alphafold.common.protein",
|
||||
"alphafold.common.residue_constants",
|
||||
"alphafold.common.confidence",
|
||||
"alphafold.model",
|
||||
"alphafold.model.config",
|
||||
"alphafold.model.data",
|
||||
"alphafold.model.model",
|
||||
"alphapulldown.objects",
|
||||
"alphapulldown.utils.plotting",
|
||||
"alphapulldown.utils.post_modelling",
|
||||
"alphapulldown.utils.modelling_setup",
|
||||
"alphapulldown.utils.af2_to_af3_msa",
|
||||
]
|
||||
saved_modules = {name: sys.modules.get(name) for name in names_to_replace}
|
||||
|
||||
jax_pkg = _package("jax")
|
||||
jax_numpy_mod = types.ModuleType("jax.numpy")
|
||||
jax_numpy_mod.ndarray = np.ndarray
|
||||
jax_numpy_mod.array = np.array
|
||||
|
||||
alphafold_pkg = _package("alphafold")
|
||||
relax_pkg = _package("alphafold.relax")
|
||||
relax_mod = types.ModuleType("alphafold.relax.relax")
|
||||
|
||||
class FakeAmberRelaxation:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def process(self, prot):
|
||||
name = getattr(prot, "name", "protein")
|
||||
return (f"RELAXED:{name}", None, [0, 1])
|
||||
|
||||
relax_mod.AmberRelaxation = FakeAmberRelaxation
|
||||
|
||||
common_pkg = _package("alphafold.common")
|
||||
protein_mod = types.ModuleType("alphafold.common.protein")
|
||||
|
||||
class FakeProtein:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
atom_positions,
|
||||
atom_mask,
|
||||
aatype,
|
||||
residue_index,
|
||||
chain_index,
|
||||
b_factors,
|
||||
):
|
||||
self.atom_positions = atom_positions
|
||||
self.atom_mask = atom_mask
|
||||
self.aatype = aatype
|
||||
self.residue_index = residue_index
|
||||
self.chain_index = chain_index
|
||||
self.b_factors = b_factors
|
||||
self.name = "debug_protein"
|
||||
|
||||
def _to_pdb(prot):
|
||||
return f"PDB:{getattr(prot, 'name', 'protein')}"
|
||||
|
||||
def _from_prediction(features, result, b_factors, remove_leading_feature_dimension):
|
||||
residue_index = np.asarray(features.get("residue_index", [0]), dtype=np.int32)
|
||||
chain_index = np.asarray(features.get("asym_id", np.zeros_like(residue_index)))
|
||||
return SimpleNamespace(
|
||||
name="predicted",
|
||||
features=features,
|
||||
result=result,
|
||||
b_factors=b_factors,
|
||||
residue_index=residue_index,
|
||||
chain_index=chain_index,
|
||||
aatype=np.zeros_like(residue_index),
|
||||
remove_leading_feature_dimension=remove_leading_feature_dimension,
|
||||
)
|
||||
|
||||
protein_mod.Protein = FakeProtein
|
||||
protein_mod.to_pdb = _to_pdb
|
||||
protein_mod.from_prediction = _from_prediction
|
||||
protein_mod.from_pdb_string = lambda text: SimpleNamespace(name=f"from_pdb:{text}")
|
||||
|
||||
residue_constants_mod = types.ModuleType("alphafold.common.residue_constants")
|
||||
residue_constants_mod.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = list(range(64))
|
||||
residue_constants_mod.restype_num = 20
|
||||
residue_constants_mod.atom_type_num = 37
|
||||
|
||||
confidence_mod = types.ModuleType("alphafold.common.confidence")
|
||||
confidence_mod.pae_json = lambda pae, max_pae: json.dumps(
|
||||
{"pae": np.asarray(pae).tolist(), "max_pae": float(max_pae)}
|
||||
)
|
||||
confidence_mod.confidence_json = lambda plddt: json.dumps(
|
||||
{"plddt": np.asarray(plddt).tolist()}
|
||||
)
|
||||
confidence_mod.predicted_tm_score = (
|
||||
lambda logits, breaks, asym_id=None, interface=False: 0.8 if interface else 0.5
|
||||
)
|
||||
confidence_mod.compute_predicted_aligned_error = lambda logits, breaks: {
|
||||
"predicted_aligned_error": np.asarray(logits) + 1,
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
}
|
||||
|
||||
model_pkg = _package("alphafold.model")
|
||||
config_mod = types.ModuleType("alphafold.model.config")
|
||||
config_mod.MODEL_PRESETS = {
|
||||
"multimer": ("model_1_multimer_v3", "model_2_multimer_v3")
|
||||
}
|
||||
|
||||
def _model_config(_name):
|
||||
return _ConfigNode(
|
||||
{
|
||||
"model": _ConfigNode(
|
||||
{
|
||||
"num_ensemble_eval": None,
|
||||
"global_config": _ConfigNode({"eval_dropout": False}),
|
||||
"embeddings_and_evoformer": _ConfigNode(
|
||||
{"num_msa": 64, "num_extra_msa": 256}
|
||||
),
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
config_mod.model_config = _model_config
|
||||
|
||||
data_mod = types.ModuleType("alphafold.model.data")
|
||||
data_mod.get_model_haiku_params = (
|
||||
lambda model_name, data_dir: {"model_name": model_name, "data_dir": data_dir}
|
||||
)
|
||||
|
||||
model_mod = types.ModuleType("alphafold.model.model")
|
||||
|
||||
class FakeRunModel:
|
||||
def __init__(self, config, params):
|
||||
self.config = config
|
||||
self.params = params
|
||||
self.multimer_mode = True
|
||||
|
||||
def process_features(self, feature_dict, random_seed):
|
||||
return dict(feature_dict)
|
||||
|
||||
def predict(self, processed_feature_dict, random_seed):
|
||||
seq_len = len(np.asarray(processed_feature_dict.get("residue_index", [0, 1])))
|
||||
return {
|
||||
"plddt": np.full(seq_len, 75.0, dtype=np.float32),
|
||||
"predicted_aligned_error": np.zeros((seq_len, seq_len), dtype=np.float32),
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
}
|
||||
|
||||
model_mod.RunModel = FakeRunModel
|
||||
|
||||
objects_mod = types.ModuleType("alphapulldown.objects")
|
||||
|
||||
class MonomericObject:
|
||||
def __init__(self, description="monomer", sequence=""):
|
||||
self.description = description
|
||||
self.sequence = sequence
|
||||
self.feature_dict = {}
|
||||
self.multimeric_mode = False
|
||||
|
||||
class MultimericObject:
|
||||
def __init__(
|
||||
self,
|
||||
description="multimer",
|
||||
input_seqs=None,
|
||||
feature_dict=None,
|
||||
multimeric_mode=True,
|
||||
):
|
||||
self.description = description
|
||||
self.input_seqs = input_seqs or []
|
||||
self.feature_dict = feature_dict or {}
|
||||
self.multimeric_mode = multimeric_mode
|
||||
|
||||
class ChoppedObject(MonomericObject):
|
||||
pass
|
||||
|
||||
objects_mod.MonomericObject = MonomericObject
|
||||
objects_mod.MultimericObject = MultimericObject
|
||||
objects_mod.ChoppedObject = ChoppedObject
|
||||
|
||||
plotting_mod = types.ModuleType("alphapulldown.utils.plotting")
|
||||
plotting_mod.plot_pae_from_matrix = lambda **_kwargs: None
|
||||
|
||||
post_modelling_mod = types.ModuleType("alphapulldown.utils.post_modelling")
|
||||
post_modelling_mod.post_prediction_process = lambda *args, **kwargs: None
|
||||
|
||||
modelling_setup_mod = types.ModuleType("alphapulldown.utils.modelling_setup")
|
||||
modelling_setup_mod.pad_input_features = (
|
||||
lambda feature_dict, desired_num_msa, desired_num_res: None
|
||||
)
|
||||
|
||||
af2_to_af3_msa_mod = types.ModuleType("alphapulldown.utils.af2_to_af3_msa")
|
||||
af2_to_af3_msa_mod.msa_rows_and_deletions_to_a3m = (
|
||||
lambda msa_rows, deletion_rows, query_sequence: (
|
||||
f">query\n{query_sequence}\n>rows\n{len(np.asarray(msa_rows))}\n"
|
||||
)
|
||||
)
|
||||
|
||||
modules = {
|
||||
"jax": jax_pkg,
|
||||
"jax.numpy": jax_numpy_mod,
|
||||
"alphafold": alphafold_pkg,
|
||||
"alphafold.relax": relax_pkg,
|
||||
"alphafold.relax.relax": relax_mod,
|
||||
"alphafold.common": common_pkg,
|
||||
"alphafold.common.protein": protein_mod,
|
||||
"alphafold.common.residue_constants": residue_constants_mod,
|
||||
"alphafold.common.confidence": confidence_mod,
|
||||
"alphafold.model": model_pkg,
|
||||
"alphafold.model.config": config_mod,
|
||||
"alphafold.model.data": data_mod,
|
||||
"alphafold.model.model": model_mod,
|
||||
"alphapulldown.objects": objects_mod,
|
||||
"alphapulldown.utils.plotting": plotting_mod,
|
||||
"alphapulldown.utils.post_modelling": post_modelling_mod,
|
||||
"alphapulldown.utils.modelling_setup": modelling_setup_mod,
|
||||
"alphapulldown.utils.af2_to_af3_msa": af2_to_af3_msa_mod,
|
||||
}
|
||||
|
||||
for name, module in modules.items():
|
||||
sys.modules[name] = module
|
||||
|
||||
jax_pkg.numpy = jax_numpy_mod
|
||||
alphafold_pkg.relax = relax_pkg
|
||||
alphafold_pkg.common = common_pkg
|
||||
alphafold_pkg.model = model_pkg
|
||||
relax_pkg.relax = relax_mod
|
||||
common_pkg.protein = protein_mod
|
||||
common_pkg.residue_constants = residue_constants_mod
|
||||
common_pkg.confidence = confidence_mod
|
||||
model_pkg.config = config_mod
|
||||
model_pkg.data = data_mod
|
||||
model_pkg.model = model_mod
|
||||
|
||||
return saved_modules
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def af2_backend_module():
|
||||
saved_modules = _install_alphafold2_backend_stubs()
|
||||
sys.modules.pop("alphapulldown.folding_backend.alphafold2_backend", None)
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"alphapulldown.folding_backend.alphafold2_backend",
|
||||
MODULE_PATH,
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
assert spec.loader is not None
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
yield module
|
||||
finally:
|
||||
sys.modules.pop(spec.name, None)
|
||||
_restore_modules(saved_modules)
|
||||
|
||||
|
||||
def test_json_template_and_alignment_helpers(af2_backend_module, tmp_path):
|
||||
missing = af2_backend_module._read_from_json_if_exists(tmp_path / "missing.json")
|
||||
assert missing == {}
|
||||
|
||||
existing_path = tmp_path / "timings.json"
|
||||
existing_path.write_text(json.dumps({"stage": 1}), encoding="utf-8")
|
||||
assert af2_backend_module._read_from_json_if_exists(existing_path) == {"stage": 1}
|
||||
|
||||
feature_dict = {
|
||||
"seq_length": 3,
|
||||
"template_aatype": np.ones((2, 3), dtype=np.int32),
|
||||
"template_all_atom_positions": np.ones((2, 3, 37, 3), dtype=np.float32),
|
||||
"template_all_atom_mask": np.zeros((2, 3, 37), dtype=np.float32),
|
||||
"num_templates": np.array([5]),
|
||||
}
|
||||
af2_backend_module._reset_template_features(feature_dict)
|
||||
assert feature_dict["template_aatype"].shape == (1, 3)
|
||||
assert feature_dict["template_all_atom_positions"].shape == (1, 3, 37, 3)
|
||||
assert np.all(feature_dict["template_all_atom_mask"] == 1.0)
|
||||
np.testing.assert_array_equal(feature_dict["num_templates"], np.array([1]))
|
||||
|
||||
assert af2_backend_module._normalise_num_alignments_for_debug({}) == 0
|
||||
assert (
|
||||
af2_backend_module._normalise_num_alignments_for_debug({"msa": np.zeros((2, 3))})
|
||||
== 2
|
||||
)
|
||||
assert (
|
||||
af2_backend_module._normalise_num_alignments_for_debug(
|
||||
{"msa": np.zeros((2, 3)), "num_alignments": np.array([9])}
|
||||
)
|
||||
== 2
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_typing_dataclass_transform_backfills_missing_attribute(
|
||||
af2_backend_module, monkeypatch
|
||||
):
|
||||
monkeypatch.delattr(af2_backend_module.typing, "dataclass_transform", raising=False)
|
||||
|
||||
af2_backend_module._ensure_typing_dataclass_transform()
|
||||
|
||||
assert af2_backend_module.typing.dataclass_transform is not None
|
||||
|
||||
|
||||
def test_resolve_gpu_relax_keeps_cuda_when_available(af2_backend_module, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"_get_openmm_platform_names",
|
||||
lambda: ["Reference", "CPU", "CUDA"],
|
||||
)
|
||||
|
||||
assert af2_backend_module._resolve_gpu_relax(True) is True
|
||||
|
||||
|
||||
def test_resolve_gpu_relax_falls_back_when_cuda_is_missing(
|
||||
af2_backend_module, monkeypatch, caplog
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"_get_openmm_platform_names",
|
||||
lambda: ["Reference", "CPU", "OpenCL"],
|
||||
)
|
||||
|
||||
assert af2_backend_module._resolve_gpu_relax(True) is False
|
||||
assert "falling back to CPU relax" in caplog.text
|
||||
|
||||
|
||||
def test_asym_query_and_msa_debug_helpers(af2_backend_module, tmp_path):
|
||||
normalized = af2_backend_module._normalize_asym_id(
|
||||
{"asym_id": np.array([5, 5, 2, 9], dtype=np.int32)}
|
||||
)
|
||||
np.testing.assert_array_equal(normalized["asym_id"], np.array([0, 0, 1, 2]))
|
||||
|
||||
fallback_normalized = af2_backend_module._normalize_asym_id(
|
||||
{},
|
||||
fallback_feature_dict={"asym_id": np.array([1, 1, 3], dtype=np.int32)},
|
||||
)
|
||||
np.testing.assert_array_equal(fallback_normalized["asym_id"], np.array([0, 0, 1]))
|
||||
|
||||
multimer = af2_backend_module.MultimericObject(
|
||||
description="job",
|
||||
input_seqs=["AA", "BB"],
|
||||
feature_dict={},
|
||||
)
|
||||
monomer = af2_backend_module.MonomericObject("single", "AC")
|
||||
assert af2_backend_module._query_sequence_for_debug(multimer) == "AABB"
|
||||
assert af2_backend_module._query_sequence_for_debug(monomer) == "AC"
|
||||
|
||||
af2_backend_module._write_processed_msa_debug_artifact(
|
||||
processed_feature_dict={
|
||||
"msa": np.array([[1, 2], [3, 4]], dtype=np.int32),
|
||||
"deletion_matrix_int": np.zeros((2, 2), dtype=np.int32),
|
||||
"num_alignments": np.array([1]),
|
||||
},
|
||||
multimeric_object=monomer,
|
||||
output_dir=tmp_path,
|
||||
model_name="modelA",
|
||||
)
|
||||
|
||||
debug_path = tmp_path / "modelA_processed_msa.a3m"
|
||||
assert debug_path.read_text(encoding="utf-8") == ">query\nAC\n>rows\n1\n"
|
||||
|
||||
|
||||
def test_template_debug_helpers_cover_success_and_failure_paths(
|
||||
af2_backend_module,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
assert af2_backend_module._decode_debug_value(b"templ") == "templ"
|
||||
assert (
|
||||
af2_backend_module._sanitize_debug_filename("bad/name:*", "fallback")
|
||||
== "bad_name__"
|
||||
)
|
||||
|
||||
one_hot = np.eye(22, dtype=np.int32)[[1, 2]]
|
||||
np.testing.assert_array_equal(
|
||||
af2_backend_module._template_aatype_to_indices(one_hot),
|
||||
np.array([1, 2], dtype=np.int32),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
af2_backend_module._template_aatype_to_indices(np.array([0, 21], dtype=np.int32)),
|
||||
np.array([0, 20], dtype=np.int32),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Unsupported template_aatype rank"):
|
||||
af2_backend_module._template_aatype_to_indices(np.zeros((1, 1, 1, 1)))
|
||||
|
||||
call_counter = {"value": 0}
|
||||
|
||||
def fake_protein_constructor(**kwargs):
|
||||
current = call_counter["value"]
|
||||
call_counter["value"] += 1
|
||||
if current == 1:
|
||||
raise RuntimeError("boom")
|
||||
return SimpleNamespace(name="template", **kwargs)
|
||||
|
||||
monkeypatch.setattr(af2_backend_module.protein, "Protein", fake_protein_constructor)
|
||||
monkeypatch.setattr(af2_backend_module.protein, "to_pdb", lambda prot: "PDB:TEMPLATE")
|
||||
|
||||
af2_backend_module._write_processed_template_debug_artifacts(
|
||||
processed_feature_dict={
|
||||
"template_all_atom_positions": np.ones((2, 2, 37, 3), dtype=np.float32),
|
||||
"template_all_atom_mask": np.ones((2, 2, 37), dtype=np.float32),
|
||||
"template_aatype": np.stack([np.eye(22)[[0, 1]], np.eye(22)[[0, 1]]]),
|
||||
"template_domain_names": [b"good/template", b"bad"],
|
||||
"residue_index": np.array([7, 8], dtype=np.int32),
|
||||
"asym_id": np.array([3, 3], dtype=np.int32),
|
||||
},
|
||||
output_dir=tmp_path,
|
||||
model_name="modelB",
|
||||
)
|
||||
|
||||
debug_dir = tmp_path / "templates_debug"
|
||||
assert (debug_dir / "modelB_good_template_idx0.pdb").is_file()
|
||||
error_file = debug_dir / "ERROR_modelB_template_1.txt"
|
||||
assert error_file.is_file()
|
||||
assert "boom" in error_file.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_setup_configures_model_runners_and_validates_custom_names(af2_backend_module):
|
||||
configured = af2_backend_module.AlphaFold2Backend.setup(
|
||||
model_name="multimer",
|
||||
num_cycle=5,
|
||||
model_dir="/models",
|
||||
num_predictions_per_model=2,
|
||||
msa_depth_scan=True,
|
||||
model_names_custom=["model_1_multimer"],
|
||||
dropout=True,
|
||||
)
|
||||
|
||||
runners = configured["model_runners"]
|
||||
assert sorted(runners) == [
|
||||
"model_1_multimer_pred_0_msa_16",
|
||||
"model_1_multimer_pred_1_msa_64",
|
||||
]
|
||||
runner = runners["model_1_multimer_pred_0_msa_16"]
|
||||
assert runner.config["model"]["num_recycle"] == 5
|
||||
assert runner.config.model.global_config.eval_dropout is True
|
||||
assert runner.params["data_dir"] == "/models"
|
||||
|
||||
with pytest.raises(Exception, match="Provided model names"):
|
||||
af2_backend_module.AlphaFold2Backend.setup(
|
||||
model_name="multimer",
|
||||
num_cycle=1,
|
||||
model_dir="/models",
|
||||
num_predictions_per_model=1,
|
||||
model_names_custom=["missing_model"],
|
||||
)
|
||||
|
||||
|
||||
def test_predict_individual_job_writes_outputs_and_runs_debug_hooks(
|
||||
af2_backend_module,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
multimer = af2_backend_module.MultimericObject(
|
||||
description="complex",
|
||||
input_seqs=["AB"],
|
||||
feature_dict={"msa": np.ones((1, 2), dtype=np.int32)},
|
||||
multimeric_mode=True,
|
||||
)
|
||||
|
||||
pad_calls = []
|
||||
msa_calls = []
|
||||
template_calls = []
|
||||
|
||||
def fake_pad(feature_dict, desired_num_msa, desired_num_res):
|
||||
pad_calls.append((desired_num_res, desired_num_msa))
|
||||
|
||||
processed_feature_dict = {
|
||||
"msa": np.ones((1, 2), dtype=np.int32),
|
||||
"template_all_atom_positions": np.ones((1, 2, 37, 3), dtype=np.float32),
|
||||
"template_all_atom_mask": np.ones((1, 2, 37), dtype=np.float32),
|
||||
"template_aatype": np.eye(22)[[0, 1]][None, :, :],
|
||||
"residue_index": np.array([0, 1], dtype=np.int32),
|
||||
"asym_id": np.array([1, 2], dtype=np.int32),
|
||||
}
|
||||
fake_runner = SimpleNamespace(
|
||||
multimer_mode=True,
|
||||
process_features=lambda feature_dict, random_seed: dict(processed_feature_dict),
|
||||
predict=lambda processed_feature_dict, random_seed: {
|
||||
"plddt": np.array([91.0, 88.0], dtype=np.float32),
|
||||
"predicted_aligned_error": np.zeros((2, 2), dtype=np.float32),
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(af2_backend_module, "pad_input_features", fake_pad)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"_write_processed_msa_debug_artifact",
|
||||
lambda **kwargs: msa_calls.append(kwargs["model_name"]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"_write_processed_template_debug_artifacts",
|
||||
lambda **kwargs: template_calls.append(kwargs["model_name"]),
|
||||
)
|
||||
|
||||
results = af2_backend_module.AlphaFold2Backend.predict_individual_job(
|
||||
model_runners={"modelA": fake_runner},
|
||||
multimeric_object=multimer,
|
||||
allow_resume=False,
|
||||
skip_templates=False,
|
||||
output_dir=tmp_path,
|
||||
random_seed=7,
|
||||
desired_num_res=4,
|
||||
desired_num_msa=3,
|
||||
debug_msas=True,
|
||||
debug_templates=True,
|
||||
)
|
||||
|
||||
assert pad_calls == [(4, 3)]
|
||||
assert msa_calls == ["modelA"]
|
||||
assert template_calls == ["modelA"]
|
||||
assert "modelA" in results
|
||||
assert results["modelA"]["seqs"] == ["AB"]
|
||||
assert (tmp_path / "result_modelA.pkl").is_file()
|
||||
assert (tmp_path / "unrelaxed_modelA.pdb").read_text(encoding="utf-8") == "PDB:predicted"
|
||||
timings = json.loads((tmp_path / "timings.json").read_text(encoding="utf-8"))
|
||||
assert "process_features_modelA" in timings
|
||||
assert "predict_and_compile_modelA" in timings
|
||||
with open(tmp_path / "result_modelA.pkl", "rb") as handle:
|
||||
payload = pickle.load(handle)
|
||||
assert payload["plddt"].tolist() == [91.0, 88.0]
|
||||
|
||||
|
||||
def test_predict_individual_job_rejects_skipped_templates_in_multimer_mode(
|
||||
af2_backend_module,
|
||||
tmp_path,
|
||||
):
|
||||
multimer = af2_backend_module.MultimericObject(
|
||||
description="complex",
|
||||
input_seqs=["AB"],
|
||||
feature_dict={"msa": np.ones((1, 2), dtype=np.int32)},
|
||||
multimeric_mode=True,
|
||||
)
|
||||
fake_runner = SimpleNamespace(
|
||||
multimer_mode=True,
|
||||
process_features=lambda feature_dict, random_seed: {
|
||||
"seq_length": 2,
|
||||
"template_aatype": np.ones((1, 2), dtype=np.int32),
|
||||
"template_all_atom_positions": np.zeros((1, 2, 37, 3), dtype=np.float32),
|
||||
"template_all_atom_mask": np.ones((1, 2, 37), dtype=np.float32),
|
||||
"num_templates": np.array([1], dtype=np.int32),
|
||||
},
|
||||
predict=lambda *args, **kwargs: None,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot skip templates"):
|
||||
af2_backend_module.AlphaFold2Backend.predict_individual_job(
|
||||
model_runners={"modelA": fake_runner},
|
||||
multimeric_object=multimer,
|
||||
allow_resume=False,
|
||||
skip_templates=True,
|
||||
output_dir=tmp_path,
|
||||
random_seed=1,
|
||||
)
|
||||
|
||||
|
||||
def test_predict_individual_job_resumes_completed_models_and_returns_early(
|
||||
af2_backend_module,
|
||||
tmp_path,
|
||||
):
|
||||
monomer = af2_backend_module.MonomericObject("single", "AB")
|
||||
monomer.feature_dict = {"residue_index": np.array([0, 1], dtype=np.int32)}
|
||||
|
||||
result_path = tmp_path / "result_modelA.pkl"
|
||||
with open(result_path, "wb") as handle:
|
||||
pickle.dump({"plddt": np.array([91.0, 88.0], dtype=np.float32)}, handle, protocol=4)
|
||||
(tmp_path / "unrelaxed_modelA.pdb").write_text("existing pdb", encoding="utf-8")
|
||||
|
||||
process_calls = []
|
||||
predict_calls = []
|
||||
fake_runner = SimpleNamespace(
|
||||
multimer_mode=False,
|
||||
process_features=lambda feature_dict, random_seed: process_calls.append(
|
||||
random_seed
|
||||
)
|
||||
or dict(feature_dict),
|
||||
predict=lambda *args, **kwargs: predict_calls.append((args, kwargs)),
|
||||
)
|
||||
|
||||
results = af2_backend_module.AlphaFold2Backend.predict_individual_job(
|
||||
model_runners={"modelA": fake_runner},
|
||||
multimeric_object=monomer,
|
||||
allow_resume=True,
|
||||
skip_templates=False,
|
||||
output_dir=tmp_path,
|
||||
random_seed=11,
|
||||
)
|
||||
|
||||
assert predict_calls == []
|
||||
assert process_calls == [11]
|
||||
assert results["modelA"]["seqs"] == ["AB"]
|
||||
assert results["modelA"]["unrelaxed_protein"].name == "predicted"
|
||||
|
||||
|
||||
def test_predict_individual_job_rejects_missing_or_zero_template_positions(
|
||||
af2_backend_module,
|
||||
tmp_path,
|
||||
):
|
||||
multimer = af2_backend_module.MultimericObject(
|
||||
description="complex",
|
||||
input_seqs=["AB"],
|
||||
feature_dict={"msa": np.ones((1, 2), dtype=np.int32)},
|
||||
multimeric_mode=True,
|
||||
)
|
||||
missing_template_runner = SimpleNamespace(
|
||||
multimer_mode=True,
|
||||
process_features=lambda feature_dict, random_seed: {},
|
||||
predict=lambda *args, **kwargs: None,
|
||||
)
|
||||
zero_template_runner = SimpleNamespace(
|
||||
multimer_mode=True,
|
||||
process_features=lambda feature_dict, random_seed: {
|
||||
"template_all_atom_positions": np.zeros((1, 2, 37, 3), dtype=np.float32),
|
||||
},
|
||||
predict=lambda *args, **kwargs: None,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No template_all_atom_positions key found"):
|
||||
af2_backend_module.AlphaFold2Backend.predict_individual_job(
|
||||
model_runners={"modelA": missing_template_runner},
|
||||
multimeric_object=multimer,
|
||||
allow_resume=False,
|
||||
skip_templates=False,
|
||||
output_dir=tmp_path,
|
||||
random_seed=3,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No valid templates found"):
|
||||
af2_backend_module.AlphaFold2Backend.predict_individual_job(
|
||||
model_runners={"modelA": zero_template_runner},
|
||||
multimeric_object=multimer,
|
||||
allow_resume=False,
|
||||
skip_templates=False,
|
||||
output_dir=tmp_path,
|
||||
random_seed=3,
|
||||
)
|
||||
|
||||
|
||||
def test_predict_yields_results_for_each_object(af2_backend_module, monkeypatch, tmp_path):
|
||||
monomer_a = af2_backend_module.MonomericObject("a", "AA")
|
||||
monomer_b = af2_backend_module.MonomericObject("b", "BB")
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module.AlphaFold2Backend,
|
||||
"predict_individual_job",
|
||||
staticmethod(
|
||||
lambda **kwargs: calls.append(kwargs["multimeric_object"].description)
|
||||
or {"modelA": kwargs["multimeric_object"].description}
|
||||
),
|
||||
)
|
||||
|
||||
outputs = list(
|
||||
af2_backend_module.AlphaFold2Backend.predict(
|
||||
model_runners={"modelA": object()},
|
||||
objects_to_model=[
|
||||
{"object": monomer_a, "output_dir": str(tmp_path / "a")},
|
||||
{"object": monomer_b, "output_dir": str(tmp_path / "b")},
|
||||
],
|
||||
allow_resume=False,
|
||||
skip_templates=False,
|
||||
random_seed=5,
|
||||
)
|
||||
)
|
||||
|
||||
assert calls == ["a", "b"]
|
||||
assert outputs == [
|
||||
{
|
||||
"object": monomer_a,
|
||||
"prediction_results": {"modelA": "a"},
|
||||
"output_dir": str(tmp_path / "a"),
|
||||
},
|
||||
{
|
||||
"object": monomer_b,
|
||||
"prediction_results": {"modelA": "b"},
|
||||
"output_dir": str(tmp_path / "b"),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_recalculate_confidence_handles_multimer_and_monomer_paths(af2_backend_module):
|
||||
already_numpy = {
|
||||
"predicted_aligned_error": np.zeros((1, 1), dtype=np.float32),
|
||||
"plddt": np.array([42.0], dtype=np.float32),
|
||||
}
|
||||
assert (
|
||||
af2_backend_module.AlphaFold2Backend.recalculate_confidence(
|
||||
already_numpy, multimer_mode=False, total_num_res=1
|
||||
)
|
||||
is already_numpy
|
||||
)
|
||||
|
||||
padded = {
|
||||
"predicted_aligned_error": {
|
||||
"logits": np.ones((4, 4), dtype=np.float32),
|
||||
"breaks": np.array([0.5], dtype=np.float32),
|
||||
"asym_id": np.array([0, 0, 1, 1], dtype=np.int32),
|
||||
},
|
||||
"plddt": np.array([10.0, 20.0, 30.0, 40.0], dtype=np.float32),
|
||||
}
|
||||
multimer_output = af2_backend_module.AlphaFold2Backend.recalculate_confidence(
|
||||
padded,
|
||||
multimer_mode=True,
|
||||
total_num_res=2,
|
||||
)
|
||||
assert multimer_output["ptm"] == 0.5
|
||||
assert multimer_output["iptm"] == 0.8
|
||||
assert multimer_output["ranking_confidence"] == pytest.approx(0.74)
|
||||
assert multimer_output["predicted_aligned_error"].shape == (2, 2)
|
||||
|
||||
monomer_output = af2_backend_module.AlphaFold2Backend.recalculate_confidence(
|
||||
padded,
|
||||
multimer_mode=False,
|
||||
total_num_res=2,
|
||||
)
|
||||
assert monomer_output["ranking_confidence"] == pytest.approx(15.0)
|
||||
|
||||
|
||||
def test_postprocess_ranks_models_relaxes_best_and_runs_cleanup(
|
||||
af2_backend_module,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
plot_calls = []
|
||||
cleanup_calls = []
|
||||
subprocess_calls = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"plot_pae_from_matrix",
|
||||
lambda **kwargs: plot_calls.append(kwargs["ranking"]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"post_prediction_process",
|
||||
lambda *args, **kwargs: cleanup_calls.append((args, kwargs)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module.subprocess,
|
||||
"run",
|
||||
lambda *args, **kwargs: subprocess_calls.append((args, kwargs))
|
||||
or SimpleNamespace(stderr="", stdout="", returncode=0),
|
||||
)
|
||||
|
||||
multimer = af2_backend_module.MultimericObject(
|
||||
description="complex",
|
||||
input_seqs=["AA", "BB"],
|
||||
feature_dict={},
|
||||
multimeric_mode=True,
|
||||
)
|
||||
prediction_results = {
|
||||
"model_low": {
|
||||
"plddt": np.array([70.0, 71.0, 72.0, 73.0], dtype=np.float32),
|
||||
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
"ranking_confidence": 0.2,
|
||||
"iptm": 0.1,
|
||||
"ptm": 0.2,
|
||||
"unrelaxed_protein": SimpleNamespace(name="low"),
|
||||
"seqs": ["AA", "BB"],
|
||||
},
|
||||
"model_high": {
|
||||
"plddt": np.array([90.0, 91.0, 92.0, 93.0], dtype=np.float32),
|
||||
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
"ranking_confidence": 0.9,
|
||||
"iptm": 0.7,
|
||||
"ptm": 0.8,
|
||||
"unrelaxed_protein": SimpleNamespace(name="high"),
|
||||
"seqs": ["AA", "BB"],
|
||||
},
|
||||
}
|
||||
|
||||
af2_backend_module.AlphaFold2Backend.postprocess(
|
||||
prediction_results=prediction_results,
|
||||
multimeric_object=multimer,
|
||||
output_dir=tmp_path,
|
||||
features_directory="/features",
|
||||
models_to_relax=af2_backend_module.ModelsToRelax.BEST,
|
||||
convert_to_modelcif=True,
|
||||
)
|
||||
|
||||
ranking = json.loads((tmp_path / "ranking_debug.json").read_text(encoding="utf-8"))
|
||||
assert ranking["order"] == ["model_high", "model_low"]
|
||||
assert plot_calls == [0, 1]
|
||||
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "RELAXED:high"
|
||||
assert (tmp_path / "ranked_1.pdb").read_text(encoding="utf-8") == "PDB:low"
|
||||
assert (tmp_path / "relax_metrics.json").is_file()
|
||||
assert cleanup_calls
|
||||
assert subprocess_calls
|
||||
|
||||
|
||||
def test_postprocess_relaxes_all_using_saved_unrelaxed_pdbs(
|
||||
af2_backend_module,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
cleanup_calls = []
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"plot_pae_from_matrix",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"post_prediction_process",
|
||||
lambda *args, **kwargs: cleanup_calls.append((args, kwargs)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module.protein,
|
||||
"from_pdb_string",
|
||||
lambda text: SimpleNamespace(name=text.strip()),
|
||||
)
|
||||
|
||||
(tmp_path / "unrelaxed_model_a.pdb").write_text("model_a_unrelaxed", encoding="utf-8")
|
||||
(tmp_path / "unrelaxed_model_b.pdb").write_text("model_b_unrelaxed", encoding="utf-8")
|
||||
|
||||
multimer = af2_backend_module.MultimericObject(
|
||||
description="complex",
|
||||
input_seqs=["AA", "BB"],
|
||||
feature_dict={},
|
||||
multimeric_mode=True,
|
||||
)
|
||||
prediction_results = {
|
||||
"model_a": {
|
||||
"plddt": np.array([80.0, 81.0, 82.0, 83.0], dtype=np.float32),
|
||||
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
"ranking_confidence": 0.6,
|
||||
"iptm": 0.4,
|
||||
"ptm": 0.5,
|
||||
"seqs": ["AA", "BB"],
|
||||
},
|
||||
"model_b": {
|
||||
"plddt": np.array([90.0, 91.0, 92.0, 93.0], dtype=np.float32),
|
||||
"predicted_aligned_error": np.zeros((4, 4), dtype=np.float32),
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
"ranking_confidence": 0.9,
|
||||
"iptm": 0.7,
|
||||
"ptm": 0.8,
|
||||
"seqs": ["AA", "BB"],
|
||||
},
|
||||
}
|
||||
|
||||
af2_backend_module.AlphaFold2Backend.postprocess(
|
||||
prediction_results=prediction_results,
|
||||
multimeric_object=multimer,
|
||||
output_dir=tmp_path,
|
||||
features_directory="/features",
|
||||
models_to_relax=af2_backend_module.ModelsToRelax.ALL,
|
||||
convert_to_modelcif=False,
|
||||
)
|
||||
|
||||
assert (tmp_path / "relaxed_model_a.pdb").read_text(encoding="utf-8") == "RELAXED:model_a_unrelaxed"
|
||||
assert (tmp_path / "relaxed_model_b.pdb").read_text(encoding="utf-8") == "RELAXED:model_b_unrelaxed"
|
||||
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "RELAXED:model_b_unrelaxed"
|
||||
assert (tmp_path / "ranked_1.pdb").read_text(encoding="utf-8") == "RELAXED:model_a_unrelaxed"
|
||||
assert cleanup_calls
|
||||
|
||||
|
||||
def test_postprocess_handles_monomers_without_relaxation_and_logs_modelcif_errors(
|
||||
af2_backend_module,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
cleanup_calls = []
|
||||
modelcif_errors = []
|
||||
plot_calls = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"plot_pae_from_matrix",
|
||||
lambda **kwargs: plot_calls.append(kwargs["ranking"]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module,
|
||||
"post_prediction_process",
|
||||
lambda *args, **kwargs: cleanup_calls.append((args, kwargs)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module.logging,
|
||||
"error",
|
||||
lambda message: modelcif_errors.append(message),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
af2_backend_module.subprocess,
|
||||
"run",
|
||||
lambda *args, **kwargs: SimpleNamespace(
|
||||
stderr="convert failed",
|
||||
stdout="",
|
||||
returncode=1,
|
||||
),
|
||||
)
|
||||
|
||||
monomer = af2_backend_module.MonomericObject("single", "AB")
|
||||
prediction_results = {
|
||||
"modelA": {
|
||||
"plddt": np.array([81.0, 82.0], dtype=np.float32),
|
||||
"predicted_aligned_error": np.zeros((2, 2), dtype=np.float32),
|
||||
"max_predicted_aligned_error": 31.0,
|
||||
"ranking_confidence": 81.5,
|
||||
"ptm": 0.4,
|
||||
"seqs": ["AB"],
|
||||
"unrelaxed_protein": SimpleNamespace(name="mono"),
|
||||
}
|
||||
}
|
||||
|
||||
af2_backend_module.AlphaFold2Backend.postprocess(
|
||||
prediction_results=prediction_results,
|
||||
multimeric_object=monomer,
|
||||
output_dir=tmp_path,
|
||||
features_directory="/features",
|
||||
models_to_relax=af2_backend_module.ModelsToRelax.NONE,
|
||||
convert_to_modelcif=True,
|
||||
)
|
||||
|
||||
ranking = json.loads((tmp_path / "ranking_debug.json").read_text(encoding="utf-8"))
|
||||
assert ranking["plddts"] == {"modelA": 81.5}
|
||||
assert ranking["ptm"] == {"modelA": 0.4}
|
||||
assert ranking["order"] == ["modelA"]
|
||||
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "PDB:mono"
|
||||
assert not (tmp_path / "relaxed_modelA.pdb").exists()
|
||||
assert plot_calls == [0]
|
||||
assert cleanup_calls
|
||||
assert modelcif_errors == ["Error: convert failed"]
|
||||
File diff suppressed because it is too large
Load Diff
434
test/unit/test_alphalink_backend_helpers.py
Normal file
434
test/unit/test_alphalink_backend_helpers.py
Normal file
@@ -0,0 +1,434 @@
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
MODULE_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "alphapulldown"
|
||||
/ "folding_backend"
|
||||
/ "alphalink_backend.py"
|
||||
)
|
||||
|
||||
|
||||
def _package(name: str) -> types.ModuleType:
|
||||
module = types.ModuleType(name)
|
||||
module.__path__ = [] # type: ignore[attr-defined]
|
||||
return module
|
||||
|
||||
|
||||
def _restore_modules(saved_modules: dict[str, types.ModuleType | None]) -> None:
|
||||
for name, module in saved_modules.items():
|
||||
if module is None:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = module
|
||||
|
||||
|
||||
def _tensor_tree_map(func, tree):
|
||||
if isinstance(tree, dict):
|
||||
return {key: _tensor_tree_map(func, value) for key, value in tree.items()}
|
||||
return func(tree)
|
||||
|
||||
|
||||
def _install_alphalink_backend_stubs() -> dict[str, types.ModuleType | None]:
|
||||
names_to_replace = [
|
||||
"alphapulldown.folding_backend.alphafold2_backend",
|
||||
"alphapulldown.objects",
|
||||
"alphapulldown.utils.plotting",
|
||||
"torch",
|
||||
"unifold",
|
||||
"unifold.config",
|
||||
"unifold.modules",
|
||||
"unifold.modules.alphafold",
|
||||
"unifold.dataset",
|
||||
"unifold.data",
|
||||
"unifold.data.residue_constants",
|
||||
"unifold.data.protein",
|
||||
"unifold.data.data_ops",
|
||||
"unicore",
|
||||
"unicore.utils",
|
||||
]
|
||||
saved_modules = {name: sys.modules.get(name) for name in names_to_replace}
|
||||
|
||||
af2_backend_mod = types.ModuleType("alphapulldown.folding_backend.alphafold2_backend")
|
||||
af2_backend_mod._save_pae_json_file = (
|
||||
lambda pae, max_pae, output_dir, model_name: Path(output_dir, f"pae_{model_name}.json").write_text(
|
||||
json.dumps({"max_pae": max_pae}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
)
|
||||
|
||||
objects_mod = types.ModuleType("alphapulldown.objects")
|
||||
|
||||
class MonomericObject:
|
||||
pass
|
||||
|
||||
class MultimericObject:
|
||||
pass
|
||||
|
||||
class ChoppedObject:
|
||||
pass
|
||||
|
||||
objects_mod.MonomericObject = MonomericObject
|
||||
objects_mod.MultimericObject = MultimericObject
|
||||
objects_mod.ChoppedObject = ChoppedObject
|
||||
|
||||
plotting_mod = types.ModuleType("alphapulldown.utils.plotting")
|
||||
plotting_mod.plot_pae_from_matrix = lambda *args, **kwargs: None
|
||||
|
||||
torch_mod = types.ModuleType("torch")
|
||||
torch_mod.bfloat16 = "bfloat16"
|
||||
torch_mod.half = "half"
|
||||
|
||||
class FakeTensor:
|
||||
def __init__(self, value, dtype=None):
|
||||
self._value = np.asarray(value)
|
||||
self.dtype = dtype if dtype is not None else self._value.dtype
|
||||
|
||||
def float(self):
|
||||
return FakeTensor(self._value.astype(np.float32), np.float32)
|
||||
|
||||
def cpu(self):
|
||||
return self._value
|
||||
|
||||
torch_mod.FakeTensor = FakeTensor
|
||||
torch_mod.cuda = SimpleNamespace(
|
||||
is_available=lambda: False,
|
||||
current_device=lambda: 0,
|
||||
get_device_properties=lambda _device: SimpleNamespace(
|
||||
total_memory=40 * 1024 * 1024 * 1024
|
||||
),
|
||||
)
|
||||
torch_mod.load = lambda path: {
|
||||
"ema": {"params": {"module.layer": 1, "module.bias": 2}}
|
||||
}
|
||||
torch_mod.no_grad = lambda: types.SimpleNamespace(
|
||||
__enter__=lambda self: None,
|
||||
__exit__=lambda self, exc_type, exc, tb: False,
|
||||
)
|
||||
torch_mod.autograd = SimpleNamespace(set_detect_anomaly=lambda flag: None)
|
||||
torch_mod.as_tensor = lambda value, device=None: value
|
||||
torch_mod.from_numpy = lambda value: value
|
||||
|
||||
unifold_pkg = _package("unifold")
|
||||
config_mod = types.ModuleType("unifold.config")
|
||||
|
||||
def _model_config(_name):
|
||||
return SimpleNamespace(
|
||||
data=SimpleNamespace(
|
||||
common=SimpleNamespace(max_recycling_iters=None),
|
||||
predict=SimpleNamespace(
|
||||
num_ensembles=None,
|
||||
subsample_templates=False,
|
||||
),
|
||||
),
|
||||
globals=SimpleNamespace(max_recycling_iters=None),
|
||||
)
|
||||
|
||||
config_mod.model_config = _model_config
|
||||
|
||||
modules_pkg = _package("unifold.modules")
|
||||
alphafold_mod = types.ModuleType("unifold.modules.alphafold")
|
||||
|
||||
class FakeAlphaFold:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.loaded_state_dict = None
|
||||
self.device = None
|
||||
self.eval_called = False
|
||||
self.inference_mode_called = False
|
||||
self.bfloat16_called = False
|
||||
self.globals = SimpleNamespace(chunk_size=None, block_size=None)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.loaded_state_dict = state_dict
|
||||
|
||||
def to(self, device):
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
self.eval_called = True
|
||||
|
||||
def inference_mode(self):
|
||||
self.inference_mode_called = True
|
||||
|
||||
def bfloat16(self):
|
||||
self.bfloat16_called = True
|
||||
|
||||
alphafold_mod.AlphaFold = FakeAlphaFold
|
||||
|
||||
dataset_mod = types.ModuleType("unifold.dataset")
|
||||
dataset_mod.process_ap = lambda **kwargs: ({}, None)
|
||||
|
||||
data_pkg = _package("unifold.data")
|
||||
residue_constants_mod = types.ModuleType("unifold.data.residue_constants")
|
||||
residue_constants_mod.atom_order = {"CA": 1}
|
||||
residue_constants_mod.atom_type_num = 37
|
||||
protein_mod = types.ModuleType("unifold.data.protein")
|
||||
protein_mod.from_prediction = lambda **kwargs: SimpleNamespace(chain_index=np.array([[0]]), aatype=np.array([[0]]))
|
||||
protein_mod.to_pdb = lambda protein: "PDB"
|
||||
data_ops_mod = types.ModuleType("unifold.data.data_ops")
|
||||
data_ops_mod.get_pairwise_distances = lambda coords: np.zeros((1, 1))
|
||||
|
||||
unicore_pkg = _package("unicore")
|
||||
unicore_utils_mod = types.ModuleType("unicore.utils")
|
||||
unicore_utils_mod.tensor_tree_map = _tensor_tree_map
|
||||
|
||||
modules = {
|
||||
"alphapulldown.folding_backend.alphafold2_backend": af2_backend_mod,
|
||||
"alphapulldown.objects": objects_mod,
|
||||
"alphapulldown.utils.plotting": plotting_mod,
|
||||
"torch": torch_mod,
|
||||
"unifold": unifold_pkg,
|
||||
"unifold.config": config_mod,
|
||||
"unifold.modules": modules_pkg,
|
||||
"unifold.modules.alphafold": alphafold_mod,
|
||||
"unifold.dataset": dataset_mod,
|
||||
"unifold.data": data_pkg,
|
||||
"unifold.data.residue_constants": residue_constants_mod,
|
||||
"unifold.data.protein": protein_mod,
|
||||
"unifold.data.data_ops": data_ops_mod,
|
||||
"unicore": unicore_pkg,
|
||||
"unicore.utils": unicore_utils_mod,
|
||||
}
|
||||
|
||||
for name, module in modules.items():
|
||||
sys.modules[name] = module
|
||||
|
||||
unifold_pkg.config = config_mod
|
||||
unifold_pkg.modules = modules_pkg
|
||||
unifold_pkg.dataset = dataset_mod
|
||||
unifold_pkg.data = data_pkg
|
||||
modules_pkg.alphafold = alphafold_mod
|
||||
data_pkg.residue_constants = residue_constants_mod
|
||||
data_pkg.protein = protein_mod
|
||||
data_pkg.data_ops = data_ops_mod
|
||||
unicore_pkg.utils = unicore_utils_mod
|
||||
|
||||
return saved_modules
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def alphalink_backend_module():
|
||||
saved_modules = _install_alphalink_backend_stubs()
|
||||
sys.modules.pop("alphapulldown.folding_backend.alphalink_backend", None)
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"alphapulldown.folding_backend.alphalink_backend",
|
||||
MODULE_PATH,
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
assert spec.loader is not None
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
yield module
|
||||
finally:
|
||||
sys.modules.pop(spec.name, None)
|
||||
_restore_modules(saved_modules)
|
||||
|
||||
|
||||
def test_setup_accepts_expected_weights_and_validates_missing_inputs(
|
||||
alphalink_backend_module,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
warning_messages = []
|
||||
monkeypatch.setattr(alphalink_backend_module.logging, "warning", warning_messages.append)
|
||||
|
||||
direct_file = tmp_path / "custom_weights.pt"
|
||||
direct_file.write_text("stub", encoding="utf-8")
|
||||
direct_setup = alphalink_backend_module.AlphaLinkBackend.setup(str(direct_file))
|
||||
assert direct_setup["param_path"] == str(direct_file)
|
||||
assert warning_messages
|
||||
|
||||
weights_dir = tmp_path / "weights"
|
||||
weights_dir.mkdir()
|
||||
canonical_file = weights_dir / "AlphaLink-Multimer_SDA_v3.pt"
|
||||
canonical_file.write_text("stub", encoding="utf-8")
|
||||
dir_setup = alphalink_backend_module.AlphaLinkBackend.setup(str(weights_dir))
|
||||
assert dir_setup["param_path"] == str(canonical_file)
|
||||
|
||||
wrong_extension = tmp_path / "weights.bin"
|
||||
wrong_extension.write_text("stub", encoding="utf-8")
|
||||
with pytest.raises(ValueError, match=".pt extension"):
|
||||
alphalink_backend_module.AlphaLinkBackend.setup(str(wrong_extension))
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="does not exist"):
|
||||
alphalink_backend_module.AlphaLinkBackend.setup(str(tmp_path / "missing"))
|
||||
|
||||
|
||||
def test_unload_tensors_and_prepare_model_runner(alphalink_backend_module):
|
||||
fake_tensor = alphalink_backend_module.torch.FakeTensor
|
||||
batch, out = alphalink_backend_module.AlphaLinkBackend.unload_tensors(
|
||||
{"a": fake_tensor([1, 2], dtype=alphalink_backend_module.torch.bfloat16)},
|
||||
{"b": fake_tensor([3, 4], dtype=np.float32)},
|
||||
)
|
||||
np.testing.assert_array_equal(batch["a"], np.array([1.0, 2.0], dtype=np.float32))
|
||||
np.testing.assert_array_equal(out["b"], np.array([3, 4], dtype=np.float32))
|
||||
|
||||
model = alphalink_backend_module.AlphaLinkBackend.prepare_model_runner(
|
||||
"weights.pt",
|
||||
bf16=True,
|
||||
model_device="cpu",
|
||||
)
|
||||
assert model.device == "cpu"
|
||||
assert model.eval_called is True
|
||||
assert model.inference_mode_called is True
|
||||
assert model.bfloat16_called is True
|
||||
assert model.loaded_state_dict == {"layer": 1, "bias": 2}
|
||||
assert (
|
||||
model.config.data.common.max_recycling_iters
|
||||
== alphalink_backend_module.MAX_RECYCLING_ITERS
|
||||
)
|
||||
assert (
|
||||
model.config.data.predict.num_ensembles
|
||||
== alphalink_backend_module.NUM_ENSEMBLES
|
||||
)
|
||||
|
||||
|
||||
def test_resume_preprocess_and_chunk_helpers(alphalink_backend_module, tmp_path):
|
||||
(tmp_path / "AlphaLink2_model_0_seed_123_0.875.pdb").write_text("pdb", encoding="utf-8")
|
||||
(tmp_path / "pae_AlphaLink2_model_0_seed_123_0.875.json").write_text(
|
||||
"{}",
|
||||
encoding="utf-8",
|
||||
)
|
||||
already_exists, iptm_value = alphalink_backend_module.AlphaLinkBackend.check_resume_status(
|
||||
"AlphaLink2_model_0_seed_123",
|
||||
str(tmp_path),
|
||||
)
|
||||
assert already_exists is True
|
||||
assert iptm_value == pytest.approx(0.875)
|
||||
|
||||
processed = alphalink_backend_module.AlphaLinkBackend.preprocess_features(
|
||||
{
|
||||
"seq_length": np.array([7, 7]),
|
||||
"num_alignments": np.array([3, 3]),
|
||||
"num_templates": np.array([2, 2]),
|
||||
"template_all_atom_masks": np.ones((1, 7, 37), dtype=np.float32),
|
||||
"template_aatype": np.eye(22)[[0, 1, 2, 3, 4, 5, 6]][None, :, :],
|
||||
"template_sum_probs": np.array([0.5], dtype=np.float32),
|
||||
"deletion_matrix_int": np.ones((1, 7), dtype=np.float32),
|
||||
"deletion_matrix_int_all_seq": np.full((2, 7), 2.0, dtype=np.float32),
|
||||
"msa": np.ones((2, 7), dtype=np.int32),
|
||||
}
|
||||
)
|
||||
assert processed["seq_length"] == 7
|
||||
assert processed["num_alignments"] == 3
|
||||
assert processed["num_templates"] == 2
|
||||
assert processed["template_all_atom_mask"].shape == (1, 7, 37)
|
||||
assert processed["template_aatype"].shape == (1, 7)
|
||||
assert processed["template_sum_probs"].shape == (1, 1)
|
||||
np.testing.assert_array_equal(processed["deletion_matrix"], np.ones((1, 7)))
|
||||
np.testing.assert_array_equal(
|
||||
processed["extra_deletion_matrix"],
|
||||
np.full((2, 7), 2.0),
|
||||
)
|
||||
assert processed["msa_mask"].shape == (2, 7)
|
||||
assert processed["msa_row_mask"].shape == (2,)
|
||||
np.testing.assert_array_equal(processed["asym_id"], np.zeros(7, dtype=np.int32))
|
||||
np.testing.assert_array_equal(processed["entity_id"], np.zeros(7, dtype=np.int32))
|
||||
np.testing.assert_array_equal(processed["sym_id"], np.ones(7, dtype=np.int32))
|
||||
|
||||
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(500, "cpu") == (
|
||||
256,
|
||||
None,
|
||||
)
|
||||
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(1000, "cpu") == (
|
||||
128,
|
||||
None,
|
||||
)
|
||||
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(1500, "cpu") == (
|
||||
64,
|
||||
None,
|
||||
)
|
||||
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(2200, "cpu") == (
|
||||
32,
|
||||
512,
|
||||
)
|
||||
assert alphalink_backend_module.AlphaLinkBackend.automatic_chunk_size(3000, "cpu") == (
|
||||
4,
|
||||
256,
|
||||
)
|
||||
|
||||
|
||||
def test_predict_validates_inputs_and_builds_chain_maps(
|
||||
alphalink_backend_module,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
with pytest.raises(ValueError, match="Missing required parameters"):
|
||||
list(alphalink_backend_module.AlphaLinkBackend.predict(objects_to_model=[]))
|
||||
|
||||
captured_calls = []
|
||||
|
||||
def fake_predict_iterations(feature_dict, output_dir, **kwargs):
|
||||
captured_calls.append((feature_dict, output_dir, kwargs))
|
||||
|
||||
monkeypatch.setattr(
|
||||
alphalink_backend_module.AlphaLinkBackend,
|
||||
"predict_iterations",
|
||||
staticmethod(fake_predict_iterations),
|
||||
)
|
||||
|
||||
default_chain_object = SimpleNamespace(
|
||||
feature_dict={"x": 1},
|
||||
input_seqs=["AAA"],
|
||||
)
|
||||
integer_chain_map_object = SimpleNamespace(
|
||||
feature_dict={"y": 2},
|
||||
input_seqs=["AAA", "BBB"],
|
||||
chain_id_map={"A": 0, "B": 1},
|
||||
)
|
||||
|
||||
results = list(
|
||||
alphalink_backend_module.AlphaLinkBackend.predict(
|
||||
objects_to_model=[
|
||||
{"object": default_chain_object, "output_dir": str(tmp_path / "default")},
|
||||
{"object": integer_chain_map_object, "output_dir": str(tmp_path / "int_map")},
|
||||
],
|
||||
configs=SimpleNamespace(data="cfg"),
|
||||
param_path="weights.pt",
|
||||
num_predictions_per_model=2,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert captured_calls[0][2]["crosslinks"] == ""
|
||||
default_chain_map = captured_calls[0][2]["chain_id_map"]
|
||||
assert sorted(default_chain_map) == ["A"]
|
||||
assert default_chain_map["A"].description == "default_chain"
|
||||
assert default_chain_map["A"].sequence == "AAA"
|
||||
|
||||
integer_chain_map = captured_calls[1][2]["chain_id_map"]
|
||||
assert integer_chain_map["A"].description == "chain_A"
|
||||
assert integer_chain_map["A"].sequence == "AAA"
|
||||
assert integer_chain_map["B"].description == "chain_B"
|
||||
assert integer_chain_map["B"].sequence == "BBB"
|
||||
assert captured_calls[1][2]["num_inference"] == 2
|
||||
|
||||
|
||||
def test_postprocess_ranks_nested_prediction_outputs(alphalink_backend_module, tmp_path):
|
||||
low_dir = tmp_path / "seed0"
|
||||
high_dir = tmp_path / "seed1"
|
||||
low_dir.mkdir()
|
||||
high_dir.mkdir()
|
||||
(low_dir / "AlphaLink2_model_0_seed_1_0.500.pdb").write_text("LOW", encoding="utf-8")
|
||||
(high_dir / "AlphaLink2_model_1_seed_2_0.900.pdb").write_text("HIGH", encoding="utf-8")
|
||||
|
||||
alphalink_backend_module.AlphaLinkBackend.postprocess({}, str(tmp_path))
|
||||
|
||||
ranking = json.loads((tmp_path / "ranking_debug.json").read_text(encoding="utf-8"))
|
||||
assert ranking["order"] == [
|
||||
"AlphaLink2_model_1_seed_2_0.900",
|
||||
"AlphaLink2_model_0_seed_1_0.500",
|
||||
]
|
||||
assert (tmp_path / "ranked_0.pdb").read_text(encoding="utf-8") == "HIGH"
|
||||
assert (tmp_path / "ranked_1.pdb").read_text(encoding="utf-8") == "LOW"
|
||||
104
test/unit/test_cluster_wrapper_helpers.py
Normal file
104
test/unit/test_cluster_wrapper_helpers.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
AF2_CHECK_PATH = REPO_ROOT / "test" / "cluster" / "check_alphafold2_predictions.py"
|
||||
AF2_WRAPPER_PATH = REPO_ROOT / "test" / "cluster" / "run_alphafold2_predictions.py"
|
||||
AF3_WRAPPER_PATH = REPO_ROOT / "test" / "cluster" / "run_alphafold3_predictions.py"
|
||||
|
||||
|
||||
def _load_module(module_name: str, module_path: Path):
|
||||
sys.modules.pop(module_name, None)
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_af2_cluster_subprocess_env_sets_safe_gpu_defaults(monkeypatch):
|
||||
module = _load_module("test_cluster_af2_check_module", AF2_CHECK_PATH)
|
||||
|
||||
for name in (
|
||||
"OMP_NUM_THREADS",
|
||||
"MKL_NUM_THREADS",
|
||||
"NUMEXPR_NUM_THREADS",
|
||||
"TF_NUM_INTEROP_THREADS",
|
||||
"TF_NUM_INTRAOP_THREADS",
|
||||
"TF_FORCE_GPU_ALLOW_GROWTH",
|
||||
"TF_CPP_MIN_LOG_LEVEL",
|
||||
"XLA_PYTHON_CLIENT_PREALLOCATE",
|
||||
"XLA_PYTHON_CLIENT_MEM_FRACTION",
|
||||
"JAX_PLATFORM_NAME",
|
||||
"XLA_FLAGS",
|
||||
):
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
env = module._af2_subprocess_env()
|
||||
|
||||
assert env["OMP_NUM_THREADS"] == "4"
|
||||
assert env["MKL_NUM_THREADS"] == "4"
|
||||
assert env["NUMEXPR_NUM_THREADS"] == "4"
|
||||
assert env["TF_NUM_INTEROP_THREADS"] == "4"
|
||||
assert env["TF_NUM_INTRAOP_THREADS"] == "4"
|
||||
assert env["TF_FORCE_GPU_ALLOW_GROWTH"] == "true"
|
||||
assert env["TF_CPP_MIN_LOG_LEVEL"] == "2"
|
||||
assert env["XLA_PYTHON_CLIENT_PREALLOCATE"] == "false"
|
||||
assert env["XLA_PYTHON_CLIENT_MEM_FRACTION"] == "0.8"
|
||||
assert env["JAX_PLATFORM_NAME"] == "gpu"
|
||||
assert "--xla_gpu_force_compilation_parallelism=0" in env["XLA_FLAGS"]
|
||||
|
||||
|
||||
def test_af2_cluster_wrapper_job_script_exports_gpu_defaults(tmp_path):
|
||||
module = _load_module("test_cluster_af2_wrapper_module", AF2_WRAPPER_PATH)
|
||||
job = module.JobSpec(
|
||||
index=1,
|
||||
nodeid="test/cluster/check_alphafold2_predictions.py::TestRunModes::test__monomer",
|
||||
slug="af2_node",
|
||||
stdout_path=tmp_path / "stdout.log",
|
||||
stderr_path=tmp_path / "stderr.log",
|
||||
script_path=tmp_path / "job.sh",
|
||||
rerun_command="python -m pytest",
|
||||
)
|
||||
|
||||
module.write_job_script(
|
||||
job=job,
|
||||
python_executable=sys.executable,
|
||||
use_temp_dir=True,
|
||||
cpus_per_task=8,
|
||||
)
|
||||
|
||||
script_text = job.script_path.read_text(encoding="utf-8")
|
||||
assert 'OMP_NUM_THREADS="${OMP_NUM_THREADS:-4}"' in script_text
|
||||
assert 'TF_NUM_INTRAOP_THREADS="${TF_NUM_INTRAOP_THREADS:-4}"' in script_text
|
||||
assert 'JAX_PLATFORM_NAME="${JAX_PLATFORM_NAME:-gpu}"' in script_text
|
||||
assert "addopts=-ra --strict-markers" in script_text
|
||||
assert "--use-temp-dir" in script_text
|
||||
|
||||
|
||||
def test_af3_cluster_wrapper_job_script_sets_perf_flag(tmp_path):
|
||||
module = _load_module("test_cluster_af3_wrapper_module", AF3_WRAPPER_PATH)
|
||||
job = module.JobSpec(
|
||||
index=1,
|
||||
nodeid="test/cluster/check_alphafold3_predictions.py::TestRunModes::test__monomer",
|
||||
slug="af3_node",
|
||||
stdout_path=tmp_path / "stdout.log",
|
||||
stderr_path=tmp_path / "stderr.log",
|
||||
script_path=tmp_path / "job.sh",
|
||||
rerun_command="python -m pytest",
|
||||
)
|
||||
|
||||
module.write_job_script(
|
||||
job=job,
|
||||
python_executable=sys.executable,
|
||||
use_temp_dir=True,
|
||||
include_perf=True,
|
||||
)
|
||||
|
||||
script_text = job.script_path.read_text(encoding="utf-8")
|
||||
assert "export AF3_RUN_PERF_TESTS=1" in script_text
|
||||
assert "addopts=-ra --strict-markers" in script_text
|
||||
assert "--use-temp-dir" in script_text
|
||||
134
test/unit/test_convert_to_modelcif_helpers.py
Normal file
134
test/unit/test_convert_to_modelcif_helpers.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
pytest.importorskip("ihm")
|
||||
pytest.importorskip("modelcif")
|
||||
|
||||
import alphapulldown.scripts.convert_to_modelcif as convert_to_modelcif
|
||||
|
||||
|
||||
TEST_PREDICTIONS_DIR = (
|
||||
Path(__file__).resolve().parents[1] / "test_data" / "predictions"
|
||||
)
|
||||
|
||||
|
||||
def test_cast_param_preserves_expected_python_types():
|
||||
assert convert_to_modelcif._cast_param("7") == 7
|
||||
assert convert_to_modelcif._cast_param("3.5") == 3.5
|
||||
assert convert_to_modelcif._cast_param("True") is True
|
||||
assert convert_to_modelcif._cast_param("False") is False
|
||||
assert convert_to_modelcif._cast_param("plain-text") == "plain-text"
|
||||
|
||||
|
||||
def test_compress_cif_file_replaces_plaintext_file(tmp_path):
|
||||
cif_file = tmp_path / "ranked_0.cif"
|
||||
cif_file.write_text("data_test\n", encoding="ascii")
|
||||
|
||||
compressed_path = convert_to_modelcif._compress_cif_file(str(cif_file))
|
||||
|
||||
assert compressed_path.endswith(".gz")
|
||||
assert not cif_file.exists()
|
||||
assert (tmp_path / "ranked_0.cif.gz").exists()
|
||||
|
||||
|
||||
def test_get_feature_metadata_falls_back_to_structure_sequence(tmp_path):
|
||||
source_dir = TEST_PREDICTIONS_DIR / "TEST"
|
||||
work_dir = tmp_path / "TEST"
|
||||
shutil.copytree(source_dir, work_dir)
|
||||
|
||||
metadata_file = next(work_dir.glob("*_feature_metadata_*.json"))
|
||||
payload = json.loads(metadata_file.read_text(encoding="ascii"))
|
||||
payload["other"]["fasta_paths"] = "['/this/path/does/not/exist.fasta']"
|
||||
metadata_file.write_text(json.dumps(payload, indent=2), encoding="ascii")
|
||||
|
||||
modelcif_json = {}
|
||||
complex_name, fasta_dicts = convert_to_modelcif._get_feature_metadata(
|
||||
modelcif_json,
|
||||
"TEST",
|
||||
str(work_dir),
|
||||
fallback_structure_path=str(work_dir / "ranked_0.pdb"),
|
||||
)
|
||||
|
||||
assert complex_name == "TEST"
|
||||
assert fasta_dicts
|
||||
assert fasta_dicts[0]["description"] == "chain_A"
|
||||
assert fasta_dicts[0]["sequence"].startswith("MESAIA")
|
||||
assert modelcif_json["__meta__"]["TEST"]["databases"]
|
||||
|
||||
|
||||
def test_get_model_list_selects_requested_model_and_tracks_non_selected_models():
|
||||
selected_models = convert_to_modelcif._get_model_list(
|
||||
str(TEST_PREDICTIONS_DIR / "TEST"),
|
||||
0,
|
||||
True,
|
||||
)
|
||||
|
||||
assert len(selected_models) == 1
|
||||
result = selected_models[0]
|
||||
assert result["complex"] == "TEST"
|
||||
assert len(result["models"]) == 1
|
||||
assert result["models"][0][0].endswith("ranked_0.pdb")
|
||||
assert len(result["not_selected"]) == 4
|
||||
|
||||
|
||||
def test_main_processes_associated_models_before_selected_models(monkeypatch, tmp_path):
|
||||
calls = []
|
||||
|
||||
def fake_convert(complex_name, model_tuple, out_dir, compress, additional_assoc_files=None):
|
||||
calls.append(
|
||||
{
|
||||
"complex_name": complex_name,
|
||||
"model_tuple": model_tuple,
|
||||
"out_dir": out_dir,
|
||||
"compress": compress,
|
||||
"additional_assoc_files": additional_assoc_files,
|
||||
}
|
||||
)
|
||||
return {
|
||||
f"{Path(model_tuple[0]).stem}.zip": (
|
||||
str(Path(out_dir) / f"{Path(model_tuple[0]).stem}.zip"),
|
||||
object(),
|
||||
)
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
convert_to_modelcif,
|
||||
"FLAGS",
|
||||
SimpleNamespace(
|
||||
ap_output=str(tmp_path / "predictions"),
|
||||
model_selected=0,
|
||||
add_associated=True,
|
||||
compress=False,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
convert_to_modelcif,
|
||||
"_get_model_list",
|
||||
lambda ap_output, model_selected, get_non_selected: [
|
||||
{
|
||||
"complex": "TEST",
|
||||
"path": str(tmp_path / "out"),
|
||||
"models": [("ranked_0.pdb", "result_0.pkl", "0", 0)],
|
||||
"not_selected": [("ranked_1.pdb", "result_1.pkl", "1", 1)],
|
||||
}
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
convert_to_modelcif,
|
||||
"alphapulldown_model_to_modelcif",
|
||||
fake_convert,
|
||||
)
|
||||
|
||||
convert_to_modelcif.main([])
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["model_tuple"][0] == "ranked_1.pdb"
|
||||
assert calls[0]["additional_assoc_files"] is None
|
||||
assert calls[0]["out_dir"] != str(tmp_path / "out")
|
||||
assert calls[1]["model_tuple"][0] == "ranked_0.pdb"
|
||||
assert calls[1]["additional_assoc_files"]
|
||||
76
test/unit/test_crosslink_input.py
Normal file
76
test/unit/test_crosslink_input.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from alphafold.data.pipeline_multimer import _FastaChain
|
||||
|
||||
|
||||
torch = pytest.importorskip("torch", reason="AlphaLink crosslink helpers require torch")
|
||||
from unifold.dataset import bin_xl, calculate_offsets, create_xl_features
|
||||
|
||||
|
||||
TEST_ROOT = Path(__file__).resolve().parents[1]
|
||||
CROSSLINK_FIXTURE = TEST_ROOT / "alphalink" / "test_xl_input.pkl"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def crosslink_context():
|
||||
return {
|
||||
"crosslinks": pickle.load(CROSSLINK_FIXTURE.open("rb")),
|
||||
"asym_id": torch.tensor([1] * 10 + [2] * 25 + [3] * 40),
|
||||
"chain_id_map": {
|
||||
"A": _FastaChain(sequence="", description="chain1"),
|
||||
"B": _FastaChain(sequence="", description="chain2"),
|
||||
"C": _FastaChain(sequence="", description="chain3"),
|
||||
},
|
||||
"bins": torch.arange(0, 1.05, 0.05),
|
||||
}
|
||||
|
||||
|
||||
def test_calculate_offsets(crosslink_context):
|
||||
offsets = calculate_offsets(crosslink_context["asym_id"]).tolist()
|
||||
assert offsets == [0, 10, 35, 75]
|
||||
|
||||
|
||||
def test_create_xl_inputs(crosslink_context):
|
||||
offsets = calculate_offsets(crosslink_context["asym_id"])
|
||||
xl = create_xl_features(
|
||||
crosslink_context["crosslinks"],
|
||||
offsets,
|
||||
chain_id_map=crosslink_context["chain_id_map"],
|
||||
)
|
||||
expected_xl = torch.tensor(
|
||||
[
|
||||
[10, 35, 0.01],
|
||||
[3, 27, 0.01],
|
||||
[5, 56, 0.01],
|
||||
[20, 65, 0.01],
|
||||
]
|
||||
)
|
||||
assert torch.equal(xl, expected_xl)
|
||||
|
||||
|
||||
def test_bin_xl(crosslink_context):
|
||||
offsets = calculate_offsets(crosslink_context["asym_id"])
|
||||
xl = create_xl_features(
|
||||
crosslink_context["crosslinks"],
|
||||
offsets,
|
||||
chain_id_map=crosslink_context["chain_id_map"],
|
||||
)
|
||||
num_res = len(crosslink_context["asym_id"])
|
||||
xl = bin_xl(xl, num_res)
|
||||
expected_xl = np.zeros((num_res, num_res, 1))
|
||||
expected_xl[3, 27, 0] = expected_xl[27, 3, 0] = torch.bucketize(
|
||||
0.99, crosslink_context["bins"]
|
||||
)
|
||||
expected_xl[10, 35, 0] = expected_xl[35, 10, 0] = torch.bucketize(
|
||||
0.99, crosslink_context["bins"]
|
||||
)
|
||||
expected_xl[5, 56, 0] = expected_xl[56, 5, 0] = torch.bucketize(
|
||||
0.99, crosslink_context["bins"]
|
||||
)
|
||||
expected_xl[20, 65, 0] = expected_xl[65, 20, 0] = torch.bucketize(
|
||||
0.99, crosslink_context["bins"]
|
||||
)
|
||||
assert np.array_equal(xl, expected_xl)
|
||||
@@ -296,3 +296,96 @@ def test_resolve_species_ids_by_accession_skips_single_accession_fallback_after_
|
||||
'A0A743YDY2': '',
|
||||
}
|
||||
assert calls == [('A0A636IKY3', 'A0A743YDY2')]
|
||||
|
||||
|
||||
def test_build_mmseq_identifier_features_skips_non_uniprot_identifiers(
|
||||
monkeypatch,
|
||||
):
|
||||
calls = []
|
||||
|
||||
def fake_resolver(accessions):
|
||||
calls.append(tuple(accessions))
|
||||
return {'A0A636IKY3': '108619'}
|
||||
|
||||
a3m = '\n'.join([
|
||||
'>101',
|
||||
'ACDE',
|
||||
'>MGYP000264027769',
|
||||
'ACDF',
|
||||
'>UniRef100_MGYP000264027769',
|
||||
'ACDG',
|
||||
'>UniRef100_A0A636IKY3',
|
||||
'ACDH',
|
||||
'',
|
||||
])
|
||||
|
||||
features = mmseqs_species_identifiers.build_mmseq_identifier_features(
|
||||
a3m,
|
||||
species_resolver=fake_resolver,
|
||||
)
|
||||
|
||||
assert calls == [('A0A636IKY3',)]
|
||||
assert features['msa_species_identifiers'].tolist() == [
|
||||
b'',
|
||||
b'',
|
||||
b'',
|
||||
b'108619',
|
||||
]
|
||||
assert features['msa_uniprot_accession_identifiers'].tolist() == [
|
||||
b'',
|
||||
b'',
|
||||
b'',
|
||||
b'A0A636IKY3',
|
||||
]
|
||||
|
||||
|
||||
def test_resolve_species_ids_by_accession_skips_unsupported_accessions(
|
||||
monkeypatch,
|
||||
):
|
||||
uniprot_calls = []
|
||||
uniparc_calls = []
|
||||
|
||||
def fake_uniprot_query(accessions, *, urlopen):
|
||||
uniprot_calls.append(tuple(accessions))
|
||||
return {
|
||||
'results': [
|
||||
{
|
||||
'primaryAccession': 'A0A636IKY3',
|
||||
'organism': {'taxonId': 562},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def fake_uniparc_query(accessions, *, urlopen):
|
||||
uniparc_calls.append(tuple(accessions))
|
||||
return {
|
||||
'results': [
|
||||
{
|
||||
'uniParcId': 'UPI001118B830',
|
||||
'organisms': [{'taxonId': 83333}],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
mmseqs_species_identifiers,
|
||||
'_query_uniprot_batch',
|
||||
fake_uniprot_query,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
mmseqs_species_identifiers,
|
||||
'_query_uniparc_batch',
|
||||
fake_uniparc_query,
|
||||
)
|
||||
|
||||
resolved = mmseqs_species_identifiers.resolve_species_ids_by_accession(
|
||||
['A0A636IKY3', 'MGYP000264027769', 'UPI001118B830']
|
||||
)
|
||||
|
||||
assert resolved == {
|
||||
'A0A636IKY3': '562',
|
||||
'MGYP000264027769': '',
|
||||
'UPI001118B830': '83333',
|
||||
}
|
||||
assert uniprot_calls == [('A0A636IKY3',)]
|
||||
assert uniparc_calls == [('UPI001118B830',)]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
|
||||
from alphapulldown.utils.output_paths import (
|
||||
build_af3_combined_json_job_name,
|
||||
derive_af3_job_name_from_json,
|
||||
resolve_af3_combined_json_output_dir,
|
||||
resolve_af3_json_output_dir,
|
||||
sanitise_af3_job_name,
|
||||
)
|
||||
@@ -106,3 +108,60 @@ def test_resolve_af3_json_output_dir_keeps_unsafe_json_name_within_root(tmp_path
|
||||
use_ap_style=True,
|
||||
shared_output_root=True,
|
||||
) == str(tmp_path / "predictions" / "input_name")
|
||||
|
||||
|
||||
def test_build_af3_combined_json_job_name_uses_fold_fragments(tmp_path):
|
||||
ptm_json = tmp_path / "protein_with_ptms.json"
|
||||
ptm_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "protein_ptms",
|
||||
"dialect": "alphafold3",
|
||||
"version": 1,
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
partner_json = tmp_path / "P61626_af3_input.json"
|
||||
partner_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": "P61626",
|
||||
"dialect": "alphafold3",
|
||||
"version": 1,
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
assert build_af3_combined_json_job_name(
|
||||
[
|
||||
{"json_input": str(ptm_json)},
|
||||
{"json_input": str(partner_json)},
|
||||
]
|
||||
) == "protein_with_ptms_and_p61626"
|
||||
|
||||
|
||||
def test_resolve_af3_combined_json_output_dir_uses_one_fold_directory(tmp_path):
|
||||
json_a = tmp_path / "P01308_af3_input.json"
|
||||
json_b = tmp_path / "P61626_af3_input.json"
|
||||
for json_path, name in ((json_a, "P01308"), (json_b, "P61626")):
|
||||
json_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"name": name,
|
||||
"dialect": "alphafold3",
|
||||
"version": 1,
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
assert resolve_af3_combined_json_output_dir(
|
||||
[
|
||||
{"json_input": str(json_a)},
|
||||
{"json_input": str(json_b)},
|
||||
],
|
||||
str(tmp_path / "predictions"),
|
||||
use_ap_style=True,
|
||||
) == str(tmp_path / "predictions" / "p01308_and_p61626")
|
||||
|
||||
1105
test/unit/test_script_entrypoints.py
Normal file
1105
test/unit/test_script_entrypoints.py
Normal file
File diff suppressed because it is too large
Load Diff
430
test/unit/test_small_script_entrypoints.py
Normal file
430
test/unit/test_small_script_entrypoints.py
Normal file
@@ -0,0 +1,430 @@
|
||||
import gzip
|
||||
import importlib.util
|
||||
import json
|
||||
import pickle
|
||||
import runpy
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import alphapulldown.scripts.generate_crosslink_pickle as crosslink_pickle
|
||||
import alphapulldown.scripts.rename_colab_search_a3m as rename_a3m
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
PREPARE_SEQ_NAMES_PATH = (
|
||||
REPO_ROOT / "alphapulldown" / "scripts" / "prepare_seq_names.py"
|
||||
)
|
||||
PARSE_INPUT_PATH = REPO_ROOT / "alphapulldown" / "scripts" / "parse_input.py"
|
||||
SPLIT_JOBS_PATH = (
|
||||
REPO_ROOT / "alphapulldown" / "scripts" / "split_jobs_into_clusters.py"
|
||||
)
|
||||
TRUNCATE_PICKLES_PATH = (
|
||||
REPO_ROOT / "alphapulldown" / "scripts" / "truncate_pickles.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_module_from_path(module_name: str, module_path: Path):
|
||||
sys.modules.pop(module_name, None)
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _load_parse_input_module(monkeypatch):
|
||||
flags_mod = types.ModuleType("absl.flags")
|
||||
flags_mod.FLAGS = SimpleNamespace()
|
||||
|
||||
def _define_list(name, default, help_text):
|
||||
del help_text
|
||||
setattr(flags_mod.FLAGS, name, default)
|
||||
|
||||
def _define_string(name, default, help_text):
|
||||
del help_text
|
||||
setattr(flags_mod.FLAGS, name, default)
|
||||
|
||||
flags_mod.DEFINE_list = _define_list
|
||||
flags_mod.DEFINE_string = _define_string
|
||||
|
||||
app_calls = []
|
||||
app_mod = types.ModuleType("absl.app")
|
||||
app_mod.run = lambda fn: app_calls.append(fn)
|
||||
|
||||
logging_mod = types.ModuleType("absl.logging")
|
||||
logging_mod.INFO = 20
|
||||
logging_mod.set_verbosity = lambda *_args, **_kwargs: None
|
||||
|
||||
absl_pkg = types.ModuleType("absl")
|
||||
absl_pkg.app = app_mod
|
||||
absl_pkg.flags = flags_mod
|
||||
absl_pkg.logging = logging_mod
|
||||
|
||||
parser_calls = {}
|
||||
parser_mod = types.ModuleType("alphapulldown_input_parser")
|
||||
|
||||
def _generate_fold_specifications(**kwargs):
|
||||
parser_calls["generate"] = kwargs
|
||||
return ["foldA"]
|
||||
|
||||
parser_mod.generate_fold_specifications = _generate_fold_specifications
|
||||
|
||||
modelling_calls = {}
|
||||
modelling_setup_mod = types.ModuleType("alphapulldown.utils.modelling_setup")
|
||||
|
||||
def _parse_fold(specifications, features_directory, delimiter):
|
||||
modelling_calls["parse_fold"] = (
|
||||
specifications,
|
||||
features_directory,
|
||||
delimiter,
|
||||
)
|
||||
return specifications
|
||||
|
||||
def _create_custom_info(parsed):
|
||||
modelling_calls["create_custom_info"] = parsed
|
||||
return {"parsed": parsed}
|
||||
|
||||
modelling_setup_mod.parse_fold = _parse_fold
|
||||
modelling_setup_mod.create_custom_info = _create_custom_info
|
||||
|
||||
monkeypatch.setitem(sys.modules, "absl", absl_pkg)
|
||||
monkeypatch.setitem(sys.modules, "absl.app", app_mod)
|
||||
monkeypatch.setitem(sys.modules, "absl.flags", flags_mod)
|
||||
monkeypatch.setitem(sys.modules, "absl.logging", logging_mod)
|
||||
monkeypatch.setitem(sys.modules, "alphapulldown_input_parser", parser_mod)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"alphapulldown.utils.modelling_setup",
|
||||
modelling_setup_mod,
|
||||
)
|
||||
|
||||
module = _load_module_from_path("test_parse_input_module", PARSE_INPUT_PATH)
|
||||
return module, app_calls, parser_calls, modelling_calls
|
||||
|
||||
|
||||
def _load_split_jobs_module(monkeypatch):
|
||||
parser_mod = types.ModuleType("alphapulldown_input_parser")
|
||||
parser_mod.generate_fold_specifications = lambda **kwargs: ["A,B", "C;D"]
|
||||
|
||||
modelling_setup_mod = types.ModuleType("alphapulldown.utils.modelling_setup")
|
||||
modelling_setup_mod.parse_fold = lambda args: args
|
||||
modelling_setup_mod.create_custom_info = lambda parsed_input: parsed_input
|
||||
modelling_setup_mod.create_interactors = (
|
||||
lambda data, features_directory, index: [[data, features_directory, index]]
|
||||
)
|
||||
|
||||
objects_mod = types.ModuleType("alphapulldown.objects")
|
||||
|
||||
class StubMultimericObject:
|
||||
def __init__(self, interactors):
|
||||
self.interactors = interactors
|
||||
self.feature_dict = {"msa": np.zeros((4, 12), dtype=np.int32)}
|
||||
|
||||
objects_mod.MultimericObject = StubMultimericObject
|
||||
|
||||
monkeypatch.setitem(sys.modules, "alphapulldown_input_parser", parser_mod)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"alphapulldown.utils.modelling_setup",
|
||||
modelling_setup_mod,
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "alphapulldown.objects", objects_mod)
|
||||
|
||||
return _load_module_from_path("test_split_jobs_module", SPLIT_JOBS_PATH)
|
||||
|
||||
|
||||
def _load_truncate_pickles_module(monkeypatch):
|
||||
flags_mod = types.ModuleType("absl.flags")
|
||||
flags_mod.FLAGS = SimpleNamespace()
|
||||
|
||||
def _define_string(name, default, help_text, required=False):
|
||||
del help_text, required
|
||||
setattr(flags_mod.FLAGS, name, default)
|
||||
|
||||
def _define_list(name, default, help_text):
|
||||
del help_text
|
||||
setattr(flags_mod.FLAGS, name, default)
|
||||
|
||||
def _define_integer(name, default, help_text):
|
||||
del help_text
|
||||
setattr(flags_mod.FLAGS, name, default)
|
||||
|
||||
flags_mod.DEFINE_string = _define_string
|
||||
flags_mod.DEFINE_list = _define_list
|
||||
flags_mod.DEFINE_integer = _define_integer
|
||||
|
||||
app_mod = types.ModuleType("absl.app")
|
||||
app_mod.run = lambda fn: fn([])
|
||||
|
||||
logging_mod = types.ModuleType("absl.logging")
|
||||
logging_mod.error = lambda *_args, **_kwargs: None
|
||||
|
||||
absl_pkg = types.ModuleType("absl")
|
||||
absl_pkg.app = app_mod
|
||||
absl_pkg.flags = flags_mod
|
||||
absl_pkg.logging = logging_mod
|
||||
|
||||
monkeypatch.setitem(sys.modules, "absl", absl_pkg)
|
||||
monkeypatch.setitem(sys.modules, "absl.app", app_mod)
|
||||
monkeypatch.setitem(sys.modules, "absl.flags", flags_mod)
|
||||
monkeypatch.setitem(sys.modules, "absl.logging", logging_mod)
|
||||
|
||||
return _load_module_from_path(
|
||||
"test_truncate_pickles_module",
|
||||
TRUNCATE_PICKLES_PATH,
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_seq_names_rewrites_headers_from_uniprot_style_fasta(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
capsys,
|
||||
):
|
||||
fasta_path = tmp_path / "input.fasta"
|
||||
fasta_path.write_text(
|
||||
">sp|Q9H9K5|Protein alpha OS=Test\nACDE\n>tr|P12345|Protein beta\nFGHI\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(sys, "argv", [str(PREPARE_SEQ_NAMES_PATH), str(fasta_path)])
|
||||
runpy.run_path(str(PREPARE_SEQ_NAMES_PATH), run_name="__main__")
|
||||
|
||||
assert capsys.readouterr().out.strip().splitlines() == [
|
||||
">Q9H9K5",
|
||||
"ACDE",
|
||||
">P12345",
|
||||
"FGHI",
|
||||
]
|
||||
|
||||
|
||||
def test_rename_colab_search_a3m_renames_legacy_search_outputs(monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "0.a3m").write_text(">protA\nACDE\n>hit\nACDE\n", encoding="utf-8")
|
||||
|
||||
rename_a3m.main()
|
||||
|
||||
renamed = tmp_path / "protA.a3m"
|
||||
assert renamed.read_text(encoding="utf-8") == ">protA\nACDE\n>hit\nACDE\n"
|
||||
assert not (tmp_path / "0.a3m").exists()
|
||||
|
||||
|
||||
def test_rename_colab_search_a3m_requires_input_fasta_for_new_colabfold(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "0.a3m").write_text(">101\nACDE\n>hit\nACDE\n", encoding="utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="Please provide the input FASTA file"):
|
||||
rename_a3m.main()
|
||||
|
||||
|
||||
def test_rename_colab_search_a3m_uses_input_fasta_names_for_new_colabfold(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "0.a3m").write_text(">101\nACDE\n>hit\nACDE\n", encoding="utf-8")
|
||||
fasta_path = tmp_path / "input.fasta"
|
||||
fasta_path.write_text(">queryA\nACDE\n", encoding="utf-8")
|
||||
|
||||
rename_a3m.main(str(fasta_path))
|
||||
|
||||
renamed = tmp_path / "queryA.a3m"
|
||||
assert renamed.read_text(encoding="utf-8") == ">queryA\nACDE\n>hit\nACDE\n"
|
||||
assert not (tmp_path / "0.a3m").exists()
|
||||
|
||||
|
||||
def test_generate_crosslink_pickle_parses_single_row_input(tmp_path, monkeypatch):
|
||||
links_path = tmp_path / "links.txt"
|
||||
links_path.write_text("5 A 9 B 0.05\n", encoding="utf-8")
|
||||
output_path = tmp_path / "crosslinks.pkl.gz"
|
||||
monkeypatch.setattr(
|
||||
crosslink_pickle,
|
||||
"parse_arguments",
|
||||
lambda: SimpleNamespace(csv=str(links_path), output=str(output_path)),
|
||||
)
|
||||
|
||||
crosslink_pickle.main()
|
||||
|
||||
with gzip.open(output_path, "rb") as handle:
|
||||
assert pickle.load(handle) == {"A": {"B": [(4, 8, 0.05)]}}
|
||||
|
||||
|
||||
def test_generate_crosslink_pickle_parse_arguments_reads_cli_flags(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
sys,
|
||||
"argv",
|
||||
["prog", "--csv", "links.txt", "--output", "crosslinks.pkl.gz"],
|
||||
)
|
||||
|
||||
args = crosslink_pickle.parse_arguments()
|
||||
|
||||
assert args.csv == "links.txt"
|
||||
assert args.output == "crosslinks.pkl.gz"
|
||||
|
||||
|
||||
def test_parse_input_main_writes_fold_spec_json(monkeypatch, tmp_path):
|
||||
module, app_calls, parser_calls, modelling_calls = _load_parse_input_module(
|
||||
monkeypatch
|
||||
)
|
||||
module.FLAGS = SimpleNamespace(
|
||||
input_list=["folds.txt"],
|
||||
protein_delimiter="+",
|
||||
features_directory=["/features"],
|
||||
output_prefix=str(tmp_path / "parsed_"),
|
||||
)
|
||||
|
||||
module.main([])
|
||||
|
||||
assert app_calls == [module.main]
|
||||
assert parser_calls["generate"] == {
|
||||
"input_files": ["folds.txt"],
|
||||
"delimiter": "+",
|
||||
"exclude_permutations": True,
|
||||
}
|
||||
assert modelling_calls["parse_fold"] == (
|
||||
["foldA"],
|
||||
["/features"],
|
||||
"+",
|
||||
)
|
||||
assert json.loads((tmp_path / "parsed_data.json").read_text(encoding="utf-8")) == {
|
||||
"parsed": ["foldA"]
|
||||
}
|
||||
|
||||
|
||||
def test_split_jobs_cluster_jobs_writes_cluster_files(monkeypatch, tmp_path):
|
||||
module = _load_split_jobs_module(monkeypatch)
|
||||
plot_calls = []
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"profile_all_jobs_and_cluster",
|
||||
lambda all_folds, args: pd.DataFrame(
|
||||
{
|
||||
"name": ["job_a", "job_b", "job_c"],
|
||||
"msa_depth": [20, 40, 60],
|
||||
"seq_length": [100, 120, 320],
|
||||
}
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"plot_clustering_result",
|
||||
lambda X, labels, num_cluster, output_dir: plot_calls.append(
|
||||
(X.copy(), np.asarray(labels), num_cluster, output_dir)
|
||||
),
|
||||
)
|
||||
|
||||
module.cluster_jobs(["job_a", "job_b", "job_c"], SimpleNamespace(output_dir=str(tmp_path)))
|
||||
|
||||
assert (tmp_path / "job_cluster1_120_40.txt").read_text(encoding="utf-8").splitlines() == [
|
||||
"job_a",
|
||||
"job_b",
|
||||
]
|
||||
assert (tmp_path / "job_cluster2_320_60.txt").read_text(encoding="utf-8").splitlines() == [
|
||||
"job_c",
|
||||
]
|
||||
assert plot_calls[0][2] == 2
|
||||
np.testing.assert_array_equal(
|
||||
plot_calls[0][0],
|
||||
np.asarray([[100, 20], [120, 40], [320, 60]], dtype=np.int64),
|
||||
)
|
||||
np.testing.assert_array_equal(plot_calls[0][1], np.asarray([0, 0, 1]))
|
||||
|
||||
|
||||
def test_split_jobs_main_normalises_generated_specs_and_all_vs_all_input(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
module = _load_split_jobs_module(monkeypatch)
|
||||
cluster_calls = []
|
||||
monkeypatch.setattr(
|
||||
module.argparse.ArgumentParser,
|
||||
"parse_args",
|
||||
lambda self: SimpleNamespace(
|
||||
protein_lists=["proteins.txt"],
|
||||
protein_delimiter="+",
|
||||
mode="all_vs_all",
|
||||
features_directory=["/features"],
|
||||
output_dir=str(tmp_path),
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"generate_fold_specifications",
|
||||
lambda **kwargs: ["A,B", "C;D"],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"cluster_jobs",
|
||||
lambda all_folds, args: cluster_calls.append((all_folds, args)),
|
||||
)
|
||||
|
||||
module.main()
|
||||
|
||||
assert cluster_calls[0][0] == ["A:B", "C+D"]
|
||||
assert cluster_calls[0][1].protein_lists == ["proteins.txt"]
|
||||
|
||||
|
||||
def test_truncate_pickles_main_copies_tree_and_removes_selected_pickle_keys(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
truncate_pickles = _load_truncate_pickles_module(monkeypatch)
|
||||
src_dir = tmp_path / "src"
|
||||
dst_dir = tmp_path / "dst"
|
||||
nested_src = src_dir / "nested"
|
||||
nested_dst = dst_dir / "nested"
|
||||
nested_src.mkdir(parents=True)
|
||||
nested_dst.mkdir(parents=True)
|
||||
with open(nested_src / "result.pkl", "wb") as handle:
|
||||
pickle.dump(
|
||||
{
|
||||
"keep": 1,
|
||||
"aligned_confidence_probs": [1, 2],
|
||||
"distogram": [3, 4],
|
||||
},
|
||||
handle,
|
||||
)
|
||||
(nested_src / "notes.txt").write_text("copied\n", encoding="utf-8")
|
||||
(nested_dst / "notes.txt").write_text("existing\n", encoding="utf-8")
|
||||
monkeypatch.setattr(
|
||||
truncate_pickles,
|
||||
"FLAGS",
|
||||
SimpleNamespace(
|
||||
src_dir=str(src_dir),
|
||||
dst_dir=str(dst_dir),
|
||||
keys_to_exclude="aligned_confidence_probs,distogram",
|
||||
number_of_threads=2,
|
||||
),
|
||||
)
|
||||
|
||||
truncate_pickles.main([])
|
||||
|
||||
with open(nested_dst / "result.pkl", "rb") as handle:
|
||||
assert pickle.load(handle) == {"keep": 1}
|
||||
assert (nested_dst / "notes.txt").read_text(encoding="utf-8") == "existing\n"
|
||||
|
||||
|
||||
def test_truncate_pickles_main_exits_when_source_dir_is_missing(monkeypatch, tmp_path):
|
||||
truncate_pickles = _load_truncate_pickles_module(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
truncate_pickles,
|
||||
"FLAGS",
|
||||
SimpleNamespace(
|
||||
src_dir=str(tmp_path / "missing"),
|
||||
dst_dir=str(tmp_path / "dst"),
|
||||
keys_to_exclude="aligned_confidence_probs,distogram",
|
||||
number_of_threads=1,
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
truncate_pickles.main([])
|
||||
161
test/unit/test_unifold_backend.py
Normal file
161
test/unit/test_unifold_backend.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
MODULE_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "alphapulldown"
|
||||
/ "folding_backend"
|
||||
/ "unifold_backend.py"
|
||||
)
|
||||
|
||||
|
||||
def _restore_modules(saved_modules: dict[str, types.ModuleType | None]) -> None:
|
||||
for name, module in saved_modules.items():
|
||||
if module is None:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = module
|
||||
|
||||
|
||||
def _install_unifold_backend_stubs() -> dict[str, types.ModuleType | None]:
|
||||
names_to_replace = [
|
||||
"alphapulldown.objects",
|
||||
"unifold",
|
||||
"unifold.config",
|
||||
"unifold.inference",
|
||||
"unifold.dataset",
|
||||
]
|
||||
saved_modules = {name: sys.modules.get(name) for name in names_to_replace}
|
||||
|
||||
objects_mod = types.ModuleType("alphapulldown.objects")
|
||||
objects_mod.MultimericObject = type("MultimericObject", (), {})
|
||||
|
||||
unifold_pkg = types.ModuleType("unifold")
|
||||
config_mod = types.ModuleType("unifold.config")
|
||||
config_mod.model_config = lambda model_name: {"model_name": model_name}
|
||||
|
||||
inference_mod = types.ModuleType("unifold.inference")
|
||||
inference_mod.calls = []
|
||||
inference_mod.config_args = (
|
||||
lambda model_dir, target_name, output_dir: {
|
||||
"model_dir": model_dir,
|
||||
"target_name": target_name,
|
||||
"output_dir": output_dir,
|
||||
}
|
||||
)
|
||||
inference_mod.unifold_config_model = lambda general_args: {"runner_args": general_args}
|
||||
inference_mod.unifold_predict = (
|
||||
lambda model_runner, model_args, processed_features: inference_mod.calls.append(
|
||||
(model_runner, model_args, processed_features)
|
||||
)
|
||||
)
|
||||
|
||||
dataset_mod = types.ModuleType("unifold.dataset")
|
||||
dataset_mod.process_ap = (
|
||||
lambda config, features, mode, labels, seed, batch_idx, data_idx, is_distillation: (
|
||||
{
|
||||
"processed_features": features,
|
||||
"seed": seed,
|
||||
"mode": mode,
|
||||
"config": config,
|
||||
},
|
||||
None,
|
||||
)
|
||||
)
|
||||
|
||||
modules = {
|
||||
"alphapulldown.objects": objects_mod,
|
||||
"unifold": unifold_pkg,
|
||||
"unifold.config": config_mod,
|
||||
"unifold.inference": inference_mod,
|
||||
"unifold.dataset": dataset_mod,
|
||||
}
|
||||
for name, module in modules.items():
|
||||
sys.modules[name] = module
|
||||
|
||||
unifold_pkg.config = config_mod
|
||||
unifold_pkg.inference = inference_mod
|
||||
unifold_pkg.dataset = dataset_mod
|
||||
|
||||
return saved_modules
|
||||
|
||||
|
||||
def _load_unifold_backend_module():
|
||||
saved_modules = _install_unifold_backend_stubs()
|
||||
sys.modules.pop("alphapulldown.folding_backend.unifold_backend", None)
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"alphapulldown.folding_backend.unifold_backend",
|
||||
MODULE_PATH,
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
assert spec.loader is not None
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
return module, saved_modules
|
||||
except Exception:
|
||||
sys.modules.pop(spec.name, None)
|
||||
_restore_modules(saved_modules)
|
||||
raise
|
||||
|
||||
|
||||
def test_unifold_setup_predict_and_postprocess():
|
||||
module, saved_modules = _load_unifold_backend_module()
|
||||
try:
|
||||
multimeric_object = SimpleNamespace(description="complex", feature_dict={"msa": [1, 2]})
|
||||
|
||||
configured = module.UnifoldBackend.setup(
|
||||
model_name="multimer",
|
||||
model_dir="/models",
|
||||
output_dir="/output",
|
||||
multimeric_object=multimeric_object,
|
||||
)
|
||||
assert configured == {
|
||||
"model_runner": {
|
||||
"runner_args": {
|
||||
"model_dir": "/models",
|
||||
"target_name": "complex",
|
||||
"output_dir": "/output",
|
||||
}
|
||||
},
|
||||
"model_args": {
|
||||
"model_dir": "/models",
|
||||
"target_name": "complex",
|
||||
"output_dir": "/output",
|
||||
},
|
||||
"model_config": {"model_name": "multimer"},
|
||||
}
|
||||
|
||||
backend = module.UnifoldBackend()
|
||||
assert (
|
||||
backend.predict(
|
||||
model_runner="runner",
|
||||
model_args={"arg": 1},
|
||||
model_config={"cfg": 2},
|
||||
multimeric_object=multimeric_object,
|
||||
random_seed=11,
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
inference_mod = sys.modules["unifold.inference"]
|
||||
assert inference_mod.calls == [
|
||||
(
|
||||
"runner",
|
||||
{"arg": 1},
|
||||
{
|
||||
"processed_features": {"msa": [1, 2]},
|
||||
"seed": 11,
|
||||
"mode": "predict",
|
||||
"config": {"cfg": 2},
|
||||
},
|
||||
)
|
||||
]
|
||||
assert module.UnifoldBackend.postprocess() is None
|
||||
finally:
|
||||
sys.modules.pop("alphapulldown.folding_backend.unifold_backend", None)
|
||||
_restore_modules(saved_modules)
|
||||
Reference in New Issue
Block a user