diff --git a/.env b/.env.sample similarity index 77% rename from .env rename to .env.sample index 4576273..c14106c 100644 --- a/.env +++ b/.env.sample @@ -11,7 +11,7 @@ # expected that you use the same saving conventions as the RCSB PDB, which means: # `1a2b` --> /path/to/pdb_mirror/a2/1a2b.cif.gz # To set up a mirror, you can use tha atomworks commandline: `atomworks pdb sync /path/to/mirror` -PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2025_07_13_pdb +PDB_MIRROR_PATH= # The `CCD_MIRROR_PATH` is a path to a local mirror of the CCD database. # It's expected that you use the same saving conventions as the RCSB CCD, which means: @@ -20,10 +20,10 @@ PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2025_07_13_pdb # If no mirror is provided, the internal biotite CCD will be used as a fallback. To provide a # custom CCD for a ligand, you can place it in the in the CCD mirror path following the CCDs pattern. # Example: /path/to/ccd_mirror/M/MYLIGAND1/MYLIGAND1.cif -CCD_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2025_07_13_ccd +CCD_MIRROR_PATH= # --- Local MSA directories --- -LOCAL_MSA_DIRS=/projects/msa/hhblits,/projects/msa/mmseqs_gpu,/projects/msa/lab,/squash/mgnify_distill_rf3/msas +LOCAL_MSA_DIRS= # --- External tools --- @@ -32,24 +32,27 @@ LOCAL_MSA_DIRS=/projects/msa/hhblits,/projects/msa/mmseqs_gpu,/projects/msa/lab, # Example: /path/to/x3dna-v2.4 X3DNA_PATH= +# For secondary structure prediction (not currently used) +DSSP_PATH= + # The `HHFILTER_PATH` is a path to the hhfilter tool from the HH-suite, which is used for # filtering MSAs to reduce redundancy. # Example: /path/to/hhsuite/build/bin/hhfilter -HHFILTER_PATH=/net/software/hhsuite/build/bin/hhfilter +HHFILTER_PATH= # The `MMSEQS2_PATH` is a path to the mmseqs2 tool, which is used for fast sequence searching. # Example: /path/to/mmseqs-gpu/bin/mmseqs -MMSEQS2_PATH=/net/software/mmseqs-gpu/bin/mmseqs +MMSEQS2_PATH= # CollabFold MMseqs2 database paths for GPU and CPU usage. # Local access (preferred) # NOTE: MMseqs2 databases are best stored on local drives of compute nodes for performance -COLABFOLD_LOCAL_DB_PATH_GPU=/local/colabfold/gpu -COLABFOLD_LOCAL_DB_PATH_CPU=/local/databases/colabfold/ +COLABFOLD_LOCAL_DB_PATH_GPU= +COLABFOLD_LOCAL_DB_PATH_CPU= # Network access (fallback; may cause IO-related issues) -COLABFOLD_NET_DB_PATH_GPU=/net/databases/colabfold/gpu -COLABFOLD_NET_DB_PATH_CPU=/net/databases/colabfold/ +COLABFOLD_NET_DB_PATH_GPU= +COLABFOLD_NET_DB_PATH_CPU= diff --git a/.gitignore b/.gitignore index aceebc1..1fd32e8 100644 --- a/.gitignore +++ b/.gitignore @@ -211,3 +211,5 @@ logs/ **/scripts/slurm/ **/benchmarks/ +# AI +**/CLAUDE.md diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..d68c38a --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "lib/atomworks"] + path = lib/atomworks + url = git@github.com:baker-laboratory/atomworks-dev.git diff --git a/.ipd/apptainer/rf3-dev.def b/.ipd/apptainer/rf3-dev.def new file mode 100644 index 0000000..35b04cb --- /dev/null +++ b/.ipd/apptainer/rf3-dev.def @@ -0,0 +1,87 @@ +Bootstrap: docker +From: nvcr.io/nvidia/pytorch:25.06-py3 +IncludeCmd: yes + +# RF3 Development Container - Dependencies only (excludes atomworks, which we expose through a submodule for development) + +# To build this Apptainer, from the project root, run: +# apptainer build rf3-dev.sif .ipd/apptainer/rf3-dev.def + +%labels + Author Institute for Protein Design + Version 1.0-dev + Description RosettaFold3 - Dependencies only (for development and model training) + +%setup + # Create a directory in the container to bind the host's current working directory + mkdir ${APPTAINER_ROOTFS}/modelhub_host + # ... for mounting `/projects` with --bind + mkdir ${APPTAINER_ROOTFS}/projects + # ... for mounting `/databases` with --bind + mkdir ${APPTAINER_ROOTFS}/net + # ... for mounting `/squash` with --bind + mkdir ${APPTAINER_ROOTFS}/squash + +%files + /etc/localtime + /etc/hosts + pyproject.toml /opt/core_pyproject.toml + models/rf3/pyproject.toml /opt/rf3_pyproject.toml + +%post + ## GENERAL SETUP + + # Common symlinks (within container) + ln -s /net/databases /databases + ln -s /net/software /software + ln -s /home /mnt/home + ln -s /projects /mnt/projects + ln -s /net /mnt/net + + ## PYTHON DEPENDENCY INSTALLATION + + # Fix NGC constraints that conflict with our required packages + # ... remove packaging constraint to allow biotite 1.3.0 installation + sed -i '/packaging==/d' /etc/pip/constraint.txt + + # ... remove pytest constraint + sed -i '/pytest==/d' /etc/pip/constraint.txt + + # Install uv for fast dependency resolution + pip install uv + + # Auto-generate dependency lists for both core and RF3, with all extra dependencies + # (Core) + mv /opt/core_pyproject.toml /opt/pyproject.toml + uv pip compile /opt/pyproject.toml --output-file /opt/core_requirements.txt --all-extras + rm /opt/pyproject.toml + + # (RF3) + mv /opt/rf3_pyproject.toml /opt/pyproject.toml + uv pip compile /opt/pyproject.toml --output-file /opt/rf3_requirements.txt --all-extras + rm /opt/pyproject.toml + + # Install core dependencies (excluding torch packages and numpy from NGC container) + grep -vE "^(torch(|vision|audio)|numpy)==" /opt/core_requirements.txt > /opt/core_requirements_filtered.txt + uv pip install --system --break-system-packages -r /opt/core_requirements_filtered.txt + + # Install RF3 dependencies (excluding atomworks, torch packages, and numpy) + grep -vE "^(atomworks|torch(|vision|audio)|numpy)==" /opt/rf3_requirements.txt > /opt/rf3_requirements_filtered.txt + uv pip install --system --break-system-packages -r /opt/rf3_requirements_filtered.txt + + # Downgrade NumPy to <2.0 (NVIDIA 25.06 is not compatible with NumPy >=2) + uv pip install --system --break-system-packages "numpy<2" + +%environment + # (Flag to increase accessible GPU memory) + export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + + # (Turn off NVLink) + export NCCL_P2P_DISABLE=1 + +%runscript + # NOTE: The %runscript is invoked when the container is run without specifying a different command. + exec python "$@" + +%help + RosettaFold3 (RF3) development container with dependencies only. diff --git a/.ipd/shebang/README.md b/.ipd/shebang/README.md new file mode 100644 index 0000000..9566241 --- /dev/null +++ b/.ipd/shebang/README.md @@ -0,0 +1,6 @@ +This directory contains scripts that are not to be run directly by the user. +They are [SHEBANG scripts](https://en.wikipedia.org/wiki/Shebang_(Unix)) that are used to run the appropriate apptainer container. + +For example, the script `rf3_exec.sh` is used to run the RF3 apptainer container with the latest apptainer image stored locally or at the IPD. + +The shebang lines (`#!/bin/bash` ...) at the top of entry point scripts like `train.py` redirect the system to here to find the correct apptainer container. \ No newline at end of file diff --git a/.ipd/shebang/rf3_exec.sh b/.ipd/shebang/rf3_exec.sh new file mode 100755 index 0000000..13ec394 --- /dev/null +++ b/.ipd/shebang/rf3_exec.sh @@ -0,0 +1,116 @@ +#!/usr/bin/bash + +################### +# You can add the path to this file as the shebang line in your python script. +# Then by default, the python script will be executed with the python interpreter +# in the SIF_PATH container. Here, we launch the container with nvidia gpu and slurm support. +# +# Example shebang: #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/.ipd/shebang/rf3_exec.sh" "$0" "$@"' +################### + +# Let the user know this script is setting things up behind the scene +SCRIPT_PATH=$(realpath $0) +SCRIPT_DIR=$(dirname $SCRIPT_PATH) +echo '################## Start shebang info ##################' +echo "The file $SCRIPT_PATH is being run as a shebang executable. + It will... + + 1. Add 'src/modelhub', 'models/rf3/src', and 'lib/atomworks/src' to your PYTHONPATH. + 2. Run your python script from the right container, which contains all dependencies. + 3. Launch the container with slurm and nvidia gpu support." + +# Extract the path to the Python script from the arguments +PYTHON_SCRIPT=$(realpath "$1") +shift + +# Find repository root by looking for .project-root file +find_repo_root() { + local current_dir="$1" + while [ "$current_dir" != "/" ]; do + if [ -f "$current_dir/.project-root" ]; then + echo "$current_dir" + return 0 + fi + current_dir="$(dirname "$current_dir")" + done + return 1 +} + +echo +echo "Searching for repository root directory..." +REPO_ROOT=$(find_repo_root "$(dirname "$PYTHON_SCRIPT")") +if [ -z "$REPO_ROOT" ]; then + echo "Error: Could not find .project-root file in any parent directory" + exit 1 +else + echo "... found repository root at '$REPO_ROOT'" +fi + +# Function to add a directory to PYTHONPATH if it's not already included +add_to_pythonpath() { + local dir_path="$1" + if [[ ":$PYTHONPATH:" != *":$dir_path:"* ]]; then + export PYTHONPATH="$dir_path:$PYTHONPATH" + echo "Added '$dir_path' to PYTHONPATH." + else + echo "'$dir_path' is already in PYTHONPATH." + fi +} + +# Add modelhub, rf3, and atomworks to PYTHONPATH +echo +echo "Adding modelhub, RF3, and atomworks to PYTHONPATH..." +MODELHUB_PATH="$REPO_ROOT/src" +RF3_PATH="$REPO_ROOT/models/rf3/src" +ATOMWORKS_PATH="$REPO_ROOT/lib/atomworks/src" +add_to_pythonpath "$MODELHUB_PATH" +add_to_pythonpath "$RF3_PATH" +add_to_pythonpath "$ATOMWORKS_PATH" + +echo +echo "Fetching the appropriate apptainer image..." + +SIF_PATH="$REPO_ROOT/.ipd/apptainer/rf3-dev.sif" + +echo "... looking for a local apptainer image at '$SIF_PATH'" +if [ ! -f "$SIF_PATH" ]; then + echo "... apptainer not found. To build it, run: apptainer build .ipd/apptainer/rf3-dev.sif .ipd/apptainer/rf3-dev.def" + echo "Attempting to run $PYTHON_SCRIPT with $(which python)" + SIF_PATH="" +fi + +# Function to print debug=mode warning +print_debug_warning() { + echo + echo "###############################################################################" + echo "# #" + echo "# ⚠️ WARNING ⚠️ #" + echo "# RUNNING WITH DEBUGPY ON PORT $DEBUG_PORT #" + echo "# DON'T FORGET TO ATTACH A DEBUGGER #" + echo "# #" + echo "###############################################################################" + echo +} + +if [ -n "$DEBUG_PORT" ]; then + print_debug_warning + python_cmd="python -m debugpy --listen $DEBUG_PORT --wait-for-client" +else + python_cmd="python" + echo +fi + +if [ ! -z $SIF_PATH ]; then + echo "Running $PYTHON_SCRIPT with apptainer: $SIF_PATH." + echo '################## End shebang info ####################' + echo + /usr/bin/apptainer exec --nv --slurm \ + --bind "$REPO_ROOT:$REPO_ROOT" \ + --env PYTHONPATH="\$PYTHONPATH:$PYTHONPATH" \ + $SIF_PATH $python_cmd "$PYTHON_SCRIPT" "$@" +else + echo "Running $PYTHON_SCRIPT with python: $(which python)" + echo '################## End shebang info ####################' + echo + $python_cmd "$PYTHON_SCRIPT" "$@" +fi diff --git a/.ipd/slurm/launch_rf3.sh b/.ipd/slurm/launch_rf3.sh new file mode 100644 index 0000000..51adbfe --- /dev/null +++ b/.ipd/slurm/launch_rf3.sh @@ -0,0 +1,65 @@ +#!/bin/bash +#SBATCH -p gpu-train +#SBATCH --nodes 1 +#SBATCH --gres=gpu:l40:8 +#SBATCH --ntasks-per-node 8 +#SBATCH -c 4 +#SBATCH --mem=512g +#SBATCH -t 1-00:00:00 +#SBATCH -J none-00-dummy +#SBATCH -o slurm_logs/%x_%j.out +#SBATCH -e slurm_logs/%x_%j.err +#SBATCH --no-kill=off + +### To call this script run: `sbatch launch.sh` from this directory +### For reference, see the Lightning Fabric + SLURM guide: https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html + +# (In case we're still running in debug mode) +unset DEBUG_PORT + +# (SLURM setup, ensuring we have a unique port per job, and setting the master address to Rank 0) +export MASTER_PORT=$((1024 + RANDOM % 64512)) +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) + +### Install the project (which will set the paths correctly) + +### Environment flags + +# Debugging flags (optional) +export NCCL_DEBUG=INFO # NCCL internal debugging +export PYTHONFAULTHANDLER=1 # Catches Python core dumps (e.g., segmentation faults) + +# Expand CUDA memory +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Turn off NVLink (L40 do not have NVLink) +export NCCL_P2P_DISABLE=1 + +# OPENMP and OPENBLAS optimizations +# https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#utilize-openmp +# NOTE: Must be optimized per-system; see: https://github.com/pytorch/pytorch/blob/65e6194aeb3269a182cfe2c05c122159da12770f/torch/distributed/run.py#L596-L608 +export OMP_NUM_THREADS=4 +export OPENBLAS_NUM_THREADS=4 + +### Set the effective batch size +### NOTE: Should be adjusted based on specific use case +EFFECTIVE_BATCH_SIZE=16 + +### Compose the training script +DEVICES_PER_NODE=${SLURM_NTASKS_PER_NODE:-8} # Default to 8 if not set +echo "Running on $SLURM_NNODES nodes with $DEVICES_PER_NODE tasks per node" + +### Calculate grad_accum_steps +GRAD_ACCUM_STEPS=$((EFFECTIVE_BATCH_SIZE / (DEVICES_PER_NODE * SLURM_NNODES))) +echo "Grad Accumulation Steps: $GRAD_ACCUM_STEPS" + +command="srun --kill-on-bad-exit ../../models/rf3/src/rf3/train.py \ + experiment=$SLURM_JOB_NAME \ + ++trainer.devices_per_node=$DEVICES_PER_NODE \ + ++trainer.num_nodes=$SLURM_NNODES \ + ++trainer.grad_accum_steps=$GRAD_ACCUM_STEPS" + +echo -e "command\t$command" + +# Let 'er rip +$command diff --git a/Makefile b/Makefile index 7aa13e1..ed1e572 100644 --- a/Makefile +++ b/Makefile @@ -14,8 +14,8 @@ clean: ## Format src directory using black format: - ruff format src tests - ruff check --fix src tests + ruff format src models tests + ruff check --fix src models tests ################################################################################# # Self Documenting Commands # diff --git a/lib/atomworks b/lib/atomworks new file mode 160000 index 0000000..20682ed --- /dev/null +++ b/lib/atomworks @@ -0,0 +1 @@ +Subproject commit 20682edcb5842a236b7145ae1593a3d0578a2fbb diff --git a/models/rf3/CONTAINER.md b/models/rf3/CONTAINER.md deleted file mode 100644 index 840d828..0000000 --- a/models/rf3/CONTAINER.md +++ /dev/null @@ -1,89 +0,0 @@ -# RF3 Apptainer Containers - -This directory contains two Apptainer definition files for different use cases. - -## Container Options - -### 1. `rf3-standalone.def` - Standalone Container -Contains a complete snapshot of the modelhub repository at build time. Use this for: -- Production deployments -- Reproducible inference with a fixed codebase -- Running on systems where you don't have the repository - -### 2. `rf3-dev.def` - Development Container -Contains only Python dependencies (from `requirements.txt`). Use this for: -- Active development and testing -- Working with your local modelhub code -- Debugging and modifying RF3 code - -**Note**: Generate `requirements.txt` first using `uv pip compile pyproject.toml -o requirements.txt` before building this container. - -## Prerequisites - -- Apptainer/Singularity installed -- NVIDIA GPU with CUDA 12.1+ support -- Sufficient disk space (~10GB per container) - -## Building Containers - -Build from the `models/rf3/` directory: - -```bash -cd models/rf3/ - -# Build standalone container (includes modelhub snapshot) -apptainer build rf3-standalone.sif rf3-standalone.def - -# Build development container (dependencies only) -apptainer build rf3-dev.sif rf3-dev.def - -# Or build with sandbox for debugging -apptainer build --sandbox rf3_sandbox/ rf3-standalone.def -``` - -## Using the Standalone Container - -The standalone container has the full repository baked in. - -### Basic Inference - -```bash -# Run inference on a single input -apptainer exec --nv rf3-standalone.sif rf3 fold inputs='input.json' - -# Process CIF/PDB files -apptainer exec --nv rf3-standalone.sif rf3 fold inputs='structure.cif' - -# Batch processing -apptainer exec --nv rf3-standalone.sif rf3 fold inputs='[file1.cif,file2.json,file3.pdb]' -``` - -### With Custom Weights - -```bash -# Mount weights directory and specify checkpoint -apptainer exec --nv \ - --bind /path/to/weights:/weights \ - rf3-standalone.sif \ - rf3 fold inputs='input.json' ckpt_path='/weights/rf3_latest.pt' -``` - -## Using the Development Container - -The development container requires mounting your local modelhub repository. - -### Basic Usage - -```bash -# Run with local modelhub repository -apptainer exec --nv \ - --bind /path/to/modelhub:/opt/modelhub \ - rf3-dev.sif \ - rf3 fold inputs='input.json' - -# Example with actual path -apptainer exec --nv \ - --bind $PWD/../..:/opt/modelhub \ - rf3-dev.sif \ - rf3 fold inputs='input.json' -``` diff --git a/models/rf3/rf3-dev.def b/models/rf3/rf3-dev.def deleted file mode 100644 index 0d8e871..0000000 --- a/models/rf3/rf3-dev.def +++ /dev/null @@ -1,61 +0,0 @@ -Bootstrap: docker -From: nvcr.io/nvidia/pytorch:25.04-py3 -IncludeCmd: yes - -# RF3 Development Container - Dependencies only - -%labels - Author Institute for Protein Design - Version 1.0-dev - Description RosettaFold3 - Dependencies only (for development and model training) - -%setup - # Create a directory in the container to bind the host's current working directory - mkdir ${APPTAINER_ROOTFS}/modelhub_host - # ... for mounting `/projects` with --bind - mkdir ${APPTAINER_ROOTFS}/projects - # ... for mounting `/databases` with --bind - mkdir ${APPTAINER_ROOTFS}/net - # ... for mounting `/squash` with --bind - mkdir ${APPTAINER_ROOTFS}/squash - -%files - /etc/localtime - /etc/hosts - pyproject.toml /opt/pyproject.toml - -%post - ## GENERAL SETUP - - # Common symlinks (within container) - ln -s /net/databases /databases - ln -s /net/software /software - ln -s /home /mnt/home - ln -s /projects /mnt/projects - ln -s /net /mnt/net - - ## PYTHON DEPENDENCY INSTALLATION - - # Install uv for fast package installation (10-100x faster than pip) - pip install uv - - # Auto-generate a dependency list - uv pip compile /opt/pyproject.toml -o /opt/requirements.txt - - # Install all Python dependencies using requirements.txt with uv - uv pip install --system --break-system-packages -r /opt/requirements.txt - - -%environment - # (Flag to increase accessible GPU memory) - export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True - - # (Turn off NVLink) - export NCCL_P2P_DISABLE=1 - -%runscript - # NOTE: The %runscript is invoked when the container is run without specifying a different command. - exec python "$@" - -%help - RosettaFold3 (RF3) development container with dependencies only. diff --git a/models/rf3/src/rf3/_version.py b/models/rf3/src/rf3/_version.py index f1bd7c5..1c3ec6f 100644 --- a/models/rf3/src/rf3/_version.py +++ b/models/rf3/src/rf3/_version.py @@ -12,8 +12,7 @@ __all__ = [ TYPE_CHECKING = False if TYPE_CHECKING: - from typing import Tuple - from typing import Union + from typing import Tuple, Union VERSION_TUPLE = Tuple[Union[int, str], ...] COMMIT_ID = Union[str, None] @@ -28,7 +27,7 @@ version_tuple: VERSION_TUPLE commit_id: COMMIT_ID __commit_id__: COMMIT_ID -__version__ = version = '0.1.dev917+gcbbe4c6a6.d20251001' -__version_tuple__ = version_tuple = (0, 1, 'dev917', 'gcbbe4c6a6.d20251001') +__version__ = version = "0.1.dev917+gcbbe4c6a6.d20251001" +__version_tuple__ = version_tuple = (0, 1, "dev917", "gcbbe4c6a6.d20251001") __commit_id__ = commit_id = None diff --git a/models/rf3/src/rf3/callbacks/dump_validation_structures.py b/models/rf3/src/rf3/callbacks/dump_validation_structures.py index 0358e06..af107e4 100644 --- a/models/rf3/src/rf3/callbacks/dump_validation_structures.py +++ b/models/rf3/src/rf3/callbacks/dump_validation_structures.py @@ -3,14 +3,14 @@ from pathlib import Path from atomworks.ml.example_id import parse_example_id from beartype.typing import Any - -from modelhub.callbacks.callback import BaseCallback from rf3.utils.io import ( build_stack_from_atom_array_and_batched_coords, dump_structures, dump_trajectories, ) +from modelhub.callbacks.callback import BaseCallback + class DumpValidationStructuresCallback(BaseCallback): """Dump predicted structures and/or diffusion trajectories during validation""" diff --git a/models/rf3/src/rf3/cli.py b/models/rf3/src/rf3/cli.py index e78bcd2..df585ea 100644 --- a/models/rf3/src/rf3/cli.py +++ b/models/rf3/src/rf3/cli.py @@ -1,11 +1,8 @@ -import os from pathlib import Path import typer from hydra import compose, initialize_config_dir -from rf3.inference import run_inference - app = typer.Typer() @@ -42,6 +39,9 @@ def fold(ctx: typer.Context): with initialize_config_dir(config_dir=config_path, version_base="1.3"): cfg = compose(config_name="inference", overrides=hydra_overrides) + # Lazy import to avoid loading heavy dependencies at CLI startup + from rf3.inference import run_inference + run_inference(cfg) diff --git a/models/rf3/src/rf3/data/pipeline_utils.py b/models/rf3/src/rf3/data/pipeline_utils.py index 18f39b3..e7c2b54 100644 --- a/models/rf3/src/rf3/data/pipeline_utils.py +++ b/models/rf3/src/rf3/data/pipeline_utils.py @@ -5,7 +5,6 @@ from atomworks.enums import ChainType from atomworks.ml.transforms._checks import check_atom_array_annotation from atomworks.ml.transforms.crop import compute_local_hash from omegaconf import DictConfig - from rf3.data.ground_truth_template import ( FeaturizeNoisedGroundTruthAsTemplateDistogram, TokenGroupNoiseScaleSampler, diff --git a/models/rf3/src/rf3/data/pipelines.py b/models/rf3/src/rf3/data/pipelines.py index 8157ee2..af9b950 100644 --- a/models/rf3/src/rf3/data/pipelines.py +++ b/models/rf3/src/rf3/data/pipelines.py @@ -98,7 +98,6 @@ from atomworks.ml.transforms.random_atomize_residues import RandomAtomizeResidue from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX from omegaconf import DictConfig - from rf3.data.extra_xforms import CheckForNaNsInInputs from rf3.data.pipeline_utils import ( annotate_post_crop_hash, diff --git a/models/rf3/src/rf3/data/rotation_augmentation.py b/models/rf3/src/rf3/data/rotation_augmentation.py index 765b010..857f52b 100644 --- a/models/rf3/src/rf3/data/rotation_augmentation.py +++ b/models/rf3/src/rf3/data/rotation_augmentation.py @@ -1,7 +1,6 @@ import math import torch - from rf3.flow_matching.rigid_utils import rot_vec_mul diff --git a/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py b/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py index 9190aba..012b67f 100755 --- a/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py +++ b/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py @@ -1,8 +1,8 @@ import torch from beartype.typing import Any, Literal from jaxtyping import Float - from rf3.data.rotation_augmentation import centre_random_augmentation + from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/models/rf3/src/rf3/inference.py b/models/rf3/src/rf3/inference.py index ba8e8e1..75f62fd 100755 --- a/models/rf3/src/rf3/inference.py +++ b/models/rf3/src/rf3/inference.py @@ -1,4 +1,4 @@ -#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../scripts/shebang/modelhub_exec.sh" "$0" "$@"' +#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"' import os import tempfile @@ -6,23 +6,19 @@ from pathlib import Path import hydra import rootutils -from dotenv import load_dotenv from hydra.utils import instantiate from omegaconf import DictConfig +from modelhub.utils.env import load_ipd_dotenv from modelhub.utils.logging import suppress_warnings -load_dotenv(override=True) - # Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils) # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located) rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -# Find the RF3 configs directory relative to this file -# This file is at: models/rf3/src/rf3/inference.py -# Configs are at: models/rf3/configs/ -rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/ -_config_path = str(rf3_package_dir / "configs") +load_ipd_dotenv(override=True) + +_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs") @hydra.main( diff --git a/models/rf3/src/rf3/inference_engines/rf3.py b/models/rf3/src/rf3/inference_engines/rf3.py index eda6076..9a3c36b 100644 --- a/models/rf3/src/rf3/inference_engines/rf3.py +++ b/models/rf3/src/rf3/inference_engines/rf3.py @@ -10,11 +10,12 @@ from atomworks.io.transforms.categories import category_to_dict from lightning.fabric import seed_everything from omegaconf import OmegaConf +from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability +from modelhub.utils.logging import print_config_tree from rf3.model.RF3 import ShouldEarlyStopFn from rf3.utils.datasets import ( assemble_distributed_inference_loader_from_list_of_paths, ) -from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability from rf3.utils.inference import ( apply_conformer_and_template_selections, build_file_paths_for_prediction, @@ -24,14 +25,17 @@ from rf3.utils.io import ( dump_structures, dump_trajectories, ) -from modelhub.utils.logging import print_config_tree from rf3.utils.predicted_error import ( annotate_atom_array_b_factor_with_plddt, compile_af3_confidence_outputs, get_mean_atomwise_plddt, ) -logging.basicConfig(level=logging.INFO) +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + datefmt="%H:%M:%S", +) ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/models/rf3/src/rf3/loss/af3_confidence_loss.py b/models/rf3/src/rf3/loss/af3_confidence_loss.py index 3c2a906..3bd57f3 100644 --- a/models/rf3/src/rf3/loss/af3_confidence_loss.py +++ b/models/rf3/src/rf3/loss/af3_confidence_loss.py @@ -1,7 +1,5 @@ import torch import torch.nn as nn -from scipy.stats import spearmanr - from rf3.chemical import NFRAMES, NHEAVY, frame_indices # TODO: REFACTOR; COPIED FROM RF2AA. WE NEED TO ADD DOCSTRINGS, EXAMPLES, HOPEFULLY TESTS, AND CLEAN UP @@ -14,6 +12,7 @@ from rf3.utils.frames import ( mask_unresolved_frames_batched, rigid_from_3_points, ) +from scipy.stats import spearmanr class ConfidenceLoss(nn.Module): diff --git a/models/rf3/src/rf3/loss/af3_losses.py b/models/rf3/src/rf3/loss/af3_losses.py index f567da3..b81c5bc 100644 --- a/models/rf3/src/rf3/loss/af3_losses.py +++ b/models/rf3/src/rf3/loss/af3_losses.py @@ -2,7 +2,6 @@ import hydra import numpy as np import torch import torch.nn as nn - from rf3.alignment import weighted_rigid_align from rf3.training.checkpoint import activation_checkpointing diff --git a/models/rf3/src/rf3/metrics/chiral.py b/models/rf3/src/rf3/metrics/chiral.py index d8d8ce3..37b99ca 100644 --- a/models/rf3/src/rf3/metrics/chiral.py +++ b/models/rf3/src/rf3/metrics/chiral.py @@ -8,8 +8,8 @@ from atomworks.ml.transforms.rdkit_utils import get_rdkit_chiral_centers from beartype.typing import Any from biotite.structure import AtomArray, AtomArrayStack from jaxtyping import Bool, Float - from rf3.kinematics import get_dih + from modelhub.metrics.metric import Metric diff --git a/models/rf3/src/rf3/metrics/distogram.py b/models/rf3/src/rf3/metrics/distogram.py index 81cfd95..d514f11 100644 --- a/models/rf3/src/rf3/metrics/distogram.py +++ b/models/rf3/src/rf3/metrics/distogram.py @@ -9,8 +9,8 @@ from beartype.typing import Any, Literal from biotite.structure import AtomArrayStack from einops import rearrange, repeat from jaxtyping import Bool, Float - from rf3.loss.af3_losses import distogram_loss + from modelhub.metrics.metric import Metric from modelhub.utils.torch import assert_no_nans diff --git a/models/rf3/src/rf3/metrics/predicted_error.py b/models/rf3/src/rf3/metrics/predicted_error.py index 468e9d7..37a5160 100644 --- a/models/rf3/src/rf3/metrics/predicted_error.py +++ b/models/rf3/src/rf3/metrics/predicted_error.py @@ -1,9 +1,9 @@ from typing import Any import torch +from rf3.metrics.metric_utils import find_bin_midpoints from modelhub.metrics.metric import Metric -from rf3.metrics.metric_utils import find_bin_midpoints def compute_ptm( diff --git a/models/rf3/src/rf3/model/RF3.py b/models/rf3/src/rf3/model/RF3.py index 5aaa6ed..b57eef5 100644 --- a/models/rf3/src/rf3/model/RF3.py +++ b/models/rf3/src/rf3/model/RF3.py @@ -5,8 +5,6 @@ import torch import torch.utils.checkpoint as checkpoint from beartype.typing import Any, Generator, Protocol from omegaconf import DictConfig -from torch import nn - from rf3.diffusion_samplers.inference_sampler import ( SampleDiffusion, SamplePartialDiffusion, @@ -16,6 +14,7 @@ from rf3.model.layers.pairformer_layers import ( ) from rf3.model.RF3_structure import DiffusionModule, DistogramHead, Recycler from rf3.training.checkpoint import create_custom_forward +from torch import nn """ Shape Annotation Glossary: diff --git a/models/rf3/src/rf3/model/RF3_blocks.py b/models/rf3/src/rf3/model/RF3_blocks.py index e274600..082082a 100644 --- a/models/rf3/src/rf3/model/RF3_blocks.py +++ b/models/rf3/src/rf3/model/RF3_blocks.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from rf3.training.checkpoint import activation_checkpointing diff --git a/models/rf3/src/rf3/model/RF3_structure.py b/models/rf3/src/rf3/model/RF3_structure.py index b1b7402..b9f25fe 100644 --- a/models/rf3/src/rf3/model/RF3_structure.py +++ b/models/rf3/src/rf3/model/RF3_structure.py @@ -2,7 +2,6 @@ import logging import torch import torch.nn as nn - from rf3.model.layers.af3_diffusion_transformer import ( AtomAttentionEncoderDiffusion, AtomTransformer, diff --git a/models/rf3/src/rf3/model/layers/Attention_module.py b/models/rf3/src/rf3/model/layers/Attention_module.py index 2f79ffd..5a78c82 100644 --- a/models/rf3/src/rf3/model/layers/Attention_module.py +++ b/models/rf3/src/rf3/model/layers/Attention_module.py @@ -5,11 +5,11 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange from opt_einsum import contract as einsum - -from modelhub import SHOULD_USE_CUEQUIVARIANCE from rf3.training.checkpoint import activation_checkpointing from rf3.util_module import init_lecun_normal +from modelhub import SHOULD_USE_CUEQUIVARIANCE + if SHOULD_USE_CUEQUIVARIANCE: import cuequivariance_torch as cuet diff --git a/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py b/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py index e9aaa69..1258c76 100644 --- a/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py +++ b/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn from jaxtyping import Float +from rf3.util_module import init_lecun_normal from modelhub import SHOULD_USE_CUEQUIVARIANCE -from rf3.util_module import init_lecun_normal if SHOULD_USE_CUEQUIVARIANCE: import cuequivariance_torch as cuet diff --git a/models/rf3/src/rf3/model/layers/af3_auxiliary_heads.py b/models/rf3/src/rf3/model/layers/af3_auxiliary_heads.py index 9eb7db3..ff789f0 100644 --- a/models/rf3/src/rf3/model/layers/af3_auxiliary_heads.py +++ b/models/rf3/src/rf3/model/layers/af3_auxiliary_heads.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from rf3.model.RF3_structure import PairformerBlock, linearNoBias import src -from rf3.model.RF3_structure import PairformerBlock, linearNoBias # TODO: Get from RF2AA encoding instead CHEM_DATA_LEGACY = {"NHEAVY": 23, "aa2num": {"UNK": 20, "GLY": 7, "MAS": 21}} diff --git a/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py index 9e5d303..5086b0f 100644 --- a/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py +++ b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py @@ -1,7 +1,6 @@ import numpy as np import torch import torch.nn as nn - from rf3.loss.loss import calc_chiral_grads_flat_impl from rf3.model.layers.layer_utils import ( AdaLN, @@ -12,6 +11,7 @@ from rf3.model.layers.layer_utils import ( ) from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage from rf3.training.checkpoint import activation_checkpointing + from modelhub.utils.torch import device_of diff --git a/models/rf3/src/rf3/model/layers/layer_utils.py b/models/rf3/src/rf3/model/layers/layer_utils.py index 8d6149f..f9fbfde 100644 --- a/models/rf3/src/rf3/model/layers/layer_utils.py +++ b/models/rf3/src/rf3/model/layers/layer_utils.py @@ -3,9 +3,8 @@ from functools import partial import numpy as np import torch import torch.nn as nn -from torch.nn.functional import silu - from rf3.training.checkpoint import activation_checkpointing +from torch.nn.functional import silu linearNoBias = partial(torch.nn.Linear, bias=False) diff --git a/models/rf3/src/rf3/model/layers/outer_product.py b/models/rf3/src/rf3/model/layers/outer_product.py index 4657e07..486b35f 100644 --- a/models/rf3/src/rf3/model/layers/outer_product.py +++ b/models/rf3/src/rf3/model/layers/outer_product.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn - from rf3.training.checkpoint import activation_checkpointing from rf3.util_module import init_lecun_normal diff --git a/models/rf3/src/rf3/model/layers/pairformer_layers.py b/models/rf3/src/rf3/model/layers/pairformer_layers.py index 61874ab..0604024 100644 --- a/models/rf3/src/rf3/model/layers/pairformer_layers.py +++ b/models/rf3/src/rf3/model/layers/pairformer_layers.py @@ -1,7 +1,4 @@ import torch -from torch import nn -from torch.nn.functional import one_hot, relu - from rf3.data.ground_truth_template import ( af3_noise_scale_to_noise_level, ) @@ -26,6 +23,8 @@ from rf3.model.layers.outer_product import ( from rf3.model.RF3_blocks import MSAPairWeightedAverage, MSASubsampleEmbedder from rf3.training.checkpoint import activation_checkpointing from rf3.util_module import Dropout +from torch import nn +from torch.nn.functional import one_hot, relu class AtomAttentionEncoderPairformer(nn.Module): diff --git a/models/rf3/src/rf3/model/layers/structure_bias.py b/models/rf3/src/rf3/model/layers/structure_bias.py index 5043ef6..21babd0 100644 --- a/models/rf3/src/rf3/model/layers/structure_bias.py +++ b/models/rf3/src/rf3/model/layers/structure_bias.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn from opt_einsum import contract as einsum - from rf3.util_module import init_lecun_normal, rbf diff --git a/models/rf3/src/rf3/symmetry/resolve.py b/models/rf3/src/rf3/symmetry/resolve.py index 3b18a62..d7d1115 100644 --- a/models/rf3/src/rf3/symmetry/resolve.py +++ b/models/rf3/src/rf3/symmetry/resolve.py @@ -14,7 +14,6 @@ from atomworks.ml.transforms.base import Compose, convert_to_torch from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX from biotite.structure import AtomArray, AtomArrayStack from jaxtyping import Bool, Float, Int - from rf3.loss.af3_losses import ( ResidueSymmetryResolution, SubunitSymmetryResolution, diff --git a/models/rf3/src/rf3/train.py b/models/rf3/src/rf3/train.py index 299ef2d..4428a22 100755 --- a/models/rf3/src/rf3/train.py +++ b/models/rf3/src/rf3/train.py @@ -1,23 +1,22 @@ -#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../scripts/shebang/modelhub_exec.sh" "$0" "$@"' +#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"' import logging import os import hydra import rootutils -from dotenv import load_dotenv from omegaconf import DictConfig +from modelhub.utils.env import load_ipd_dotenv from modelhub.utils.logging import suppress_warnings from modelhub.utils.weights import CheckpointConfig -load_dotenv(override=True) - - # Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils) # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located) rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +load_ipd_dotenv(override=True) + _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs") _spawning_process_logger = logging.getLogger(__name__) diff --git a/models/rf3/src/rf3/trainers/rf3.py b/models/rf3/src/rf3/trainers/rf3.py index a603e00..0dff9d8 100644 --- a/models/rf3/src/rf3/trainers/rf3.py +++ b/models/rf3/src/rf3/trainers/rf3.py @@ -5,20 +5,20 @@ from einops import repeat from jaxtyping import Float, Int from lightning_utilities import apply_to_collection from omegaconf import DictConfig - -from modelhub.common import exists from rf3.loss.af3_losses import Loss as AF3Loss from rf3.loss.af3_losses import ( ResidueSymmetryResolution, SubunitSymmetryResolution, ) -from modelhub.metrics.metric import MetricManager from rf3.model.RF3 import ShouldEarlyStopFn -from modelhub.trainers.fabric import FabricTrainer from rf3.training.EMA import EMA -from modelhub.utils.ddp import RankedLogger from rf3.utils.io import build_stack_from_atom_array_and_batched_coords from rf3.utils.recycling import get_recycle_schedule + +from modelhub.common import exists +from modelhub.metrics.metric import MetricManager +from modelhub.trainers.fabric import FabricTrainer +from modelhub.utils.ddp import RankedLogger from modelhub.utils.torch import assert_no_nans, assert_same_shape ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/models/rf3/src/rf3/training/checkpoint.py b/models/rf3/src/rf3/training/checkpoint.py index 94b185c..79458c3 100644 --- a/models/rf3/src/rf3/training/checkpoint.py +++ b/models/rf3/src/rf3/training/checkpoint.py @@ -13,7 +13,7 @@ from torch.utils.checkpoint import checkpoint def create_custom_forward(module, **kwargs): """Create a custom forward function for gradient checkpointing with fixed kwargs. - Enables passing keyword arguments to a module when using PyTorch's checkpoint function, + Enables passing keyword arguments to a module when using PyTorch's checkpoint function, which only accepts positional arguments for the function to be checkpointed. Args: diff --git a/models/rf3/src/rf3/utils/frames.py b/models/rf3/src/rf3/utils/frames.py index f707c89..94be2e3 100644 --- a/models/rf3/src/rf3/utils/frames.py +++ b/models/rf3/src/rf3/utils/frames.py @@ -1,7 +1,6 @@ # TODO: REFACTOR; COPIED FROM RF2AA. WE NEED TO ADD DOCSTRINGS, EXAMPLES, HOPEFULLY TESTS, AND CLEAN UP import torch - from rf3.chemical import NFRAMES, NNAPROTAAS, costgtNA diff --git a/models/rf3/src/rf3/utils/inference.py b/models/rf3/src/rf3/utils/inference.py index 0776294..1959582 100644 --- a/models/rf3/src/rf3/utils/inference.py +++ b/models/rf3/src/rf3/utils/inference.py @@ -15,7 +15,6 @@ from atomworks.io.tools.inference import ( from atomworks.io.utils.io_utils import to_cif_file from atomworks.io.utils.selection import AtomSelectionStack from biotite.structure import AtomArray - from rf3.utils.io import ( CIF_LIKE_EXTENSIONS, DICTIONARY_LIKE_EXTENSIONS, diff --git a/models/rf3/src/rf3/utils/io.py b/models/rf3/src/rf3/utils/io.py index 6d6f275..3eed91e 100644 --- a/models/rf3/src/rf3/utils/io.py +++ b/models/rf3/src/rf3/utils/io.py @@ -7,8 +7,8 @@ import torch from atomworks.io.utils.io_utils import to_cif_file from beartype.typing import Literal from biotite.structure import AtomArray, AtomArrayStack, stack - from rf3.alignment import weighted_rigid_align + from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/models/rf3/src/rf3/utils/predicted_error.py b/models/rf3/src/rf3/utils/predicted_error.py index 4ef5010..39ea42d 100644 --- a/models/rf3/src/rf3/utils/predicted_error.py +++ b/models/rf3/src/rf3/utils/predicted_error.py @@ -9,7 +9,6 @@ import tree from beartype.typing import Any from biotite.structure import AtomArray, AtomArrayStack from omegaconf import DictConfig - from rf3.chemical import NHEAVY from rf3.metrics.metric_utils import ( compute_mean_over_subsampled_pairs, diff --git a/models/rf3/src/rf3/validate.py b/models/rf3/src/rf3/validate.py index d623686..6ea0707 100755 --- a/models/rf3/src/rf3/validate.py +++ b/models/rf3/src/rf3/validate.py @@ -1,18 +1,16 @@ -#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../scripts/shebang/modelhub_exec.sh" "$0" "$@"' +#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"' import logging import os -from pathlib import Path import hydra import rootutils -from dotenv import load_dotenv from omegaconf import DictConfig +from modelhub.utils.env import load_ipd_dotenv from modelhub.utils.logging import suppress_warnings -load_dotenv(override=True) - +load_ipd_dotenv(override=True) # Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils) # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located) diff --git a/models/rf3/tests/test_chiral_metrics.py b/models/rf3/tests/test_chiral_metrics.py index 27e6b75..36912a2 100644 --- a/models/rf3/tests/test_chiral_metrics.py +++ b/models/rf3/tests/test_chiral_metrics.py @@ -2,7 +2,6 @@ from copy import deepcopy import pytest from atomworks.ml.utils.testing import cached_parse - from rf3.metrics.chiral import ChiralLoss diff --git a/models/rf3/tests/test_write_confidence.py b/models/rf3/tests/test_write_confidence.py index a82aa3d..2684e18 100644 --- a/models/rf3/tests/test_write_confidence.py +++ b/models/rf3/tests/test_write_confidence.py @@ -3,7 +3,6 @@ import pytest import torch from lightning.fabric import seed_everything from omegaconf import DictConfig - from rf3.chemical import NHEAVY, heavyatom_mask from rf3.metrics.metric_utils import ( find_bin_midpoints, diff --git a/projects/latent/2d_pipe.py b/projects/latent/2d_pipe.py new file mode 100644 index 0000000..76a4dc9 --- /dev/null +++ b/projects/latent/2d_pipe.py @@ -0,0 +1,1305 @@ +"""Transform pipeline for Atom14 Design with 2D Conditioning""" + +import logging + +# Turn off warnings for now +# warnings.filterwarnings("ignore", category=RuntimeWarning) +# warnings.filterwarnings("ignore", category=DeprecationWarning) +import warnings +from pathlib import Path +from typing import Final + +import biotite.structure as struc +import numpy as np +import torch +import torch.nn.functional as F +from beartype.typing import Any +from cifutils.constants import ( + AF3_EXCLUDED_LIGANDS, + GAP, + STANDARD_AA, + STANDARD_DNA, + STANDARD_RNA, +) +from cifutils.utils.selection import get_residue_starts +from datahub.encoding_definitions import AF3SequenceEncoding +from datahub.enums import GroundTruthConformerPolicy +from datahub.transforms._checks import ( + check_atom_array_annotation, + check_contains_keys, +) +from datahub.transforms.af3_reference_molecule import ( + ELEMENT_NAME_TO_ATOMIC_NUMBER, + _encode_atom_names_like_af3, + get_af3_reference_molecule_features, +) +from datahub.transforms.atom_array import ( + AddGlobalAtomIdAnnotation, + AddGlobalTokenIdAnnotation, + AddProteinTerminiAnnotation, + AddWithinChainInstanceResIdx, + AddWithinPolyResIdxAnnotation, + ComputeAtomToTokenMap, + CopyAnnotation, + get_within_entity_idx, +) +from datahub.transforms.atomize import ( + AtomizeByCCDName, + FlagNonPolymersForAtomization, +) +from datahub.transforms.base import ( + AddData, + Compose, + ConditionalRoute, + ConvertToTorch, + RandomRoute, + RemoveKeys, + SubsetToKeys, + Transform, +) +from datahub.transforms.bonds import AddAF3TokenBondFeatures +from datahub.transforms.cached_residue_data import LoadCachedResidueLevelData +from datahub.transforms.covalent_modifications import ( + FlagAndReassignCovalentModifications, +) +from datahub.transforms.crop import ( + CropContiguousLikeAF3, + CropSpatialLikeAF3, +) +from datahub.transforms.diffusion.batch_structures import ( + BatchStructuresForDiffusionNoising, +) +from datahub.transforms.diffusion.edm import SampleEDMNoise +from datahub.transforms.featurize_unresolved_residues import ( + MaskPolymerResiduesWithUnresolvedFrameAtoms, + PlaceUnresolvedTokenAtomsOnRepresentativeAtom, + PlaceUnresolvedTokenOnClosestResolvedTokenInSequence, +) +from datahub.transforms.filters import ( + FilterToSpecifiedPNUnits, + HandleUndesiredResTokens, + RemoveHydrogens, + RemoveNucleicAcidTerminalOxygen, + RemovePolymersWithTooFewResolvedResidues, + RemoveTerminalOxygen, + RemoveUnresolvedLigandAtomsIfTooMany, + RemoveUnresolvedPNUnits, +) +from datahub.utils.token import ( + apply_and_spread_token_wise, + get_token_starts, +) + +from modelhub.common import exists + +# from projects.aa_design.constants import ( +# CENTRAL_ATOM, +# MASKED_ATOM_NAME, +# MASKED_RES_NAME, +# VIRTUAL_ATOM_ELEMENT, +# VIRTUAL_ATOM_NAME_PREFIX, +# ) +from projects.aa_design.transforms.condition import ( + C_CRD, + C_CTR, + C_DIS, + C_HOT, + C_IDX, + C_NTR, + C_SEQ, +) +from projects.aa_design.transforms.condition_2d.annotator import ( + ensure_annotations, +) +from projects.aa_design.transforms.condition_2d.design_task import ( + SampleDesignTask, + TipAtomDistanceTask, + UnconditionalTask, +) +from projects.aa_design.transforms.condition_2d.random_atomize_residues import ( + AtomizeByMaskFunction, + RandomAtomizeResidues, +) +from projects.aa_design.transforms.condition_2d.virtual_atoms import ( + MaskAnnotationsForTokensWithoutSequenceConditioning, +) +from projects.aa_design.transforms.conditioning_base import UnindexFlaggedTokens +from projects.aa_design.transforms.design_transforms import ( + AddGroundTruthSequence, +) +from projects.aa_design.transforms.util_transforms import ( + AggregateFeaturesLikeAF3WithoutMSA, + RemoveTokensWithoutCorrespondingCentralAtom, + get_af3_token_representative_masks, +) +from projects.aa_design.transforms.virtual_atoms import PadTokensWithVirtualAtoms + +warnings.filterwarnings( + "ignore", message="Category 'chem_comp_bond' not found", category=UserWarning +) +warnings.filterwarnings( + "ignore", message="The coordinates are missing", category=UserWarning +) + +# Turn DeprecationWarnings into exceptions +# warnings.filterwarnings("error", category=DeprecationWarning) + +warnings.filterwarnings("ignore", message="datetime", category=DeprecationWarning) +logging.getLogger("datahub").setLevel(logging.ERROR) +logging.getLogger("cifutils").setLevel(logging.ERROR) +logging.getLogger("cifutils.tools.rdkit").setLevel(logging.ERROR) + + +###################################################################################### +# Common transforms +###################################################################################### +af3_sequence_encoding = AF3SequenceEncoding() + +CENTRAL_ATOM: Final[str] = "CB" +"""Central atom name for virtual atoms.""" + +VIRTUAL_ATOM_ELEMENT: Final[str] = "X" +"""Virtual atom element.""" + +VIRTUAL_ATOM_NAME_PREFIX: Final[str] = "V" +"""Virtual atom name prefix.""" + +MASKED_ATOM_NAME: Final[str] = "VX" +"""The symbol to use for masked atoms.""" + +MASKED_RES_NAME = GAP +"""The symbal to use for masked residues + "" - Residue name used for all masked atoms (both virtual and real atoms with masked identities) +""" + + +def get_diffusion_transforms( + *, + sigma_data: float, + diffusion_batch_size: int, +): + return [ + ConvertToTorch(keys=["feats"]), + # Prepare coordinates for noising (without modifying the ground truth) + # ... add placeholder coordinates for noising + CopyAnnotation(annotation_to_copy="coord", new_annotation="coord_to_be_noised"), + # ... handling of unresolved residues (NOTE: best done after inputs are processed) + PlaceUnresolvedTokenAtomsOnRepresentativeAtom( + annotation_to_update="coord_to_be_noised" + ), + PlaceUnresolvedTokenOnClosestResolvedTokenInSequence( + annotation_to_update="coord_to_be_noised", + annotation_to_copy="coord_to_be_noised", + ), + # Feature aggregation + AggregateFeaturesLikeAF3WithoutMSA(), + # ... batching and noise sampling for diffusion + BatchStructuresForDiffusionNoising(batch_size=diffusion_batch_size), + SampleEDMNoise( + sigma_data=sigma_data, diffusion_batch_size=diffusion_batch_size + ), + ] + + +###################################################################################### +# Custom Transforms +###################################################################################### + + +class EncodeAtomLevelFeaturesWithSequenceMasking(Transform): + """ + Encodes atom-level reference features using featurization annotations with sequence masking. + + Uses featurization annotations (`element_to_featurize`, `res_name_to_featurize`, etc.) instead of + ground truth annotations for generating reference features. This allows sequence masking to be + handled by upstream transforms that update the featurization annotations appropriately. + + The following features are added to `data['feats']`: + - ref_pos: Reference atom positions. (np.ndarray, shape: (n_atoms, 3), dtype: float32) + - ref_mask: Reference atom mask. (np.ndarray, shape: (n_atoms,), dtype: bool) + - ref_element: Reference atom element indices. (np.ndarray, shape: (n_atoms,), dtype: int64) + - ref_charge: Reference atom charges. (np.ndarray, shape: (n_atoms,), dtype: int8) + - ref_space_uid: Unique residue segment index. (np.ndarray, shape: (n_atoms,), dtype: int64) + - ref_pos_is_ground_truth: Boolean indicator for whether the reference_conformer is ground truth. (np.ndarray, shape: (n_atoms,), dtype: bool) + - motif_pos: Ground truth motif positions (if coordinate conditioned, otherwise 0s). (np.ndarray, shape: (n_atoms, 3), dtype: float32) + - is_seq_conditioned_atom_level: Boolean indicator for sequence conditioning. (torch.Tensor, shape: (n_atoms,), dtype: bool) + - is_dist_conditioned_atom_level: Boolean indicator for distance conditioning. (torch.Tensor, shape: (n_atoms,), dtype: bool) + - mask_hotspot_1_atom: Boolean indicator for hotspot conditioning. (torch.Tensor, shape: (n_atoms,), dtype: bool) + - feature_hotspot_1_atom: Boolean indicator for whether the atom is a hotspot. (torch.Tensor, shape: (n_atoms,), dtype: bool) + - feature_distance_2_atom: Distance feature for 2D conditioning. (np.ndarray, shape: (n_atoms, n_atoms), dtype: float32) + - mask_distance_2_atom: Mask for 2D conditioning. (np.ndarray, shape: (n_atoms, n_atoms), dtype: bool) + + Args: + **kwargs: Additional keyword arguments passed to `get_af3_reference_molecule_features` (e.g., conformer generation timeout). + """ + + def __init__( + self, ground_truth_conformer_policy=GroundTruthConformerPolicy.IGNORE, **kwargs + ): + DEFAULT_KWARGS = dict( + conformer_generation_timeout=(3.0, 0.15), + ) + self.conformer_generation_kwargs = DEFAULT_KWARGS | kwargs + self.ground_truth_conformer_policy = ground_truth_conformer_policy + + def check_input(self, data: dict): + check_contains_keys(data, ["atom_array"]) + + def forward(self, data: dict) -> dict: + atom_array = data["atom_array"] + + L = atom_array.array_length() # n_atoms + + # ... Set up default reference features (all zeros) + ref_pos = np.zeros_like(atom_array.coord, dtype=np.float32) + ref_mask = np.zeros((L,), dtype=bool) + ref_element = np.zeros((L,), dtype=np.int64) + ref_charge = np.zeros((L,), dtype=np.int8) + ref_atom_name_chars = np.zeros((L, 4), dtype=np.int8) + ref_pos_is_ground_truth = np.zeros((L,), dtype=bool) + + # Get residue boundaries and assign unique IDs to each residue segment + residue_starts = get_residue_starts(atom_array, add_exclusive_stop=True) + ref_space_uid = struc.segments.spread_segment_wise( + residue_starts, np.arange(len(residue_starts) - 1, dtype=np.int64) + ) + + # ... Create reference features for sequence-conditioned atoms + is_seq_conditioned = C_SEQ.mask(atom_array, default="generate") + if np.any(is_seq_conditioned): + # We need an atom array with ground-truth atom names for reference conformer generation and the ground truth conformer policy + atom_array_with_gt_atom_name = atom_array.copy() + atom_array_with_gt_atom_name.atom_name = ( + atom_array.gt_atom_name + ) # Set in PadTokensWithVirtualAtoms + atom_array_with_gt_atom_name.set_annotation( + "ground_truth_conformer_policy", + np.full( + atom_array_with_gt_atom_name.array_length(), + self.ground_truth_conformer_policy.value, + ), + ) + + # If we are showing the model the sequence, we must generate conformers + reference_features, _ = get_af3_reference_molecule_features( + atom_array_with_gt_atom_name[is_seq_conditioned], + cached_residue_level_data=data["cached_residue_level_data"] + if "cached_residue_level_data" in data + else None, + **self.conformer_generation_kwargs, + ) # (n_atoms_with_seq_conditioning, n_features) + + # ... overwrite reference features for atoms with sequence conditioning + ref_pos[is_seq_conditioned] = reference_features["ref_pos"] + ref_mask[is_seq_conditioned] = reference_features["ref_mask"] + ref_charge[is_seq_conditioned] = reference_features["ref_charge"] + ref_atom_name_chars[is_seq_conditioned] = reference_features[ + "ref_atom_name_chars" + ] + ref_pos_is_ground_truth[is_seq_conditioned] = reference_features[ + "ref_pos_is_ground_truth" + ] + + # ... show element for all 'real' backbone atoms & any non-standard AA's + ref_element = np.array( + [ + ELEMENT_NAME_TO_ATOMIC_NUMBER.get(a, 0) + for a in atom_array.element_to_featurize + ] + ) + ref_atom_name_chars = _encode_atom_names_like_af3( + atom_array.atom_name_to_featurize + ) + # ensure_annotations(atom_array, "is_protein_backbone", "is_standard_aa") + # is_element_shown = atom_array.mask("(~is_virtual & is_protein_backbone) | ~is_standard_aa") + # ref_element[is_element_shown] = atom_array.atomic_number[is_element_shown] + + # 2D Features + mask_dist_2d = C_DIS.mask(atom_array, default="generate").as_dense_array() + feature_dist_2d = C_DIS.annotation( + atom_array, default="generate" + ).as_dense_array() + feature_dist_2d[~mask_dist_2d] = 0.0 + + reference_features = { + # ... standard AF3 `ref` features + "ref_pos": ref_pos, # (n_atoms, 3) + "ref_mask": ref_mask, # (n_atoms) + "ref_element": ref_element, # (n_atoms) + "ref_charge": ref_charge, # (n_atoms) + "ref_space_uid": ref_space_uid, # (n_atoms) + "ref_atom_name_chars": ref_atom_name_chars, # (n_atoms, 4, 64) + "ref_pos_is_ground_truth": ref_pos_is_ground_truth, # (n_atoms) + "motif_pos": np.nan_to_num( + C_CRD.annotation(atom_array, default="generate") + ), # (n_atoms, 3) + # ... extra condition features + C_SEQ.get_mask_name(1, "atom"): C_SEQ.mask( + atom_array, default="generate" + ), # (n_atoms) + # TODO(Discuss w. Nate): Do we still need this? + C_DIS.get_mask_name(1, "atom"): C_DIS.mask(atom_array, default="generate") + .as_dense_array(default=False) + .any(axis=0), # (n_atoms) + # TODO(Discuss w. Max): Do we need both? + C_HOT.get_mask_name(1, "atom"): C_HOT.mask( + atom_array, default="generate" + ), # (n_atoms) + C_HOT.get_feature_name(1, "atom"): C_HOT.annotation( + atom_array, default="generate" + ), # (n_atoms) + # 2D Features + C_DIS.get_feature_name(2, "atom"): feature_dist_2d, # (n_atoms, n_atoms) + C_DIS.get_mask_name(2, "atom"): mask_dist_2d, # (n_atoms, n_atoms) + } + # Verify all features have n_atoms as first dimension + assert all( + v.shape[0] == L for v in reference_features.values() + ), "All features must have n_atoms as first dimension" + + # Sanity Check: All features are null for unconditioned atoms + idx_unconditioned = ~is_seq_conditioned + assert np.allclose( + ref_pos[idx_unconditioned], np.zeros_like(ref_pos[idx_unconditioned]) + ), "ref_pos not null for unconditioned atoms" + assert np.all( + ~ref_mask[idx_unconditioned] + ), "ref_mask not null for unconditioned atoms" + # assert np.all( + # ref_element[idx_unconditioned] == 0 + # ), "ref_element not null for unconditioned atoms" + assert np.all( + ref_charge[idx_unconditioned] == 0 + ), "ref_charge not null for unconditioned atoms" + + if "feats" not in data: + data["feats"] = {} + + data["feats"].update(reference_features) + + return data + + +class HackInRequiredAnnotations(Transform): + """ + Hack to add in some annotations that UnindexFlaggedTokens should have added but didn't. + This functionality is taken from AddIsX, which already implements this but later in the pipeline. + Required annotations are: + "is_motif_atom", + "is_motif_token", + "is_motif_atom_unindexed", + "is_motif_atom_unindexed_motif_breakpoint", + "is_motif_atom_with_fixed_seq", + "is_motif_atom_with_fixed_coord", + "is_flexible_motif_atom", + + NOTE: In the future we want to rework UnindexFlaggedTokens to not need this hack. + """ + + def check_input(self, data): + check_contains_keys(data, ["atom_array"]) + + def forward(self, data: dict) -> dict: + # TODO: We should be storing in the ground truth key, not feats, unless it is literally a feature for the model + # (vs. info that we use for losses / metrics) + + atom_array = data["atom_array"] + + ########## HACK: SPOOF LEGACY MOTIF FEATURES ########## + # --- alias table -- + a = atom_array + # ------------------ + + from datahub.utils.token import ( + apply_token_wise, + get_token_starts, + spread_token_wise, + ) + + token_starts = get_token_starts(a) + token_segments = np.concatenate([token_starts, [a.array_length()]]) + + # `is_motif_atom` + # ... get relevant atom-level annotations + has_coord = C_CRD.mask(a, default="generate") + has_dist = C_DIS.mask(a, default="generate").as_dense_array(False).any(axis=0) + has_idx = C_IDX.mask(a, default="generate") + has_seq = C_SEQ.mask(a, default="generate") + is_motif_atom = has_coord | has_dist + a.set_annotation("is_motif_atom", is_motif_atom) + a.set_annotation("is_motif_atom_with_fixed_coord", has_coord) + # NOTE: You would think this should be is_motif_atom & has_seq, but based on PadTokensWithVirtualAtoms + # it really is just has_seq. + a.set_annotation("is_motif_atom_with_fixed_seq", has_seq) + a.set_annotation( + "is_flexible_motif_atom", + has_seq & ~is_motif_atom & np.isin(a.atom_name, ["N", "CA", "C", "O"]), + ) # TODO: Confirm this is desired behavior + a.set_annotation("is_motif_atom_unindexed", is_motif_atom & ~has_idx) + a.set_annotation( + "is_motif_atom_unindexed_motif_breakpoint", np.zeros_like(is_motif_atom) + ) # FIXME: Hack for now but will not work when we try unindexing motifs! + + # `is_motif_token` + # ... get relevant token-level annotations + _to_token_lvl = lambda x: apply_token_wise( # noqa: E731 + a, x, np.any, token_starts=token_segments + ) + token_has_coord = _to_token_lvl(has_coord) + token_has_dist = _to_token_lvl(has_dist) + # ... combine them to create legacy features + is_motif_token = token_has_coord | token_has_dist + a.set_annotation( + "is_motif_token", spread_token_wise(a, is_motif_token, token_segments) + ) + + data["atom_array"] = a + return data + + +class AddIsX(Transform): + def __init__( + self, + X=[ + "is_backbone", # ... part of the protein backbone (N, CA, C, O) + "is_sidechain", # ... part of the protein sidechain (all atoms except N, CA, C, O, OXT) + "is_virtual", # ... virtual atoms that do not exist in the ground truth + "is_central", # ... token representative atom (CA) + "is_ca", # ... true CA atoms + "is_masked", # ... atoms that are masked out (i.e. appear virtual to the model but exist in the ground truth) + ], + central_atom=CENTRAL_ATOM, + virtual_atom_element=VIRTUAL_ATOM_ELEMENT, + ): + self.X = X + self.central_atom = central_atom + self.virtual_atom_element = virtual_atom_element + + def check_input(self, data): + check_contains_keys(data, ["atom_array", "feats"]) + + def forward(self, data: dict) -> dict: + # TODO: We should be storing in the ground truth key, not feats, unless it is literally a feature for the model + # (vs. info that we use for losses / metrics) + + atom_array = data["atom_array"] + # ... Add backbone and sidechain annotations + ensure_annotations( + atom_array, + # "is_protein", + "is_protein_backbone", + "is_protein_sidechain", + ) + + _token_rep_mask = get_af3_token_representative_masks( + atom_array, central_atom=self.central_atom + ) + + # Initialize ground_truth dict if it doesn't exist + if "feats" not in data: + data["feats"] = {} + + # ... Basic features + if "is_backbone" in self.X: + is_backbone = atom_array.get_annotation("is_protein_backbone") + data["feats"]["is_backbone"] = torch.from_numpy(is_backbone).to( + dtype=torch.bool + ) + + if "is_sidechain" in self.X: + is_sidechain = atom_array.get_annotation("is_protein_sidechain") + data["feats"]["is_sidechain"] = torch.from_numpy(is_sidechain).to( + dtype=torch.bool + ) + + # Virtual atom feats + if "is_virtual" in self.X: + data["feats"]["is_virtual"] = ( + atom_array.element == self.virtual_atom_element + ) + + if "is_masked" in self.X: + data["feats"]["is_masked"] = ( + atom_array.element_to_featurize == self.virtual_atom_element + ) + + # ... Central + if "is_central" in self.X: + data["feats"]["is_central"] = _token_rep_mask + + # NOTE: Check end of function for is_ca. Need to do it then because for now we are relying on some of the spoofed legacy features + + # Set occupancy feature + if data.get( + "is_inference", False + ): # HACK: Pretend all occupancy is 1.0 during inference + data["feats"]["has_zero_occupancy"] = np.zeros_like( + atom_array.occupancy, dtype=bool + ) + else: + data["feats"]["has_zero_occupancy"] = atom_array.occupancy == 0.0 + + ########## HACK: SPOOF LEGACY MOTIF FEATURES ########## + # --- alias table -- + f = data["feats"] + a = atom_array + from_np = torch.from_numpy + # ------------------ + # TODO: Investigate this + + from datahub.utils.token import ( + apply_token_wise, + get_token_starts, + spread_token_wise, + ) + + n_atoms = atom_array.array_length() + token_starts = get_token_starts(a) + token_segments = np.concatenate([token_starts, [a.array_length()]]) + n_tokens = len(token_starts) + + # `is_motif_atom` + # ... get relevant atom-level annotations + has_coord = C_CRD.mask(a, default="generate") + has_dist = C_DIS.mask(a, default="generate").as_dense_array(False).any(axis=0) + has_idx = C_IDX.mask(a, default="generate") + has_seq = C_SEQ.mask(a, default="generate") + is_motif_atom = has_coord | has_dist + a.set_annotation("is_motif_atom", is_motif_atom) + a.set_annotation("is_motif_atom_with_fixed_coord", has_coord) + # NOTE: You would think this should be is_motif_atom & has_seq, but based on PadTokensWithVirtualAtoms + # it really is just has_seq. + a.set_annotation("is_motif_atom_with_fixed_seq", has_seq) + a.set_annotation( + "is_flexible_motif_atom", + has_seq & ~is_motif_atom & np.isin(a.atom_name, ["N", "CA", "C", "O"]), + ) # TODO: Confirm this is desired behavior + a.set_annotation("is_motif_atom_unindexed", is_motif_atom & ~has_idx) + a.set_annotation( + "is_motif_atom_unindexed_motif_breakpoint", np.zeros_like(is_motif_atom) + ) # FIXME: Hack for now but will not work when we try unindexing motifs! + # ... combine them to create legacy features (L, ) + f["is_motif_atom"] = from_np(a.is_motif_atom) + f["is_motif_atom_with_fixed_coord"] = from_np(a.is_motif_atom_with_fixed_coord) + f["ref_is_motif_atom_with_fixed_coord"] = from_np( + a.is_motif_atom_with_fixed_coord + ) + f["is_flexible_motif_atom"] = from_np(a.is_flexible_motif_atom) + f["is_motif_atom_with_fixed_seq"] = from_np(a.is_motif_atom_with_fixed_seq) + f["is_motif_atom_unindexed"] = from_np(a.is_motif_atom_unindexed) # (L,) + f["ref_is_motif_atom_unindexed"] = from_np(a.is_motif_atom_unindexed) # (L,) + # TODO: What is the difference between `is_motif_atom` and `ref_is_motif_atom`? + # [1, 0] = non-motif, [0, 1] = motif + ref_is_motif_atom = np.zeros((n_atoms, 2), dtype=bool) + ref_is_motif_atom[~is_motif_atom, 0] = True + ref_is_motif_atom[is_motif_atom, 1] = True + f["ref_is_motif_atom"] = from_np(ref_is_motif_atom) # (L, 2) + # [1, 0, 0] = non-motif; [0, 1, 0] = indexed motif; [0, 0, 1] = unindexed motif + ref_motif_atom_type = np.zeros((n_atoms, 3), dtype=bool) + ref_motif_atom_type[~is_motif_atom, 0] = True + ref_motif_atom_type[is_motif_atom & has_idx, 1] = True + ref_motif_atom_type[is_motif_atom & ~has_idx, 2] = True + f["ref_motif_atom_type"] = from_np(ref_motif_atom_type) # (L, 3) + + # `is_motif_token` + # ... get relevant token-level annotations + _to_token_lvl = lambda x: apply_token_wise( # noqa: E731 + a, x, np.any, token_starts=token_segments + ) + token_has_coord = _to_token_lvl(has_coord) + token_has_dist = _to_token_lvl(has_dist) + token_has_idx = _to_token_lvl(has_idx) + # ... combine them to create legacy features + is_motif_token = token_has_coord | token_has_dist + a.set_annotation( + "is_motif_token", spread_token_wise(a, is_motif_token, token_segments) + ) + f["is_motif_token"] = from_np(is_motif_token) # (I,) + # TODO: What is the difference between `is_motif_token` and `ref_is_motif_token`? NOTE: Max - ref_is_motif_token is one-hot encoded so it has dimension 2. However I don't think it's actually used anywhere by the model. + # Token-level motif indicator: [1, 0] = non-motif, [0, 1] = motif + ref_is_motif_token = np.zeros((n_tokens, 2), dtype=bool) + ref_is_motif_token[~is_motif_token, 0] = True + ref_is_motif_token[is_motif_token, 1] = True + f["ref_is_motif_token"] = from_np(ref_is_motif_token) # (I, 2) + # [1, 0, 0] = non-motif; [0, 1, 0] = indexed motif; [0, 0, 1] = unindexed motif + ref_motif_token_type = np.zeros((n_tokens, 3), dtype=bool) + ref_motif_token_type[~is_motif_token, 0] = True + ref_motif_token_type[is_motif_token & token_has_idx, 1] = True + ref_motif_token_type[is_motif_token & ~token_has_idx, 2] = True + f["ref_motif_token_type"] = from_np(ref_motif_token_type) # (I, 3) + f["is_motif_token_with_fully_fixed_coord"] = apply_token_wise( + a, has_coord, np.all, token_starts=token_segments + ) + + # TODO: What is the difference between this and is_motif_atom? + f["ref_is_motif_atom_mask"] = is_motif_atom + + # ... CA + if "is_ca" in self.X: + # NOTE from Max: This seems to be the fix to the glycine bug -- use CA as your central atom instead of CB for certain tasks. + # This feature is called `is_ca` but it really should probably be called `is_central_atom_if_central_atom_was_ca`. + # Basically we sometimes want to use the central atom, but can't use the real central atom (usually CB) because it will leak the glycine's identity. + # So instead we use `is_ca` in those spots which means it not only needs to mark CA atoms but also all central atoms (for ligands and such) + + # Split into components to handle separately + atom_array_indexed = atom_array[~atom_array.is_motif_atom_unindexed] + _token_rep_mask_indexed = get_af3_token_representative_masks( + atom_array_indexed, central_atom="CA" + ) + if atom_array.is_motif_atom_unindexed.any(): + atom_array_unindexed = atom_array[atom_array.is_motif_atom_unindexed] + + # Ensure is_ca represents one and the first atom only for unindexed tokens + def first_nonzero(n): + assert n > 0 + x = np.zeros(n, dtype=bool) + x[0] = 1 + return x + + starts = get_token_starts(atom_array_unindexed, add_exclusive_stop=True) + _token_rep_mask_unindexed = np.concatenate( + [ + first_nonzero(end - start) + for start, end in zip(starts[:-1], starts[1:]) + ] + ) + _token_rep_mask = np.concatenate( + [ + _token_rep_mask_indexed, + _token_rep_mask_unindexed, + ], + axis=0, + ) + else: + _token_rep_mask = _token_rep_mask_indexed + data["feats"]["is_ca"] = _token_rep_mask + + return data + + +class EncodeTokenLevelFeaturesWithSequenceMasking(Transform): + """ + Encodes token-level model features using specified featurization annotations. + + Uses `res_name_to_featurize` instead of `res_name` for encoding residue types. + This allows for sequence masking to be handled by upstream transforms that update the featurization annotations appropriately. + + Computes and stores the following token-level features in the `data['feats']` dictionary: + - `residue_index`: Index of the residue within its chain (int, shape: (N_tokens,)) + - `token_index`: Index of the token in the sequence (int, shape: (N_tokens,)) + - `asym_id`: Unique integer for each distinct chain instance (int, shape: (N_tokens,)) + - `entity_id`: Unique integer for each distinct sequence entity (int, shape: (N_tokens,)) + - `sym_id`: Unique integer within chains of the same sequence (int, shape: (N_tokens,)) + - `restype`: One-hot encoding of the residue type (float, shape: (N_tokens, n_tokens)), using featurization annotations + - `is_protein`, `is_rna`, `is_dna`, `is_ligand`: Boolean masks for molecule type (bool, shape: (N_tokens,)) + + Metadata for chain and entity names is stored in `data['feat_metadata']`. + + Args: + sequence_encoding (AF3SequenceEncoding): + An encoding object that provides methods for mapping residue names to AF3 token indices and one-hot encodings. + """ + + def __init__( + self, + sequence_encoding: AF3SequenceEncoding, + ): + self.sequence_encoding = sequence_encoding + + def check_input(self, data: dict[str, Any]) -> None: + check_atom_array_annotation( + data, + [ + "atomize", + "pn_unit_iid", + "chain_entity", + "res_name", + "res_name_to_featurize", + "within_chain_res_idx", + ], + ) + + def forward(self, data: dict[str, Any]) -> dict[str, Any]: + atom_array = data["atom_array"] + + # ... get token-level array + token_starts = get_token_starts(atom_array) + n_tokens = len(token_starts) + token_level_array = atom_array[token_starts] + + # ... identifier tokens + # ... (residue) + residue_index = token_level_array.within_chain_res_idx + # ... (token) + token_index = np.arange(len(token_starts)) + # ... (chain instance) + asym_name, asym_id = np.unique( + token_level_array.pn_unit_iid, return_inverse=True + ) + # ... (chain entity) + entity_name, entity_id = np.unique( + token_level_array.pn_unit_entity, return_inverse=True + ) + # ... (within chain entity) + sym_name, sym_id = get_within_entity_idx(token_level_array, level="pn_unit") + + # ... molecule type (protein, RNA, DNA, ligand) - use ground truth res_name + _aa_like_res_names = self.sequence_encoding.all_res_names[ + self.sequence_encoding.is_aa_like + ] + is_protein = np.isin(token_level_array.res_name, _aa_like_res_names) + + _rna_like_res_names = self.sequence_encoding.all_res_names[ + self.sequence_encoding.is_rna_like + ] + is_rna = np.isin(token_level_array.res_name, _rna_like_res_names) + + _dna_like_res_names = self.sequence_encoding.all_res_names[ + self.sequence_encoding.is_dna_like + ] + is_dna = np.isin(token_level_array.res_name, _dna_like_res_names) + + is_ligand = ~(is_protein | is_rna | is_dna) + + # ... sequence tokens - use featurization annotations + res_names_to_featurize = token_level_array.res_name_to_featurize + + restype = self.sequence_encoding.encode(res_names_to_featurize) + restype = F.one_hot( + torch.tensor(restype), num_classes=self.sequence_encoding.n_tokens + ) + + # Indicator variables for conditioning (atom-level, but indicates token-level conditioning) + is_dist_conditioned_atom_level = np.any( + C_DIS.mask(atom_array, 2, "atom", default="generate").as_dense_array(), + axis=1, + ) + token_segments = np.concatenate([token_starts, [atom_array.array_length()]]) + is_dist_conditioned_token_level = torch.from_numpy( + apply_and_spread_token_wise( + atom_array, + is_dist_conditioned_atom_level, + np.any, + token_starts=token_segments, + ) + )[token_starts] # (L,) + is_seq_conditioned_token_level = torch.from_numpy( + C_SEQ.mask(token_level_array, default="generate") + ) # (L,) + + # ... add terminus type # TODO: Turn into proper `Condition` + terminus_type = torch.zeros((n_tokens, 2), dtype=torch.long) + is_c_terminus = C_CTR.mask(token_level_array, default="raise") + is_n_terminus = C_NTR.mask(token_level_array, default="raise") + terminus_type[is_c_terminus, 0] = 1 + terminus_type[is_n_terminus, 1] = 1 + + # ... add to data dict + if "feats" not in data: + data["feats"] = {} + if "feat_metadata" not in data: + data["feat_metadata"] = {} + + # Build dictionary of features + new_feats = { + "residue_index": residue_index, # (N_tokens) (int) + "token_index": token_index, # (N_tokens) (int) + "asym_id": asym_id, # (N_tokens) (int) + "entity_id": entity_id, # (N_tokens) (int) + "sym_id": sym_id, # (N_tokens) (int) + "restype": restype, # (N_tokens, 32) (float, one-hot) (using featurization annotations) + "is_protein": is_protein, # (N_tokens) (bool) + "is_rna": is_rna, # (N_tokens) (bool) + "is_dna": is_dna, # (N_tokens) (bool) + "is_ligand": is_ligand, # (N_tokens) (bool) + "terminus_type": terminus_type, # (N_tokens, 2) (int) + C_DIS.get_mask_name( + 1, "token" + ): is_dist_conditioned_token_level, # (N_tokens,) (bool) + C_SEQ.get_mask_name( + 1, "token" + ): is_seq_conditioned_token_level, # (N_tokens,) (bool) + } + + # Assert all features have matching first dimension + n_tokens = len(residue_index) + for key, value in new_feats.items(): + assert ( + value.shape[0] == n_tokens + ), f"{key} has first dim {value.shape[0]} but expected {n_tokens}!" + + # Merge into data dict + data["feats"] |= new_feats + + # Maps from numerical indices to string names (returned from np.unique with return_inverse=True) + # (May be helpful for debugging) + data["feat_metadata"] |= { + "asym_name": asym_name, # (N_asyms) + "entity_name": entity_name, # (N_entities) + "sym_name": sym_name, # (N_entities) + } + + return data + + +class RemoveCenterOfMass(Transform): + def forward(self, data: dict) -> dict: + atom_array = data["atom_array"] + center_of_mass = atom_array.coord[atom_array.mask("~has_nan_coord()")].mean( + axis=0 + ) + atom_array.coord -= center_of_mass + data["atom_array"] = atom_array + return data + + +class JitterCenterOfMass(Transform): + def __init__(self, jitter_sigma: float = 8.0): + self.jitter_sigma = jitter_sigma + + def forward(self, data: dict) -> dict: + atom_array = data["atom_array"] + atom_array.coord += np.random.normal(0, self.jitter_sigma, (3,)) + data["atom_array"] = atom_array + return data + + +###################################################################################### +# Pipelines +###################################################################################### +class _ListBuilder: + """A small convenience class to build lists element by element with the '+=' operator""" + + def __init__(self): + self.list = [] + + def __add__(self, other): + if isinstance(other, list): + self.list.extend(other) + elif isinstance(other, _ListBuilder): + self.list.extend(other.list) + else: + self.list.append(other) + return self + + def tolist(self): + return self.list + + +_DEFAULT_DESIGN_TASKS = { + "unconditional": { + "transform": UnconditionalTask(), + "frequency": 0.01, + }, + "tip_atom_distance": { + "transform": TipAtomDistanceTask( + min_residues=2, + max_residues=20, + p_tip_atom=0.8, + knockout_p=0.8, + dropout_min_fraction=0.1, + dropout_max_fraction=0.9, + ), + "frequency": 1.0, + }, +} +"""A dummy list of examplary design tasks to provide a discoverable interface for the design task sampling system. + +For actual training runs this should be set via the hydra config. +""" + + +def _get_design_task_name(data: dict) -> str: + """Get the design task name from the data dict. + NOTE: This is implemented as a global function to allow pickle-ing for multiprocessing. + """ + return data["task"]["name"] + + +def _build_rfd3_train_pipeline( + crop_size: int | None = None, + crop_contiguous_probability: float = 0.5, + crop_spatial_probability: float = 0.5, + crop_center_cutoff_distance: float = 15.0, + max_atoms_in_crop: int | None = None, + is_inference: bool = False, + p_atomize_residues: float = 0.0, + design_tasks: dict[str, dict[str, Any]] = _DEFAULT_DESIGN_TASKS, + undesired_res_names: list[str] = AF3_EXCLUDED_LIGANDS, + central_atom: str = CENTRAL_ATOM, + **kwargs, # to dump remaining args for now +) -> list[Transform]: + # TODO: Step 0 (before picking example): Sample problem type + # T = transform shorthand for readability + T = _ListBuilder() + + # (We may want to run the train pipe with is_inference=True; e.g., for simple validations of our training objective) + T += AddData({"is_inference": is_inference}) + ########################################################### + ### Step 1: Filter unwanted information ### + ########################################################### + # ... cleanup + T += RemoveKeys(["atom_array_stack"], require_keys_exist=False) + T += RemoveHydrogens() + T += FilterToSpecifiedPNUnits( + extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing" + ) + T += RemoveTerminalOxygen() + T += RemoveNucleicAcidTerminalOxygen() + T += RemoveUnresolvedPNUnits() + T += HandleUndesiredResTokens( + undesired_res_tokens=undesired_res_names + ) # e.g., non-standard residues + T += RemovePolymersWithTooFewResolvedResidues(min_residues=1) + T += MaskPolymerResiduesWithUnresolvedFrameAtoms() + T += RemoveUnresolvedLigandAtomsIfTooMany(unresolved_ligand_atom_limit=5) + T += RemoveTokensWithoutCorrespondingCentralAtom(central_atom=central_atom) + + # ... add basic annotations + T += AddGlobalAtomIdAnnotation() + T += AddWithinChainInstanceResIdx() + T += AddWithinPolyResIdxAnnotation() + T += AddProteinTerminiAnnotation() # Also handled by conditions now? + T += FlagAndReassignCovalentModifications() + T += FlagNonPolymersForAtomization() + T += RandomAtomizeResidues(p_atomize=p_atomize_residues) + T += AtomizeByCCDName( + atomize_by_default=True, + res_names_to_ignore=STANDARD_AA + STANDARD_RNA + STANDARD_DNA, + move_atomized_part_to_end=False, + validate_atomize=False, + ) + + # ... crop + if crop_size and (crop_contiguous_probability > 0 or crop_spatial_probability > 0): + assert crop_size > 0, "Crop size must be greater than 0" + assert ( + crop_center_cutoff_distance > 0 + ), "Crop center cutoff distance must be greater than 0" + + T += RandomRoute.from_list( + [ + ( + crop_contiguous_probability, # DISCUSS: Does this even make sense for design?? + CropContiguousLikeAF3( + crop_size=crop_size, + keep_uncropped_atom_array=False, + max_atoms_in_crop=max_atoms_in_crop, + annotate_crop_boundary=True, + ), + ), + ( + crop_spatial_probability, + CropSpatialLikeAF3( + crop_size=crop_size, + crop_center_cutoff_distance=crop_center_cutoff_distance, + keep_uncropped_atom_array=False, + max_atoms_in_crop=max_atoms_in_crop, + raise_if_missing_query=False, + annotate_crop_boundary=True, + ), + ), + ] + ) + + # ... remove center of mass & add jitter information to not leak the crop center + # TODO: Switch for CenterRandomAugmentation + T += RemoveCenterOfMass() + T += JitterCenterOfMass(jitter_sigma=5.0) + + ########################################################### + ### Step 2: Build design task ### + ########################################################### + # Design task sampling: Sample and apply design tasks + # + # This section implements a flexible design task sampling system where different protein design + # objectives can be conditionally applied based on the input data. The system works as follows: + # 1. SampleDesignTask evaluates all available design tasks and samples one based on their + # frequencies among eligible tasks (those that can be applied to the current data) + # 2. Each design task implements can_apply() to determine applicability and + # annotate_for_task_selection() to add required annotations + # 3. The sampled task is then applied via ConditionalRoute to generate the appropriate + # design problem configuration + # 4. Possible `DesignTaskModifier` functions can then be applied afterwards to modify the task + # These exist to reduce the overhead of writing full-fledged new design tasks by simply allowing + # to change the condition annotation values (e.g. `rasa`, `sequence`, `distance`) for a given, + # existing task. + T += SampleDesignTask(design_tasks=design_tasks) + + # ... create the design task + T += ConditionalRoute( + condition_func=_get_design_task_name, + transform_map={name: task["transform"] for name, task in design_tasks.items()}, + ) + + # ... (optionally) modify the design task post-hoc + # TODO: Template this + + return T.tolist() + + +def _build_rfd3_featurize_pipeline( + n_atoms_per_token: int = 14, + central_atom: str = CENTRAL_ATOM, + sigma_data: float = 16.0, + diffusion_batch_size: int = 32, + return_atom_array: bool = False, + residue_cache_dir: str | None = None, + association_scheme: str | None = None, +) -> list[Transform]: + T = _ListBuilder() + + if exists(residue_cache_dir): + T += LoadCachedResidueLevelData( + dir=Path(residue_cache_dir), + sharding_depth=1, + ) + + # MISSING TRANSFORMS FROM OTHER PIPELINE (Besides all the ones that would be replaced by design tasks): + # AddIsDAminoAcidFeat() and RandomlyMirrorInputs() + # MotifCenterRandomAugmentation() + # AugmentNoise() + + ###################################################################################### + # Virtual Atoms and Masked Atoms + ###################################################################################### + # ... Annotate finalized token ids + # TODO: Add terminus_type + T += AddGlobalTokenIdAnnotation() + + # ... Copy annotations to create featurization annotations + T += CopyAnnotation( + annotation_to_copy="res_name", new_annotation="res_name_to_featurize" + ) + T += CopyAnnotation( + annotation_to_copy="atom_name", new_annotation="atom_name_to_featurize" + ) + T += CopyAnnotation( + annotation_to_copy="element", new_annotation="element_to_featurize" + ) + T += HackInRequiredAnnotations() # UnindexFlaggedTokens needs these annotations + # ... unindex flagged tokens. In this pipeline, we don't have any unindexed tokens. + T += UnindexFlaggedTokens( + central_atom=central_atom + ) # TODO: Fix the hacks in this function to actually work in this framework!!! + # ... add virtual atoms to protein residues (WITHOUT sequence conditioning) + # (Since if we know the sequence, we know exactly how many atoms there are — no need for virtual atoms) + T += PadTokensWithVirtualAtoms( # Use this over AddVirtualAtoms for now to handle the naming permutations + n_atoms_per_token=n_atoms_per_token, + atom_to_pad_from=central_atom, + association_scheme=association_scheme, + ) + # ... Create masked atoms: mask chemical identities for all atoms in tokens without sequence conditioning + # This includes both virtual atoms (is_virtual=True) and real atoms (is_virtual=False) + T += MaskAnnotationsForTokensWithoutSequenceConditioning( + masked_atom_element=VIRTUAL_ATOM_ELEMENT, + masked_atom_name=MASKED_ATOM_NAME, + masked_res_name=MASKED_RES_NAME, # For featurization masking + res_name_annotation_to_featurize="res_name_to_featurize", + atom_name_annotation_to_featurize="atom_name_to_featurize", + element_annotation_to_featurize="element_to_featurize", + ) + + ###################################################################################### + # Featurize all conditions + ###################################################################################### + # ... Compute atom-to-token mapping after token structure is finalized (we have added virtual atoms already) + T += ComputeAtomToTokenMap() + # ... AF3 token level-encoding, with sequence masking when we don't have sequence + T += EncodeTokenLevelFeaturesWithSequenceMasking( + sequence_encoding=af3_sequence_encoding, + ) + # ... Atom-level reference features + T += EncodeAtomLevelFeaturesWithSequenceMasking() + # ... Bonds + T += AddAF3TokenBondFeatures() + # ... Add useful features for losses / metrics + # (We add to ground truth, not feats, to distinguish from features that we show the model vs. features that we use for losses and metrics) + T += AddIsX( + X=[ + # Basic + "is_backbone", + "is_sidechain", + # Virtual atom + "is_masked", + "is_virtual", + "is_central", + "is_ca", + ], + central_atom=central_atom, + ) + T += AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding) + # EDM-style wrap-up (no additional features added at this point) + T += get_diffusion_transforms( + sigma_data=sigma_data, diffusion_batch_size=diffusion_batch_size + ) + + # Subset to necessary keys only + keys_to_keep = [ + "example_id", + "feats", + "t", + "noise", + "ground_truth", + "coord_atom_lvl_to_be_noised", + "symmetry_resolution", + "extra_info", + "task", + ] + + if return_atom_array: + keys_to_keep.append("atom_array") + + T += SubsetToKeys(keys_to_keep) + + return T.tolist() + + +class AnnotateConditionsForInference(Transform): + """Annotate conditions for inference.""" + + def check_input(self, data: dict) -> None: + assert data.get( + "is_inference", False + ), "This transform should only be run in inference mode." + + def forward(self, data: dict) -> dict: + atom_array = data["atom_array"] + + # TODO: Revisit these once the `sampling` of lengths/contigs is implemented. + # Explicitly set conditions whose defaults may change if the atom-array is in-context conditioned + # (e.g. when running `unindexed`) + # ... terminus type conditioning + is_n_terminus = C_NTR.mask(atom_array, default="generate") + is_c_terminus = C_CTR.mask(atom_array, default="generate") + + C_NTR.set_mask(atom_array, is_n_terminus) + C_CTR.set_mask(atom_array, is_c_terminus) + + return data + + +def _build_rfd3_validation_pipeline( + atomize_distance_conditioned_tokens: bool = False, + **kwargs, # to dump remaining args for now +) -> list[Transform]: + # T = transform shorthand for readability + # ... initialize + T = _ListBuilder() + T += AddData({"is_inference": True}) + T += RemoveTerminalOxygen() + + ########################################################### + ### Step 1: Add potentially missing information ### + ########################################################### + # TODO: This needs to sample a set of residues inbetween & at the end of our motif + # (essentially the same as contigs in RFD) + # T += SampleSequenceLength() + T += AddGlobalAtomIdAnnotation(allow_overwrite=True) + T += AddWithinChainInstanceResIdx() + T += AtomizeByCCDName( + atomize_by_default=True, + res_names_to_ignore=STANDARD_AA + STANDARD_RNA + STANDARD_DNA, + ) + if atomize_distance_conditioned_tokens: + T += AtomizeByMaskFunction( + mask_function=lambda x: C_DIS.mask(x, default="generate") + .as_dense_array(default=False) + .any(axis=0), + ) + T += AnnotateConditionsForInference() + return T.tolist() + + +def build_rfd3_train_pipeline( + # ... training specific + crop_size: int | None = None, + crop_contiguous_probability: float = 0.5, + crop_spatial_probability: float = 0.5, + crop_center_cutoff_distance: float = 15, + p_atomize_residues: float = 0.0, + # ... design task specific + design_tasks: dict[str, dict] = _DEFAULT_DESIGN_TASKS, + # ... featurization specific + sigma_data: float = 0.5, + diffusion_batch_size: int = 16, + n_atoms_per_token: int = 14, + central_atom: str = "CB", + return_atom_array: bool = False, + is_inference: bool = False, + residue_cache_dir: str | None = None, + association_scheme: str = "atom14", + **kwargs, # to dump remaining args for now +) -> list[Transform]: + """Build the RFD3 training pipeline. + + Args: + sigma_data: Scale of noise to add during training + diffusion_batch_size: Number of diffusion samples to generate per batch + central_atom: Name of the central atom to use for virtual atoms + virtual_atom_element: Element symbol to use for virtual atoms + return_atom_array: Whether to return the atom array in the output + is_inference: Whether to run the pipeline in inference mode (in case we want to validate our training objective) + residue_cache_dir: Directory to load residue-level cached data from + + Returns: + List of transforms for the training pipeline + """ + transforms = [] + transforms += _build_rfd3_train_pipeline( + crop_size=crop_size, + crop_contiguous_probability=crop_contiguous_probability, + crop_spatial_probability=crop_spatial_probability, + crop_center_cutoff_distance=crop_center_cutoff_distance, + design_tasks=design_tasks, + is_inference=is_inference, + p_atomize_residues=p_atomize_residues, + central_atom=central_atom, + ) + transforms += _build_rfd3_featurize_pipeline( + n_atoms_per_token=n_atoms_per_token, + sigma_data=sigma_data, + diffusion_batch_size=diffusion_batch_size, + central_atom=central_atom, + return_atom_array=return_atom_array, + residue_cache_dir=residue_cache_dir, + association_scheme=association_scheme, + ) + return Compose(transforms) + + +def build_rfd3_validation_pipeline( + # ... featurization specific + sigma_data: float = 0.5, + diffusion_batch_size: int = 16, + n_atoms_per_token: int = 14, + central_atom: str = "CB", + return_atom_array: bool = False, + atomize_distance_conditioned_tokens: bool = False, + association_scheme: str = "atom14", +) -> list[Transform]: + """Build the RFD3 validation pipeline. + + Args: + sigma_data: Scale of noise to add during validation + diffusion_batch_size: Number of diffusion samples to generate per batch + central_atom: Name of the central atom to use for virtual atoms + virtual_atom_element: Element symbol to use for virtual atoms + return_atom_array: Whether to return the atom array in the output + + Returns: + List of transforms for the validation pipeline + """ + transforms = [] + transforms += _build_rfd3_validation_pipeline( + atomize_distance_conditioned_tokens=atomize_distance_conditioned_tokens, + ) + transforms += _build_rfd3_featurize_pipeline( + sigma_data=sigma_data, + diffusion_batch_size=diffusion_batch_size, + central_atom=central_atom, + n_atoms_per_token=n_atoms_per_token, + return_atom_array=return_atom_array, + association_scheme=association_scheme, + ) + return Compose(transforms) \ No newline at end of file diff --git a/projects/latent/pipeline.py b/projects/latent/pipeline.py new file mode 100644 index 0000000..a35285f --- /dev/null +++ b/projects/latent/pipeline.py @@ -0,0 +1 @@ +from atomworks.ml.datasets import PandasDataset \ No newline at end of file diff --git a/src/modelhub/utils/env.py b/src/modelhub/utils/env.py new file mode 100644 index 0000000..4d152a4 --- /dev/null +++ b/src/modelhub/utils/env.py @@ -0,0 +1,35 @@ +"""Environment loading utilities.""" + +import os +from pathlib import Path + +from dotenv import load_dotenv + + +def load_ipd_dotenv(override: bool = True) -> None: + """Load environment variables, prioritizing IPD-specific config. + + First checks for ``.ipd/.env`` in the project root. If found, loads it. + Otherwise falls back to standard ``.env`` loading. + + Note: + Requires ``PROJECT_ROOT`` environment variable to be set (via ``rootutils.setup_root()``). + + Args: + override: Override existing environment variables. Defaults to ``True``. + + Raises: + RuntimeError: If ``PROJECT_ROOT`` is not set. + """ + project_root = os.environ.get("PROJECT_ROOT") + if not project_root: + raise RuntimeError( + "PROJECT_ROOT environment variable not set. " + "Call rootutils.setup_root() before load_ipd_dotenv()." + ) + + ipd_env = Path(project_root) / ".ipd" / ".env" + if ipd_env.exists(): + load_dotenv(dotenv_path=ipd_env, override=override) + else: + load_dotenv(override=override)