mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
233 lines
7.4 KiB
Python
Executable File
233 lines
7.4 KiB
Python
Executable File
import sys
|
|
import types
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent
|
|
TEST_IMPORT_PATHS = (
|
|
REPO_ROOT,
|
|
REPO_ROOT / "ColabFold",
|
|
REPO_ROOT / "alphafold",
|
|
REPO_ROOT / "alphafold3" / "src",
|
|
REPO_ROOT / "AlphaLink2",
|
|
)
|
|
|
|
for import_path in TEST_IMPORT_PATHS:
|
|
import_path_str = str(import_path)
|
|
if import_path.exists() and import_path_str not in sys.path:
|
|
sys.path.insert(0, import_path_str)
|
|
|
|
|
|
def _install_jax_tree_stub() -> None:
|
|
try:
|
|
import jax # noqa: F401
|
|
return
|
|
except Exception:
|
|
for module_name in list(sys.modules):
|
|
if module_name == "jax" or module_name.startswith("jax."):
|
|
sys.modules.pop(module_name, None)
|
|
|
|
import tree as dm_tree
|
|
|
|
jax_stub = types.ModuleType("jax")
|
|
jax_numpy_stub = types.ModuleType("jax.numpy")
|
|
jax_nn_stub = types.ModuleType("jax.nn")
|
|
jax_lax_stub = types.ModuleType("jax.lax")
|
|
tree_stub = types.ModuleType("jax.tree")
|
|
tree_util_stub = types.ModuleType("jax.tree_util")
|
|
version_stub = types.ModuleType("jax.version")
|
|
version_stub.__version__ = "0.0-test"
|
|
|
|
def _tree_map(func, *structures):
|
|
return dm_tree.map_structure(func, *structures)
|
|
|
|
def _tree_flatten(structure):
|
|
return dm_tree.flatten(structure), structure
|
|
|
|
def _tree_unflatten(treedef, leaves):
|
|
return dm_tree.unflatten_as(treedef, leaves)
|
|
|
|
def _tree_leaves(structure):
|
|
return dm_tree.flatten(structure)
|
|
|
|
def _register_pytree_node(*args, **kwargs):
|
|
return None
|
|
|
|
def _register_pytree_node_class(cls):
|
|
return cls
|
|
|
|
tree_stub.map = _tree_map
|
|
tree_stub.flatten = _tree_leaves
|
|
tree_stub.unflatten = _tree_unflatten
|
|
|
|
tree_util_stub.tree_map = _tree_map
|
|
tree_util_stub.tree_flatten = _tree_flatten
|
|
tree_util_stub.tree_unflatten = _tree_unflatten
|
|
tree_util_stub.tree_leaves = _tree_leaves
|
|
tree_util_stub.tree_structure = lambda structure: structure
|
|
tree_util_stub.register_pytree_node = _register_pytree_node
|
|
tree_util_stub.register_pytree_node_class = _register_pytree_node_class
|
|
|
|
jax_numpy_stub.__dict__.update(np.__dict__)
|
|
jax_nn_stub.softmax = lambda x, axis=-1: np.exp(x - np.max(x, axis=axis, keepdims=True)) / np.sum(
|
|
np.exp(x - np.max(x, axis=axis, keepdims=True)),
|
|
axis=axis,
|
|
keepdims=True,
|
|
)
|
|
jax_lax_stub.stop_gradient = lambda x: x
|
|
|
|
jax_stub.tree = tree_stub
|
|
jax_stub.tree_map = _tree_map
|
|
jax_stub.tree_util = tree_util_stub
|
|
jax_stub.numpy = jax_numpy_stub
|
|
jax_stub.nn = jax_nn_stub
|
|
jax_stub.lax = jax_lax_stub
|
|
jax_stub.Array = np.ndarray
|
|
jax_stub.local_devices = lambda: [types.SimpleNamespace(platform="cpu")]
|
|
jax_stub.devices = lambda *_args, **_kwargs: [types.SimpleNamespace(platform="cpu")]
|
|
jax_stub.default_backend = lambda: "cpu"
|
|
jax_stub.version = version_stub
|
|
|
|
sys.modules["jax"] = jax_stub
|
|
sys.modules["jax.numpy"] = jax_numpy_stub
|
|
sys.modules["jax.nn"] = jax_nn_stub
|
|
sys.modules["jax.lax"] = jax_lax_stub
|
|
sys.modules["jax.tree"] = tree_stub
|
|
sys.modules["jax.tree_util"] = tree_util_stub
|
|
sys.modules["jax.version"] = version_stub
|
|
|
|
|
|
_install_jax_tree_stub()
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
parser.addoption(
|
|
"--use-temp-dir",
|
|
action="store_true",
|
|
default=False,
|
|
help="Run functional test suites with isolated temporary output directories.",
|
|
)
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
for item in items:
|
|
path = Path(str(item.fspath))
|
|
parts = set(path.parts)
|
|
if "unit" in parts:
|
|
item.add_marker(pytest.mark.unit)
|
|
if "integration" in parts:
|
|
item.add_marker(pytest.mark.integration)
|
|
if "functional" in parts:
|
|
item.add_marker(pytest.mark.functional)
|
|
if "cluster" in parts:
|
|
item.add_marker(pytest.mark.cluster)
|
|
|
|
|
|
@pytest.hookimpl(tryfirst=True)
|
|
def pytest_itemcollected(item):
|
|
try:
|
|
par = getattr(item.parent, "obj", None)
|
|
node = getattr(item, "obj", None)
|
|
pref = (par.__doc__.strip() if getattr(par, "__doc__", None) else par.__class__.__name__) if par else ""
|
|
suf = (node.__doc__.strip() if getattr(node, "__doc__", None) else node.__name__) if node else ""
|
|
if pref or suf:
|
|
item._nodeid = " ".join(x for x in (pref, suf) if x)
|
|
except Exception:
|
|
pass
|
|
|
|
@pytest.fixture
|
|
def tmp_flags(monkeypatch, tmp_path):
|
|
"""
|
|
Provide a parsed absl FLAGS with sane defaults.
|
|
Values are applied via the underlying flag parsers (strings),
|
|
None values are left at their declared defaults.
|
|
"""
|
|
import alphapulldown.scripts.create_individual_features as cf
|
|
F = cf.FLAGS
|
|
|
|
# for save_meta_data.get_meta_dict
|
|
if not hasattr(F, "flag_values_dict"):
|
|
def _flag_values_dict():
|
|
return {k: getattr(F, k, None) for k in dir(F) if not k.startswith("_")}
|
|
F.flag_values_dict = _flag_values_dict
|
|
|
|
defaults = dict(
|
|
data_pipeline="alphafold2",
|
|
fasta_paths=[str(tmp_path / "a.fasta")],
|
|
data_dir="/db",
|
|
output_dir=str(tmp_path / "out"),
|
|
max_template_date="2021-09-30",
|
|
use_mmseqs2=False,
|
|
use_precomputed_msas=False,
|
|
save_msa_files=False,
|
|
skip_existing=False,
|
|
compress_features=False,
|
|
db_preset="full_dbs",
|
|
use_hhsearch=False,
|
|
# db flags we want to stay None (let code fill from data_dir):
|
|
uniref90_database_path=None,
|
|
uniref30_database_path=None,
|
|
mgnify_database_path=None,
|
|
bfd_database_path=None,
|
|
small_bfd_database_path=None,
|
|
pdb70_database_path=None,
|
|
uniprot_database_path=None,
|
|
pdb_seqres_database_path=None,
|
|
template_mmcif_dir=None,
|
|
obsolete_pdbs_path=None,
|
|
# misc
|
|
seq_index=None,
|
|
path_to_mmt=None,
|
|
description_file=None,
|
|
multiple_mmts=False,
|
|
# binaries
|
|
jackhmmer_binary_path="jackhmmer",
|
|
hhblits_binary_path="hhblits",
|
|
hhsearch_binary_path="hhsearch",
|
|
hmmsearch_binary_path="hmmsearch",
|
|
hmmbuild_binary_path="hmmbuild",
|
|
nhmmer_binary_path="nhmmer",
|
|
hmmalign_binary_path="hmmalign",
|
|
kalign_binary_path="kalign",
|
|
)
|
|
|
|
# helper: stringify for absl parser
|
|
def _to_arg_str(name, val):
|
|
if val is None:
|
|
return None
|
|
if isinstance(val, bool):
|
|
return "true" if val else "false"
|
|
if isinstance(val, (int, float, str)):
|
|
return str(val)
|
|
if isinstance(val, (list, tuple)):
|
|
# absl list flags expect comma-separated string
|
|
return ",".join(str(x) for x in val)
|
|
# fallback
|
|
return str(val)
|
|
|
|
flags_dict = F._flags() # name -> Flag
|
|
for name, value in defaults.items():
|
|
fl = flags_dict.get(name)
|
|
if fl is None:
|
|
# not an absl flag on this FLAGS; attach as attribute for completeness
|
|
setattr(F, name, value)
|
|
continue
|
|
arg = _to_arg_str(name, value)
|
|
if arg is None:
|
|
# leave at declared default (often None); don't parse
|
|
continue
|
|
fl.parse(arg) # sets .value and marks present
|
|
|
|
# mark the whole set parsed to allow reads of flags we left at default
|
|
if hasattr(F, "mark_as_parsed"):
|
|
F.mark_as_parsed()
|
|
else:
|
|
F(["pytest"])
|
|
|
|
# ensure module under test uses this FLAGS instance
|
|
monkeypatch.setattr(cf, "FLAGS", F)
|
|
return F
|