From 3dba499b6d037e6e9b0e23c0259f18887efe6d8b Mon Sep 17 00:00:00 2001
From: Jasper Butcher <66851659+Ubiquinone-dot@users.noreply.github.com>
Date: Thu, 20 Nov 2025 16:29:47 -0800
Subject: [PATCH] refactor source files for open sourcing (#648)
* mc
* Add base class for inference engine
* refactor inference engine
* Move constants and components out of rfd3 folder
* Fixes to engine
* Update with working checkpoint
* revert layer utils
* Fix more imports
* Move alignment, conditiontransitionblock
* Update sampler name
* mc
* More import fixes
* make format
* Minor fixes
* mc
* Fix rf3 inference engine
* Fix inference sampler
* Fix modules
* Running inference
* Make format
* add pre-commit hook
* fix: RF3 inference (#670)
fix: make rf3 tests in new format
* Minor cleanup
---------
Co-authored-by: Nathaniel Corley
---
.pre-commit-config.yaml | 8 +
README.md | 11 +
lib/atomworks | 2 +-
models/rf3/README.md | 1 +
models/rf3/pyproject.toml | 2 +-
models/rf3/src/rf3/data/paired_msa.py | 2 +-
.../diffusion_samplers/inference_sampler.py | 2 +-
models/rf3/src/rf3/inference_engines/rf3.py | 217 +++----
models/rf3/src/rf3/loss/af3_losses.py | 28 +-
models/rf3/src/rf3/model/RF3_structure.py | 24 +-
.../src/rf3/model/layers/pairformer_layers.py | 2 +-
models/rf3/src/rf3/trainers/rf3.py | 5 +-
models/rf3/src/rf3/util_module.py | 24 -
models/rf3/src/rf3/utils/io.py | 2 +-
models/rf3/tests/test_inference_regression.py | 3 +-
models/rfd3/README.md | 20 +
.../datasets/val/design_validation_base.yaml | 11 -
models/rfd3/configs/experiment/rfd3-full.yaml | 17 -
.../rfd3/configs/inference_engine/base.yaml | 5 +-
.../inference_engine/rfdiffusion3.yaml | 18 +-
.../configs/model/components/rfd3_net.yaml | 4 +-
models/rfd3/configs/paths/data/default.yaml | 2 +-
models/rfd3/configs/trainer/aa_design.yaml | 12 +-
.../trainer/loss/losses/diffusion_loss.yaml | 26 +-
.../trainer/loss/losses/sequence_loss.yaml | 3 +
models/rfd3/pyproject.toml | 2 +-
models/rfd3/src/rfd3/cli.py | 19 +-
models/rfd3/src/rfd3/constants.py | 48 +-
models/rfd3/src/rfd3/docs/input.md | 6 +-
models/rfd3/src/rfd3/engine.py | 435 ++++++++++++++
models/rfd3/src/rfd3/inference/datasets.py | 96 +---
models/rfd3/src/rfd3/inference/engine.py | 448 ---------------
.../rfd3/src/rfd3/inference/input_parsing.py | 535 ++++++++++--------
.../rfd3/inference/legacy_input_parsing.py | 34 +-
models/rfd3/src/rfd3/inference/parsing.py | 3 +-
.../rfd3/src/rfd3/inference/specification.py | 0
.../src/rfd3/inference/symmetry/contigs.py | 3 +-
.../rfd3/inference/symmetry/symmetry_utils.py | 2 +-
models/rfd3/src/rfd3/metrics/losses.py | 62 +-
.../src/rfd3/metrics/sidechain_metrics.py | 2 +-
models/rfd3/src/rfd3/model/RFD3.py | 105 ++++
...ion_module.py => RFD3_diffusion_module.py} | 62 +-
models/rfd3/src/rfd3/model/af3_design.py | 12 -
.../{aa_design.py => inference_sampler.py} | 424 ++++----------
.../src/rfd3/model/{ => layers}/attention.py | 9 +-
.../rfd3/model/{ => layers}/block_utils.py | 0
.../src/rfd3/model/{ => layers}/blocks.py | 49 +-
.../model/{ => layers}/chunked_pairwise.py | 3 +-
.../src/rfd3/model/{ => layers}/encoders.py | 14 +-
.../rfd3/src/rfd3/model/layers/layer_utils.py | 197 +++++++
.../rfd3/model/layers/pairformer_layers.py | 128 +++++
models/rfd3/src/rfd3/run_inference.py | 27 +-
models/rfd3/src/rfd3/testing/debug.py | 3 +-
models/rfd3/src/rfd3/testing/debug_utils.py | 2 +-
models/rfd3/src/rfd3/testing/testing_utils.py | 19 +-
.../rfd3/src/rfd3/trainer/fabric_trainer.py | 1 -
.../rfd3/trainer/{rfd3_trainer.py => rfd3.py} | 418 ++------------
models/rfd3/src/rfd3/trainer/trainer_utils.py | 502 ++++++++++++++++
.../bfactor_conditioned_transforms.py | 56 --
.../src/rfd3/transforms/conditioning_utils.py | 194 -------
models/rfd3/src/rfd3/transforms/pipelines.py | 44 +-
.../rfd3/src/rfd3/transforms/virtual_atoms.py | 4 +-
.../inference_utils.py => utils/inference.py} | 13 +-
models/rfd3/src/rfd3/{util => utils}/io.py | 3 +-
.../src/rfd3/{util => utils}/vizualize.py | 2 +-
models/rfd3/tests/test_aa_design.py | 22 +-
models/rfd3/tests/test_metrics.py | 2 +-
models/rfd3/tests/test_selections.py | 15 +-
models/rfd3/tests/test_unindexing.py | 8 +-
.../transforms/test_pipeline_regression.py | 2 +
pyproject.toml | 10 +-
src/modelhub/__init__.py | 6 +-
src/modelhub/constants.py | 28 +
src/modelhub/inference_engines/base.py | 216 +++++++
src/modelhub/metrics/losses.py | 27 +
src/modelhub/model/layers/blocks.py | 47 ++
src/modelhub/trainers/fabric.py | 35 +-
.../util => src/modelhub/utils}/alignment.py | 11 +-
.../modelhub/utils}/components.py | 0
src/modelhub/utils/ddp.py | 1 +
.../modelhub/utils/rigid.py | 0
.../modelhub/utils}/rotation_augmentation.py | 3 +-
82 files changed, 2552 insertions(+), 2318 deletions(-)
create mode 100644 .pre-commit-config.yaml
create mode 100644 models/rfd3/configs/trainer/loss/losses/sequence_loss.yaml
create mode 100644 models/rfd3/src/rfd3/engine.py
delete mode 100644 models/rfd3/src/rfd3/inference/engine.py
delete mode 100644 models/rfd3/src/rfd3/inference/specification.py
create mode 100644 models/rfd3/src/rfd3/model/RFD3.py
rename models/rfd3/src/rfd3/model/{rfd3_diffusion_module.py => RFD3_diffusion_module.py} (90%)
delete mode 100644 models/rfd3/src/rfd3/model/af3_design.py
rename models/rfd3/src/rfd3/model/{aa_design.py => inference_sampler.py} (64%)
rename models/rfd3/src/rfd3/model/{ => layers}/attention.py (99%)
rename models/rfd3/src/rfd3/model/{ => layers}/block_utils.py (100%)
rename models/rfd3/src/rfd3/model/{ => layers}/blocks.py (94%)
rename models/rfd3/src/rfd3/model/{ => layers}/chunked_pairwise.py (99%)
rename models/rfd3/src/rfd3/model/{ => layers}/encoders.py (98%)
create mode 100644 models/rfd3/src/rfd3/model/layers/layer_utils.py
create mode 100644 models/rfd3/src/rfd3/model/layers/pairformer_layers.py
rename models/rfd3/src/rfd3/trainer/{rfd3_trainer.py => rfd3.py} (52%)
create mode 100644 models/rfd3/src/rfd3/trainer/trainer_utils.py
delete mode 100644 models/rfd3/src/rfd3/transforms/bfactor_conditioned_transforms.py
rename models/rfd3/src/rfd3/{inference/inference_utils.py => utils/inference.py} (99%)
rename models/rfd3/src/rfd3/{util => utils}/io.py (99%)
rename models/rfd3/src/rfd3/{util => utils}/vizualize.py (99%)
create mode 100644 src/modelhub/constants.py
create mode 100644 src/modelhub/inference_engines/base.py
create mode 100644 src/modelhub/metrics/losses.py
create mode 100644 src/modelhub/model/layers/blocks.py
rename {models/rfd3/src/rfd3/util => src/modelhub/utils}/alignment.py (88%)
rename {models/rfd3/src/rfd3/inference => src/modelhub/utils}/components.py (100%)
rename models/rf3/src/rf3/flow_matching/rigid_utils.py => src/modelhub/utils/rigid.py (100%)
rename {models/rf3/src/rf3/data => src/modelhub/utils}/rotation_augmentation.py (97%)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..1ad50e6
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,8 @@
+repos:
+ - repo: local
+ hooks:
+ - id: make-format
+ name: Run `make format`
+ entry: make format
+ language: system
+ pass_filenames: false
diff --git a/README.md b/README.md
index 7cfc8f9..1f54ad5 100644
--- a/README.md
+++ b/README.md
@@ -111,3 +111,14 @@ To add a new model:
2. Add `modelhub` as a dependency
3. Implement model-specific code in `models//src/`
4. Users can install with: `uv pip install -e ./models/`
+
+### Pre-commit Formatting
+
+We ship a `.pre-commit-config.yaml` that runs `make format` (via `ruff format`) before each commit. Enable it once per clone:
+
+```bash
+pip install pre-commit # if not already installed
+pre-commit install
+```
+
+After installation the hook automatically formats the repo whenever you `git commit`. Use `pre-commit run --all-files` to apply it manually.
diff --git a/lib/atomworks b/lib/atomworks
index 7e12a8a..04da354 160000
--- a/lib/atomworks
+++ b/lib/atomworks
@@ -1 +1 @@
-Subproject commit 7e12a8a11bf810e21a2f52342d2e79e5ec085c22
+Subproject commit 04da3547b29bc444582967786b915298b7edb151
diff --git a/models/rf3/README.md b/models/rf3/README.md
index 8987dc2..172d046 100644
--- a/models/rf3/README.md
+++ b/models/rf3/README.md
@@ -621,3 +621,4 @@ view(atom_array)
**Alternative viewing options:**
- View in PyMol like normal, or using `pymol_remote`
- Use the `view_pymol()` function for direct PyMol integration
+
diff --git a/models/rf3/pyproject.toml b/models/rf3/pyproject.toml
index 337e4cc..b9b32d7 100644
--- a/models/rf3/pyproject.toml
+++ b/models/rf3/pyproject.toml
@@ -24,7 +24,7 @@ classifiers = [
dependencies = [
# Core functionality shared across all models
- "modelforge",
+ # "modelforge",
# CLI
"typer>=0.9.0,<1",
# RF3-specific ML dependencies
diff --git a/models/rf3/src/rf3/data/paired_msa.py b/models/rf3/src/rf3/data/paired_msa.py
index 85d0738..a368a88 100644
--- a/models/rf3/src/rf3/data/paired_msa.py
+++ b/models/rf3/src/rf3/data/paired_msa.py
@@ -7,7 +7,7 @@ from typing import Any
import numpy as np
from atomworks.common import exists
from atomworks.enums import ChainType
-from atomworks.ml.datasets import logger, StructuralDatasetWrapper
+from atomworks.ml.datasets import StructuralDatasetWrapper, logger
from atomworks.ml.datasets.parsers import (
MetadataRowParser,
load_example_from_metadata_row,
diff --git a/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py b/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py
index 012b67f..0c4f149 100755
--- a/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py
+++ b/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py
@@ -1,9 +1,9 @@
import torch
from beartype.typing import Any, Literal
from jaxtyping import Float
-from rf3.data.rotation_augmentation import centre_random_augmentation
from modelhub.utils.ddp import RankedLogger
+from modelhub.utils.rotation_augmentation import centre_random_augmentation
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
diff --git a/models/rf3/src/rf3/inference_engines/rf3.py b/models/rf3/src/rf3/inference_engines/rf3.py
index eb9af8b..f18c3d9 100644
--- a/models/rf3/src/rf3/inference_engines/rf3.py
+++ b/models/rf3/src/rf3/inference_engines/rf3.py
@@ -2,7 +2,6 @@ import logging
from os import PathLike
from pathlib import Path
-import hydra
import pandas as pd
import torch
import torch.distributed as dist
@@ -12,13 +11,12 @@ from atomworks.ml.preprocessing.msa.finding import (
)
from atomworks.ml.samplers import LoadBalancedDistributedSampler
from biotite.structure import AtomArray
-from lightning.fabric import seed_everything
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
+from modelhub.inference_engines.base import BaseInferenceEngine
from modelhub.metrics.metric import MetricManager
-from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
-from modelhub.utils.logging import print_config_tree
+from modelhub.utils.ddp import RankedLogger
from rf3.model.RF3 import ShouldEarlyStopFn
from rf3.utils.inference import (
InferenceInput,
@@ -63,7 +61,7 @@ def should_early_stop_by_mean_plddt(
return fn
-class RF3InferenceEngine:
+class RF3InferenceEngine(BaseInferenceEngine):
"""RF3 inference engine.
Separates model setup (expensive, once) from inference (can run multiple times).
@@ -84,71 +82,44 @@ class RF3InferenceEngine:
def __init__(
self,
- ckpt_path: PathLike,
# Model parameters
n_recycles: int = 10,
diffusion_batch_size: int = 5,
num_steps: int = 50,
- seed: int | None = None,
+ # Templating, MSAs, etc.
template_noise_scale: float = 1e-5,
- early_stopping_plddt_threshold: float | None = None,
- metrics_cfg: dict | OmegaConf | MetricManager | None = None,
- num_nodes: int = 1,
- devices_per_node: int = 1,
+ raise_if_missing_msa_for_protein_of_length_n: int | None = None,
# Output control
compress_outputs: bool = True,
- # Debug
- print_config: bool = False,
- raise_if_missing_msa_for_protein_of_length_n: int | None = None,
+ early_stopping_plddt_threshold: float | None = None,
+ # Metrics
+ metrics_cfg: dict | OmegaConf | MetricManager | None = None,
+ **kwargs,
):
"""Initialize inference engine and load model.
Model config is loaded from checkpoint and overridden with parameters provided here.
Args:
- ckpt_path: Path to model checkpoint.
n_recycles: Number of recycles. Defaults to ``10``.
diffusion_batch_size: Number of structures to generate per input. Defaults to ``5``.
num_steps: Number of diffusion steps. Defaults to ``50``.
- seed: Random seed. If None, uses external RNG state. Defaults to ``None``.
template_noise_scale: Noise scale for template coordinates. Defaults to ``1e-5``.
+ raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
+ compress_outputs: Whether to gzip output files. Defaults to ``True``.
early_stopping_plddt_threshold: Stop early if pLDDT below threshold. Defaults to ``None``.
metrics_cfg: Metrics configuration. Can be:
- dict/OmegaConf with Hydra configs
- Pre-instantiated MetricManager
- None (no metrics).
Defaults to ``None``.
- num_nodes: Number of nodes for distributed inference. Defaults to ``1``.
- devices_per_node: Number of devices per node. Defaults to ``1``.
- compress_outputs: Whether to gzip output files. Defaults to ``True``.
- print_config: Whether to print config trees. Defaults to ``False``.
- raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
+ **kwargs: Additional arguments passed to BaseInferenceEngine:
+ - ckpt_path (PathLike, required): Path to model checkpoint.
+ - seed (int | None): Random seed. If None, uses external RNG state. Defaults to ``None``.
+ - num_nodes (int): Number of nodes for distributed inference. Defaults to ``1``.
+ - devices_per_node (int): Number of devices per node. Defaults to ``1``.
+ - print_config (bool): Whether to print config trees. Defaults to ``False``.
"""
- # Load checkpoint and config
- ranked_logger.info(f"Loading checkpoint from {Path(ckpt_path).resolve()}...")
- checkpoint = torch.load(ckpt_path, "cpu", weights_only=False)
- self.cfg = OmegaConf.create(checkpoint["train_cfg"])
-
- # Override config with inference parameters
- self.cfg.model.net.inference_sampler.num_timesteps = num_steps
- self.cfg.trainer.num_nodes = num_nodes
- self.cfg.trainer.devices_per_node = devices_per_node
-
- # Set metrics - can be dict/OmegaConf or MetricManager
- # Store MetricManager separately since OmegaConf can't serialize it
- self._custom_metric_manager = None
- if isinstance(metrics_cfg, MetricManager):
- # Already instantiated - store separately and pass to trainer later
- self._custom_metric_manager = metrics_cfg
- self.cfg.trainer["metrics"] = {} # Empty dict in config
- elif metrics_cfg is not None:
- # Hydra config dict
- self.cfg.trainer["metrics"] = metrics_cfg
- else:
- self.cfg.trainer["metrics"] = {}
-
- set_accelerator_based_on_availability(self.cfg)
-
# set MSA directories
if env_var_msa_dirs := get_msa_dirs_from_env(raise_if_not_set=False):
override_msa_dirs = [str(msa_dir) for msa_dir in env_var_msa_dirs]
@@ -163,112 +134,72 @@ class RF3InferenceEngine:
]
ranked_logger.info(f"Using default MSA directories: {override_msa_dirs}")
- # Dataset overrides
- self.dataset_overrides = {
- "diffusion_batch_size": diffusion_batch_size,
- "n_recycles": n_recycles,
- "raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
- "undesired_res_names": [],
- "template_noise_scales": {
- "atomized": template_noise_scale,
- "not_atomized": template_noise_scale,
+ super().__init__(
+ transform_overrides={
+ "diffusion_batch_size": diffusion_batch_size,
+ "n_recycles": n_recycles,
+ "raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
+ "undesired_res_names": [],
+ "template_noise_scales": {
+ "atomized": template_noise_scale,
+ "not_atomized": template_noise_scale,
+ },
+ "allowed_chain_types_for_conditioning": None,
+ "protein_msa_dirs": [
+ {
+ "dir": msa_dir,
+ "extension": extension.value,
+ "directory_depth": depth,
+ }
+ for msa_dir, depth, extension in [
+ (msa_dir, *get_msa_depth_and_ext_from_folder(Path(msa_dir)))
+ for msa_dir in override_msa_dirs
+ ]
+ ],
+ "rna_msa_dirs": [],
+ # (Paranoia - in validation, these should be set correctly anyhow)
+ "p_give_polymer_ref_conf": 0.0,
+ "p_give_non_polymer_ref_conf": 0.0,
+ "p_dropout_ref_conf": 0.0,
+ "use_element_for_atom_names_of_atomized_tokens": True,
},
- "allowed_chain_types_for_conditioning": None,
- "protein_msa_dirs": [
- {
- "dir": msa_dir,
- "extension": extension.value,
- "directory_depth": depth,
- }
- for msa_dir, depth, extension in [
- (msa_dir, *get_msa_depth_and_ext_from_folder(Path(msa_dir)))
- for msa_dir in override_msa_dirs
- ]
- ],
- "rna_msa_dirs": [],
- # (Paranoia - in validation, these should be set correctly anyhow)
- "p_give_polymer_ref_conf": 0.0,
- "p_give_non_polymer_ref_conf": 0.0,
- "p_dropout_ref_conf": 0.0,
- "use_element_for_atom_names_of_atomized_tokens": True,
- }
-
- self.print_config = print_config
-
- # Set random seed (only if seed is not None)
- if seed is not None or self.cfg.seed is not None:
- seed = seed or self.cfg.seed
- ranked_logger.info(f"Seeding everything with seed={seed}...")
- seed_everything(seed, workers=True, verbose=True)
- else:
- ranked_logger.info("Seed is None - using external RNG state")
-
- # Instantiate trainer
- ranked_logger.info("Instantiating trainer...")
- if self.print_config:
- print_config_tree(
- self.cfg.trainer, resolve=True, title="INFERENCE TRAINER CONFIGURATION"
- )
-
- self.trainer = hydra.utils.instantiate(
- self.cfg.trainer,
- _convert_="partial",
- _recursive_=False,
+ inference_sampler_overrides={
+ "num_timesteps": num_steps,
+ },
+ **kwargs,
)
- # If we have a custom MetricManager, override the trainer's metrics
- if self._custom_metric_manager is not None:
- self.trainer.metrics = self._custom_metric_manager
+ # remove loss override if present (i.e. keep from checkpoint)
+ self.overrides["trainer"].pop("loss", None)
- self.ckpt_path = ckpt_path
+ # Store metrics config for later - will be set directly on trainer in initialize()
+ self._metrics_cfg = metrics_cfg
+
+ # Dataset overrides
self.early_stopping_plddt_threshold = early_stopping_plddt_threshold
self.compress_outputs = compress_outputs
- # Setup model
- ranked_logger.info("Setting up model...")
- self.trainer.fabric.launch()
- self.trainer.initialize_or_update_trainer_state({"train_cfg": self.cfg})
- self.trainer.construct_model()
+ def initialize(self):
+ cfg = super().initialize()
- ranked_logger.info("Loading model weights from checkpoint...")
- self.trainer.load_checkpoint(checkpoint=checkpoint)
+ if cfg is not None:
+ self.cfg = cfg # store for later use
- # Ensure optimizer isn't loaded
- self.trainer.state["optimizer"] = None
- self.trainer.state["train_cfg"].model.optimizer = None
+ # Set trainer metrics directly based on what was requested
+ # This bypasses the OmegaConf merge issue with empty dicts
+ if isinstance(self._metrics_cfg, MetricManager):
+ # Already instantiated - use directly
+ self.trainer.metrics = self._metrics_cfg
+ elif self._metrics_cfg is not None:
+ # Hydra config dict - instantiate MetricManager
+ self.trainer.metrics = MetricManager.instantiate_from_hydra(
+ metrics_cfg=self._metrics_cfg
+ )
+ else:
+ # No metrics requested - disable them
+ self.trainer.metrics = None
- self.trainer.setup_model_optimizers_and_schedulers()
- self.trainer.state["model"].eval()
-
- # Construct pipeline
- ranked_logger.info("Building Transform pipeline...")
- first_val_dataset_key, first_val_dataset = next(
- iter(self.cfg.datasets.val.items())
- )
- ranked_logger.info(
- f"Using settings from validation dataset: {first_val_dataset_key}."
- )
-
- assert (
- first_val_dataset.dataset.transform.is_inference
- ), "Inference must be enabled for the validation dataset."
-
- # Provide manual overrides to dataset config
- for key, value in self.dataset_overrides.items():
- first_val_dataset.dataset.transform[key] = value
-
- if self.print_config:
- print_config_tree(
- first_val_dataset.dataset.transform,
- resolve=True,
- title="INFERENCE TRANSFORM PIPELINE",
- )
-
- self.pipeline = hydra.utils.instantiate(
- first_val_dataset.dataset.transform,
- )
-
- ranked_logger.info("Model loaded and ready for inference.")
+ return cfg
def run(
self,
@@ -314,6 +245,8 @@ class RF3InferenceEngine:
If ``out_dir`` is None: Dict mapping example_id to results dict.
If ``out_dir`` is set: None (results saved to disk).
"""
+ self.initialize()
+
# Setup output directory if provided
out_dir = Path(out_dir) if out_dir else None
if out_dir:
diff --git a/models/rf3/src/rf3/loss/af3_losses.py b/models/rf3/src/rf3/loss/af3_losses.py
index 4079727..7da3b8a 100644
--- a/models/rf3/src/rf3/loss/af3_losses.py
+++ b/models/rf3/src/rf3/loss/af3_losses.py
@@ -1,10 +1,9 @@
-import hydra
import numpy as np
import torch
import torch.nn as nn
-from rf3.alignment import weighted_rigid_align
from modelhub.training.checkpoint import activation_checkpointing
+from modelhub.utils.alignment import weighted_rigid_align
# resolve residue-level symmetries in native vs pred
@@ -360,31 +359,6 @@ class SubunitSymmetryResolution(nn.Module):
return loss_input
-class Loss(nn.Module):
- def __init__(self, **losses):
- super().__init__()
- self.to_compute = []
- for loss_name, loss in losses.items():
- loss_fn = hydra.utils.instantiate(loss)
- print(f"Adding loss {loss_name} to the loss function")
- self.to_compute.append(loss_fn)
-
- def forward(
- self,
- network_input,
- network_output,
- loss_input,
- ):
- loss_dict = {}
- loss = 0
- for loss_fn in self.to_compute:
- loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
- loss += loss_
- loss_dict.update(loss_dict_)
- loss_dict["total_loss"] = loss.detach()
- return loss, loss_dict
-
-
class ProteinLigandBondLoss(nn.Module):
def __init__(self, weight):
super().__init__()
diff --git a/models/rf3/src/rf3/model/RF3_structure.py b/models/rf3/src/rf3/model/RF3_structure.py
index fc54f03..85da39e 100644
--- a/models/rf3/src/rf3/model/RF3_structure.py
+++ b/models/rf3/src/rf3/model/RF3_structure.py
@@ -15,6 +15,7 @@ from rf3.model.layers.pairformer_layers import (
RF3TemplateEmbedder,
)
+from modelhub.model.layers.blocks import FourierEmbedding
from modelhub.training.checkpoint import activation_checkpointing
logger = logging.getLogger(__name__)
@@ -229,29 +230,6 @@ class DiffusionConditioning(nn.Module):
return _run_conditioning(Z_II, S_trunk_I, S_inputs_I)
-pi = torch.acos(torch.zeros(1)).item() * 2
-
-
-class FourierEmbedding(nn.Module):
- def __init__(self, c):
- super().__init__()
- self.c = c
- self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
- self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
- # super().reset_parameters()
- nn.init.normal_(self.w)
- nn.init.normal_(self.b)
-
- def forward(
- self,
- t, # [D]
- ):
- return torch.cos(2 * pi * (t[:, None] * self.w + self.b))
-
-
class DistogramHead(nn.Module):
def __init__(
self,
diff --git a/models/rf3/src/rf3/model/layers/pairformer_layers.py b/models/rf3/src/rf3/model/layers/pairformer_layers.py
index cbaff5c..8d07559 100644
--- a/models/rf3/src/rf3/model/layers/pairformer_layers.py
+++ b/models/rf3/src/rf3/model/layers/pairformer_layers.py
@@ -16,10 +16,10 @@ from rf3.model.layers.outer_product import (
OuterProductMean_AF3,
)
from rf3.model.RF3_blocks import MSAPairWeightedAverage, MSASubsampleEmbedder
-from rf3.util_module import Dropout
from torch import nn
from torch.nn.functional import one_hot, relu
+from modelhub.model.layers.blocks import Dropout
from modelhub.training.checkpoint import activation_checkpointing
diff --git a/models/rf3/src/rf3/trainers/rf3.py b/models/rf3/src/rf3/trainers/rf3.py
index 8182582..5221017 100644
--- a/models/rf3/src/rf3/trainers/rf3.py
+++ b/models/rf3/src/rf3/trainers/rf3.py
@@ -5,7 +5,6 @@ from einops import repeat
from jaxtyping import Float, Int
from lightning_utilities import apply_to_collection
from omegaconf import DictConfig
-from rf3.loss.af3_losses import Loss as AF3Loss
from rf3.loss.af3_losses import (
ResidueSymmetryResolution,
SubunitSymmetryResolution,
@@ -15,6 +14,7 @@ from rf3.utils.io import build_stack_from_atom_array_and_batched_coords
from rf3.utils.recycling import get_recycle_schedule
from modelhub.common import exists
+from modelhub.metrics.losses import Loss
from modelhub.metrics.metric import MetricManager
from modelhub.trainers.fabric import FabricTrainer
from modelhub.training.EMA import EMA
@@ -42,6 +42,7 @@ class RF3Trainer(FabricTrainer):
n_recycles_train: int | None = None,
loss: DictConfig | dict | None = None,
metrics: DictConfig | dict | MetricManager | None = None,
+ seed=None, # dumped
**kwargs,
):
"""See `FabricTrainer` for the additional initialization arguments.
@@ -79,7 +80,7 @@ class RF3Trainer(FabricTrainer):
self.metrics = None
# Loss
- self.loss = AF3Loss(**loss) if loss else None
+ self.loss = Loss(**loss) if loss else None
# (Symmetry resolution)
self.subunit_symm_resolve = SubunitSymmetryResolution()
diff --git a/models/rf3/src/rf3/util_module.py b/models/rf3/src/rf3/util_module.py
index f653987..78eed3e 100644
--- a/models/rf3/src/rf3/util_module.py
+++ b/models/rf3/src/rf3/util_module.py
@@ -1,6 +1,5 @@
import numpy as np
import torch
-import torch.nn as nn
def init_lecun_normal(module, scale=1.0):
@@ -38,29 +37,6 @@ def create_custom_forward(module, **kwargs):
return custom_forward
-class Dropout(nn.Module):
- # Dropout entire row or column
- def __init__(self, broadcast_dim=None, p_drop=0.15):
- super(Dropout, self).__init__()
- # give ones with probability of 1-p_drop / zeros with p_drop
- self.sampler = torch.distributions.bernoulli.Bernoulli(
- torch.tensor([1 - p_drop])
- )
- self.broadcast_dim = broadcast_dim
- self.p_drop = p_drop
-
- def forward(self, x):
- if not self.training: # no drophead during evaluation mode
- return x
- shape = list(x.shape)
- if self.broadcast_dim is not None:
- shape[self.broadcast_dim] = 1
- mask = self.sampler.sample(shape).to(x.device).view(shape)
-
- x = mask * x / (1.0 - self.p_drop)
- return x
-
-
def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5):
# Distance radial basis function
D_max = D_min + (D_count - 1) * D_sigma
diff --git a/models/rf3/src/rf3/utils/io.py b/models/rf3/src/rf3/utils/io.py
index 6fb6949..e62cf7a 100644
--- a/models/rf3/src/rf3/utils/io.py
+++ b/models/rf3/src/rf3/utils/io.py
@@ -8,8 +8,8 @@ from atomworks.ml.utils.io import apply_sharding_pattern
from atomworks.ml.utils.misc import hash_sequence
from beartype.typing import Literal
from biotite.structure import AtomArray, AtomArrayStack, stack
-from rf3.alignment import weighted_rigid_align
+from modelhub.utils.alignment import weighted_rigid_align
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
diff --git a/models/rf3/tests/test_inference_regression.py b/models/rf3/tests/test_inference_regression.py
index 8a9f195..4497262 100755
--- a/models/rf3/tests/test_inference_regression.py
+++ b/models/rf3/tests/test_inference_regression.py
@@ -171,7 +171,7 @@ def test_inference_regression(example_id, rmsd_tolerance, csv_tolerance):
baseline_dir = TEST_DATA_DIR / "inference_regression_tests" / example_id
with (
- initialize(config_path="../configs"),
+ initialize(config_path="../configs", version_base="1.3"),
tempfile.TemporaryDirectory() as temp_dir,
rng_state(create_rng_state_from_seeds(1, 1, 1)),
):
@@ -327,5 +327,6 @@ def test_inference_regression_in_memory(example_id, rmsd_tolerance, csv_toleranc
rmsd_difference < rmsd_tolerance
), f"Mean RMSD difference {rmsd_difference:.4f}Å exceeds {rmsd_tolerance}Å tolerance for {example_id}"
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/models/rfd3/README.md b/models/rfd3/README.md
index f1bef15..363a50d 100644
--- a/models/rfd3/README.md
+++ b/models/rfd3/README.md
@@ -108,3 +108,23 @@ This is then passed through the same processing pipeline as in training with `is
Overview of important transforms in the Atom14 conditioning pipeline.
+
+
+
+## Citation
+
+If you use this code or data in your work, please cite:
+
+```bibtex
+@article {butcher2025_rfdiffusion3,
+ author = {Butcher, Jasper and Krishna, Rohith and Mitra, Raktim and Brent, Rafael Isaac and Li, Yanjing and Corley, Nathaniel and Kim, Paul T and Funk, Jonathan and Mathis, Simon Valentin and Salike, Saman and Muraishi, Aiko and Eisenach, Helen and Thompson, Tuscan Rock and Chen, Jie and Politanska, Yuliya and Sehgal, Enisha and Coventry, Brian and Zhang, Odin and Qiang, Bo and Didi, Kieran and Kazman, Maxwell and DiMaio, Frank and Baker, David},
+ title = {De novo Design of All-atom Biomolecular Interactions with RFdiffusion3},
+ elocation-id = {2025.09.18.676967},
+ year = {2025},
+ doi = {10.1101/2025.09.18.676967},
+ publisher = {Cold Spring Harbor Laboratory},
+ URL = {https://www.biorxiv.org/content/early/2025/11/19/2025.09.18.676967},
+ eprint = {https://www.biorxiv.org/content/early/2025/11/19/2025.09.18.676967.full.pdf},
+ journal = {bioRxiv}
+}
+```
\ No newline at end of file
diff --git a/models/rfd3/configs/datasets/val/design_validation_base.yaml b/models/rfd3/configs/datasets/val/design_validation_base.yaml
index bde63bf..5aabcc0 100644
--- a/models/rfd3/configs/datasets/val/design_validation_base.yaml
+++ b/models/rfd3/configs/datasets/val/design_validation_base.yaml
@@ -35,17 +35,6 @@ dataset:
use_element_for_atom_names_of_atomized_tokens: ${datasets.global_transform_args.use_element_for_atom_names_of_atomized_tokens}
residue_cache_dir: ${paths.data.residue_cache_dir}
- # PPI
- keep_full_binder_in_spatial_crop: ${datasets.global_transform_args.keep_full_binder_in_spatial_crop}
- max_binder_length: ${datasets.global_transform_args.max_binder_length}
- max_ppi_hotspots_frac_to_provide: ${datasets.global_transform_args.max_ppi_hotspots_frac_to_provide}
- ppi_hotspot_max_distance: ${datasets.global_transform_args.ppi_hotspot_max_distance}
-
- # Secondary structure
- max_ss_frac_to_provide: ${datasets.global_transform_args.max_ss_frac_to_provide}
- min_ss_island_len: ${datasets.global_transform_args.min_ss_island_len}
- max_ss_island_len: ${datasets.global_transform_args.max_ss_island_len}
-
# Other dataset-specific parameters
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
token_1d_features: ${model.net.token_initializer.token_1d_features}
\ No newline at end of file
diff --git a/models/rfd3/configs/experiment/rfd3-full.yaml b/models/rfd3/configs/experiment/rfd3-full.yaml
index 239edfa..fda9145 100644
--- a/models/rfd3/configs/experiment/rfd3-full.yaml
+++ b/models/rfd3/configs/experiment/rfd3-full.yaml
@@ -159,23 +159,6 @@ trainer:
prevalidate: False
validate_every_n_epochs: 4
precision: bf16-mixed
- loss:
- verbose_diffusion_loss: null
- simple_diffusion_loss:
- _target_: rfd3.metrics.losses.SimpleDiffusionLoss
- sigma_data: ${model.net.diffusion_module.sigma_data}
- weight: 4.0
- lddt_weight: 0.25
- alpha_virtual_atom: 1.0
- alpha_polar_residues: 1.0
-
- lp_weight: 0.0
- unindexed_norm_p: 1.0
- alpha_unindexed_diffused: 1.0
- unindexed_t_alpha: 0.75
- normalize_virtual_atom_weight: False
- alpha_ligand: 10.0
-
# callbacks:
# activations_tracking_callback:
# _target_: modelhub.callbacks.health_logging.ActivationsGradientsWeightsTracker
diff --git a/models/rfd3/configs/inference_engine/base.yaml b/models/rfd3/configs/inference_engine/base.yaml
index e7679ef..74e4336 100644
--- a/models/rfd3/configs/inference_engine/base.yaml
+++ b/models/rfd3/configs/inference_engine/base.yaml
@@ -7,10 +7,9 @@ defaults:
ckpt_path: ???
num_nodes: 1
devices_per_node: 1
+print_config: False
+seed: null
# Parameters for RFD3InferenceEngine.run()
inputs: ???
out_dir: ???
-dump_predictions: true
-dump_trajectories: false
-one_model_per_file: false
diff --git a/models/rfd3/configs/inference_engine/rfdiffusion3.yaml b/models/rfd3/configs/inference_engine/rfdiffusion3.yaml
index bf2a271..c9a3b5e 100644
--- a/models/rfd3/configs/inference_engine/rfdiffusion3.yaml
+++ b/models/rfd3/configs/inference_engine/rfdiffusion3.yaml
@@ -3,14 +3,14 @@ defaults:
- base
- _self_
-_target_: rfd3.inference.engine.RFD3InferenceEngine
+_target_: rfd3.engine.RFD3InferenceEngine
out_dir: ???
inputs: ??? # null, json, pdb or
-ckpt_path: /projects/ml/aa_design/models/rfd3_latest.ckpt
+# ckpt_path: /projects/ml/aa_design/models/rfd3_latest.ckpt
+ckpt_path: /projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt
json_keys_subset: null
skip_existing: True
-seed: null # if null samples seed integer based on timestamp
#########################################################
# Design spec args: overrides args from input json
@@ -21,7 +21,6 @@ specification: {}
diffusion_batch_size: 8
n_batches: 1
-
# Inference sampler args | set to None to use the default in the checkpoint's config
inference_sampler:
kind: "default" # "default" or "symmetry" to choose the sampler
@@ -35,14 +34,9 @@ inference_sampler:
cfg_t_max: null # max t to apply cfg guidance
cfg_scale: 1.5
center_option: "all" # Options are ["all", "motif", "diffuse"]
- move_noise_to_reset_com: False # Reset the COM of the diffuse region after the re-noising operation in each diffusion step
s_trans: 1.0 # Translational noise scale for augmentation during inference
- fraction_of_steps_to_fix_motif: 0.0 # Fraction of steps to let the model not move the motif. e.g. if we have 10 steps, set this value to 0.2 will make model not move motif for the last 2 steps.
- skip_few_diffusion_steps: False # Choose to skip some diffusion steps based on the noise scheme
inference_noise_scaling_factor: 1.0
allow_realignment: False
- zero_drift_noise: False
- use_frame_guidance: False
# Diffusion args:
num_timesteps: 200
@@ -51,7 +45,6 @@ inference_sampler:
p: 7
gamma_0: 0.6 # Previously 1.0 | 0.0 for ODE sampling
gamma_min: 1.0
- gamma_min2: 0.0
s_jitter_origin: 0.0 # Sigma of gaussian noise to jitter the motif offset (equivalent to ORI token Jitter)
# Saving args
@@ -66,13 +59,8 @@ output_full_json: True
# e.g. Empty string -> f'{jsonkey}_{batch}_{model}'
# e.g. Chunk string -> f'{chunkprefix_}{jsonkey}_{batch}_{model}' (pipelines usage)
global_prefix: null
-
dump_prediction_metadata_json: True
dump_trajectories: False
-one_model_per_file: True
align_trajectory_structures: False
-
-# Additional args
-print_config: False
prevalidate_inputs: True
low_memory_mode: False # False for standard mode, True for memory efficient tokenization mode
diff --git a/models/rfd3/configs/model/components/rfd3_net.yaml b/models/rfd3/configs/model/components/rfd3_net.yaml
index d496c4c..4033483 100644
--- a/models/rfd3/configs/model/components/rfd3_net.yaml
+++ b/models/rfd3/configs/model/components/rfd3_net.yaml
@@ -1,4 +1,4 @@
-_target_: rfd3.model.aa_design.RFD3
+_target_: rfd3.model.RFD3.RFD3
c_s: 384
c_z: 128
@@ -53,7 +53,7 @@ token_initializer: # formerly known as the trunk
n_attn_keys: 128
diffusion_module:
- _target_: rfd3.model.rfd3_diffusion_module.RFD3DiffusionModule
+ _target_: rfd3.model.RFD3_diffusion_module.RFD3DiffusionModule
c_token: 768
c_t_embed: 256 # Time embedding dimension
sigma_data: 16
diff --git a/models/rfd3/configs/paths/data/default.yaml b/models/rfd3/configs/paths/data/default.yaml
index a2faf95..b92ad06 100644
--- a/models/rfd3/configs/paths/data/default.yaml
+++ b/models/rfd3/configs/paths/data/default.yaml
@@ -26,6 +26,6 @@ design_benchmark_data_dir: /projects/ml/aa_design/benchmarks
design_model_weight_dir: /projects/ml/aa_design/models
# path to directory with cached residue data
-residue_cache_dir: /net/tukwila/ncorley/datahub/MACE-OFF23_medium
+residue_cache_dir: /net/tukwila/lschaaf/datahub/MACE-Egret-3-noH/mace_embeddings
cif_cache_dir: /net/tukwila/ncorley/cifutils/cache
\ No newline at end of file
diff --git a/models/rfd3/configs/trainer/aa_design.yaml b/models/rfd3/configs/trainer/aa_design.yaml
index b16a74a..269809d 100644
--- a/models/rfd3/configs/trainer/aa_design.yaml
+++ b/models/rfd3/configs/trainer/aa_design.yaml
@@ -1,6 +1,7 @@
defaults:
- ddp
- - loss/losses/diffusion_loss@loss.verbose_diffusion_loss
+ - loss/losses/diffusion_loss@loss.diffusion_loss
+ - loss/losses/sequence_loss@loss.sequence_loss
- metrics: design_metrics
- _self_
@@ -27,11 +28,4 @@ clip_grad_max_norm: 10.0
output_dir: ${paths.output_dir}
n_recycles_train: 2
grad_accum_steps: 3 # overridden by launch.sh
-skip_optimizer_loading: True
-
-loss:
- verbose_diffusion_loss:
- sequence_head_loss:
- _target_: rfd3.metrics.losses.SequenceLoss
- weight: 0.1
- max_t: 1
+skip_optimizer_loading: True
\ No newline at end of file
diff --git a/models/rfd3/configs/trainer/loss/losses/diffusion_loss.yaml b/models/rfd3/configs/trainer/loss/losses/diffusion_loss.yaml
index cc99a26..7dd300e 100644
--- a/models/rfd3/configs/trainer/loss/losses/diffusion_loss.yaml
+++ b/models/rfd3/configs/trainer/loss/losses/diffusion_loss.yaml
@@ -1,21 +1,11 @@
-# _target_: modelhub.loss.af3_losses.DiffusionLoss
-# sigma_data: ${model.net.diffusion_module.sigma_data}
-# alpha_dna: 5
-# alpha_rna: 5
-# alpha_ligand: 10
-# edm_lambda: True
-# se3_invariant_loss: True
-# clamp_diffusion_loss: False
+sigma_data: ${model.net.diffusion_module.sigma_data}
weight: 4.0
-
-_target_: rfd3.metrics.losses.VerboseDiffusionLoss
-alpha_motif: 1.0
-alpha_ca_atom: 1.0
-alpha_virtual_atom: 1.0
-alpha_fixed_motif: 2.0
-alpha_unindexed_diffused: 2.0
lddt_weight: 0.25
-clamp_diffusion_loss: True
-align_prediction: False
-use_motif_aligned_loss: False
+alpha_virtual_atom: 1.0
+alpha_polar_residues: 1.0
+lp_weight: 0.0
+unindexed_norm_p: 1.0
+alpha_unindexed_diffused: 1.0
+unindexed_t_alpha: 0.75
normalize_virtual_atom_weight: False
+alpha_ligand: 10.0
\ No newline at end of file
diff --git a/models/rfd3/configs/trainer/loss/losses/sequence_loss.yaml b/models/rfd3/configs/trainer/loss/losses/sequence_loss.yaml
new file mode 100644
index 0000000..5ab6fbb
--- /dev/null
+++ b/models/rfd3/configs/trainer/loss/losses/sequence_loss.yaml
@@ -0,0 +1,3 @@
+_target_: rfd3.metrics.losses.SequenceLoss
+weight: 0.1
+max_t: 1
diff --git a/models/rfd3/pyproject.toml b/models/rfd3/pyproject.toml
index 193d8fa..179ae36 100644
--- a/models/rfd3/pyproject.toml
+++ b/models/rfd3/pyproject.toml
@@ -3,7 +3,7 @@ name = "rfd3"
dynamic = ["version"]
description = "De novo Design of All-atom Biomolecular Interactions with RFdiffusion3"
readme = "README.md"
-requires-python = ">= 3.12"
+requires-python = ">=3.12"
authors = [
{ name = "Institute for Protein Design", email = "contact@ipd.uw.edu" },
]
diff --git a/models/rfd3/src/rfd3/cli.py b/models/rfd3/src/rfd3/cli.py
index 17d84bf..1173fe7 100644
--- a/models/rfd3/src/rfd3/cli.py
+++ b/models/rfd3/src/rfd3/cli.py
@@ -19,26 +19,15 @@ def design(ctx: typer.Context):
# Get all arguments
args = ctx.params.get("args", []) + ctx.args
-
- # Parse arguments
- hydra_overrides = []
-
- if len(args) == 1 and "=" not in args[0]:
- # Old style: single positional argument assumed to be inputs
- hydra_overrides.append(f"inputs={args[0]}")
- else:
- # New style: all arguments are hydra overrides
- hydra_overrides.extend(args)
+ args = [a for a in args if a not in ["design", "fold"]]
# Ensure we have at least a default inference_engine if not specified
- has_inference_engine = any(
- arg.startswith("inference_engine=") for arg in hydra_overrides
- )
+ has_inference_engine = any(arg.startswith("inference_engine=") for arg in args)
if not has_inference_engine:
- hydra_overrides.append("inference_engine=rfdiffusion3")
+ args.append("inference_engine=rfdiffusion3")
with initialize_config_dir(config_dir=config_path, version_base="1.3"):
- cfg = compose(config_name="inference", overrides=hydra_overrides)
+ cfg = compose(config_name="inference", overrides=args)
# Lazy import to avoid loading heavy dependencies at CLI startup
from rfd3.run_inference import run_inference
diff --git a/models/rfd3/src/rfd3/constants.py b/models/rfd3/src/rfd3/constants.py
index ec1fadf..3ea42ad 100644
--- a/models/rfd3/src/rfd3/constants.py
+++ b/models/rfd3/src/rfd3/constants.py
@@ -1,5 +1,8 @@
import numpy as np
-from atomworks.constants import CRYSTALLIZATION_AIDS
+
+from modelhub.constants import TIP_BY_RESTYPE
+
+TIP_BY_RESTYPE
# Annot: default (diffused default)
REQUIRED_CONDITIONING_ANNOTATION_VALUES = {
@@ -204,46 +207,3 @@ SELECTION_NONPROTEIN = [
"MACROLIDE",
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
]
-
-# fmt: off
-# ... For convenience, we allow BKBN, or TIP to be used as a shortcut | TIP is the largest set of fixed atom given at least 2 tip atoms
-TIP_BY_RESTYPE = {
- "TRP": ["CG","CD1","CD2","NE1","CE2","CE3","CZ2","CZ3","CH2"], # fix both rings
- "HIS": ["CG","ND1","CD2","CE1","NE2"], # fixed ring
- "TYR": ["CZ","OH"], # keeps ring dihedral flexible
- "PHE": ["CG","CD1","CD2","CE1","CE2","CZ"],
- "ASN": ["CB", "CG","OD1","ND2"],
- "ASP": ["CB", "CG","OD1","OD2"],
- "GLN": ["CG", "CD","OE1","NE2"],
- "GLU": ["CG", "CD","OE1","OE2"],
- "CYS": ["CB", "SG"],
- "SER": ["CB", "OG"],
- "THR": ["CB", "OG1"],
- "LEU": ["CB", "CG", "CD1", "CD2"],
- "VAL": ["CG1", "CG2"],
- "ILE": ["CB", "CG2"],
- "MET": ["SD", "CE"],
- "LYS": ["CE","NZ"],
- "ARG": ["CD","NE","CZ","NH1","NH2"],
- "PRO": None,
- "ALA": None,
- "GLY": None,
- "UNK": None,
- "MSK": None
-}
-# fmt: on
-
-STANDARD_PARSER_ARGS = {
- "add_missing_atoms": True,
- "add_id_and_entity_annotations": True,
- "add_bond_types_from_struct_conn": ("covale",),
- "remove_ccds": tuple(CRYSTALLIZATION_AIDS),
- "remove_waters": True,
- "fix_ligands_at_symmetry_centers": True,
- "fix_arginines": True,
- "fix_formal_charges": True,
- "fix_bond_types": True,
- "convert_mse_to_met": True, # Changed from False to True vs. atomworks.io.parser.parse default
- "hydrogen_policy": "keep",
- "model": None, # all models
-}
diff --git a/models/rfd3/src/rfd3/docs/input.md b/models/rfd3/src/rfd3/docs/input.md
index 5809d97..ba1f3cf 100644
--- a/models/rfd3/src/rfd3/docs/input.md
+++ b/models/rfd3/src/rfd3/docs/input.md
@@ -1,7 +1,7 @@
# RFdiffusion3 — Input specification (dialect **2**)
> **TL;DR**
-> Inputs are now defined with a single `InputSpecification` class.
+> Inputs are now defined with a single `DesignInputSpecification` class.
> Selections like “what’s fixed?”, “what’s sequence-free?”, “which atoms are donors/acceptors?” are all expressed with the same **InputSelection** mini-language.
> Everything is reproducibly logged back out alongside your generation.
@@ -10,7 +10,7 @@
- [What changed (high level)](#what-changed-high-level)
- [Quick start](#quick-start)
- [The `InputSelection` mini-language](#the-inputselection-mini-language)
-- [Full schema: `InputSpecification`](#full-schema-inputspecification)
+- [Full schema: `DesignInputSpecification`](#full-schema-DesignInputSpecification)
- [Common recipes (cookbook)](#common-recipes-cookbook)
- [Partial diffusion](#partial-diffusion)
- [Symmetry](#symmetry)
@@ -40,7 +40,7 @@
---
-## InputSpecification
+## DesignInputSpecification
| Field | Type | Description |
| -------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------- |
diff --git a/models/rfd3/src/rfd3/engine.py b/models/rfd3/src/rfd3/engine.py
new file mode 100644
index 0000000..d9ba54d
--- /dev/null
+++ b/models/rfd3/src/rfd3/engine.py
@@ -0,0 +1,435 @@
+import json
+import logging
+import os
+import time
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict, List
+
+import torch
+import yaml
+from biotite.structure import AtomArray
+from toolz import merge_with
+
+from modelhub.common import exists
+from modelhub.inference_engines.base import BaseInferenceEngine
+from modelhub.utils.ddp import RankedLogger
+from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
+from rfd3.inference.datasets import (
+ assemble_distributed_inference_loader_from_json,
+)
+from rfd3.inference.input_parsing import DesignInputSpecification
+from rfd3.utils.inference import ensure_input_is_abspath
+from rfd3.utils.io import (
+ CIF_LIKE_EXTENSIONS,
+ dump_metadata,
+ dump_structures,
+ dump_trajectories,
+ extract_example_id_from_path,
+ find_files_with_extension,
+)
+
+logging.basicConfig(level=logging.INFO)
+ranked_logger = RankedLogger(__name__, rank_zero_only=True)
+
+
+class RFD3InferenceEngine(BaseInferenceEngine):
+ """Inference engine for RFdiffusion3"""
+
+ def __init__(
+ self,
+ *,
+ # Default input handling args
+ skip_existing: bool,
+ json_keys_subset: None | List[str],
+ prevalidate_inputs: bool,
+ # Base inference engine args
+ diffusion_batch_size: int,
+ inference_sampler: dict,
+ specification: dict | None,
+ # Structure dumping arguments
+ global_prefix: str | None,
+ cleanup_guideposts: bool,
+ cleanup_virtual_atoms: bool,
+ read_sequence_from_sequence_head: bool,
+ output_full_json: bool,
+ dump_prediction_metadata_json: bool,
+ dump_trajectories: bool,
+ align_trajectory_structures: bool,
+ low_memory_mode: bool,
+ **kwargs,
+ ):
+ super().__init__(
+ transform_overrides={"diffusion_batch_size": diffusion_batch_size},
+ inference_sampler_overrides={**inference_sampler},
+ trainer_overrides={
+ "cleanup_guideposts": cleanup_guideposts,
+ "cleanup_virtual_atoms": cleanup_virtual_atoms,
+ "read_sequence_from_sequence_head": read_sequence_from_sequence_head,
+ "output_full_json": output_full_json,
+ },
+ **kwargs,
+ )
+ # save
+ self.specification_overrides = dict(specification or {})
+
+ # Setup output directories and args
+ self.global_prefix = global_prefix
+ self.json_keys_subset = json_keys_subset
+ self.prevalidate_inputs = prevalidate_inputs
+ self.skip_existing = skip_existing
+
+ # Saving / other args
+ self.dump_prediction_metadata_json = dump_prediction_metadata_json
+ self.dump_trajectories = dump_trajectories
+ self.align_trajectory_structures = align_trajectory_structures
+ if not cleanup_guideposts:
+ ranked_logger.warning(
+ "Guideposts will not be cleaned up. This is intended for debugging purposes."
+ )
+ if not cleanup_virtual_atoms:
+ ranked_logger.warning(
+ "Virtual atoms will not be cleaned up. Some tools like MPNN may run, but outputs will not be like native structures."
+ )
+
+ # Check which example ids already exist in the output directory
+ if low_memory_mode:
+ ranked_logger.info("Low memory mode enabled.")
+ # HACK: Set attribute to the diffusion module
+ os.environ["RFD3_LOW_MEMORY_MODE"] = "1"
+
+ def run(
+ self,
+ *,
+ inputs: str | PathLike | AtomArray | DesignInputSpecification,
+ n_batches: int | None = None,
+ out_dir: str | PathLike | None = None,
+ ):
+ self._set_out_dir(out_dir)
+ inputs = self._canonicalize_inputs(inputs)
+ design_specifications = self._multiply_specifications(
+ inputs=inputs,
+ n_batches=n_batches,
+ )
+ # init before
+ self.initialize()
+ outputs = self._run_multi(design_specifications)
+ return outputs
+
+ def _set_out_dir(self, out_dir: str | PathLike | None):
+ out_dir = Path(out_dir) if out_dir else None
+ if out_dir:
+ out_dir.mkdir(parents=True, exist_ok=True)
+ ranked_logger.info(f"Outputs will be written to {out_dir.resolve()}.")
+ self.out_dir = out_dir
+
+ def _run_multi(self, specs):
+ # ==============================================================================
+ # Prepare pipeline and inference loader
+ # ==============================================================================
+ loader = assemble_distributed_inference_loader_from_json(
+ # Passed directly to ContigJSONDataset
+ # data={spec.example_id: spec for spec in spec.values()},
+ data=specs,
+ transform=self.pipeline,
+ name="inference-dataset",
+ cif_parser_args=None,
+ subset_to_keys=None,
+ eval_every_n=1,
+ # Sampler args
+ world_size=self.trainer.fabric.world_size,
+ rank=self.trainer.fabric.global_rank,
+ )
+ loader = self.trainer.fabric.setup_dataloaders(
+ loader,
+ use_distributed_sampler=False,
+ )
+
+ # ==============================================================================
+ # Evaluate, using `validation_step`
+ # ==============================================================================
+ outputs = {}
+ for batch_idx, batch in enumerate(loader):
+ pipeline_output = batch[0]
+ output = self._model_forward(pipeline_output)
+
+ if self.out_dir:
+ self.save_batch_outputs(
+ out_dir=self.out_dir,
+ example_id=pipeline_output["example_id"],
+ network_output=output["network_output"],
+ prediction_metadata=output["prediction_metadata"],
+ predicted_atom_array_stack=output["predicted_atom_array_stack"],
+ pipeline_output=pipeline_output,
+ )
+ else:
+ outputs[pipeline_output["example_id"]] = {
+ "network_output": output["network_output"],
+ "prediction_metadata": output["prediction_metadata"],
+ "predicted_atom_array_stack": output["predicted_atom_array_stack"],
+ }
+ return outputs
+
+ def _model_forward(self, pipeline_output):
+ t0 = time.time()
+ with torch.no_grad():
+ pipeline_output = self.trainer.fabric.to_device(pipeline_output)
+ output = self.trainer.validation_step(
+ batch=pipeline_output,
+ batch_idx=0,
+ compute_metrics=False,
+ )
+ t_end = time.time()
+
+ # Add additional information to prediction metadata
+ for key in output["prediction_metadata"].keys():
+ ckpt = Path(self.ckpt_path)
+ if ckpt.is_symlink():
+ ckpt = ckpt.resolve(strict=True) # follow symlink to target
+ output["prediction_metadata"][key]["ckpt_path"] = str(ckpt)
+ output["prediction_metadata"][key]["seed"] = self.seed
+
+ ranked_logger.info(f"Finished inference batch in {t_end - t0:.2f} seconds.")
+ return output
+
+ ###############################################
+ # Input merging
+ ###############################################
+
+ def _canonicalize_inputs(
+ self, inputs
+ ) -> Dict[str, dict | DesignInputSpecification]:
+ is_json_like = (isinstance(inputs, (str, PathLike, Path))) or (
+ isinstance(inputs, list)
+ and all([isinstance(i, (str, PathLike, Path)) for i in inputs])
+ )
+ is_specification_like = isinstance(inputs, DesignInputSpecification) or (
+ isinstance(inputs, list)
+ and all([isinstance(i, DesignInputSpecification) for i in inputs])
+ )
+ is_atom_array_like = isinstance(inputs, (AtomArray, list)) or (
+ isinstance(inputs, list) and all([isinstance(i, AtomArray) for i in inputs])
+ )
+ if inputs is None:
+ # Create empty specification dictionary
+ return {"": {**self.specification_overrides}}
+ elif is_json_like:
+ # List of file paths
+ inputs = process_input(
+ inputs,
+ json_keys_subset=self.json_keys_subset,
+ global_prefix=self.global_prefix,
+ specification_overrides=self.specification_overrides,
+ validate=self.prevalidate_inputs,
+ ) # any -> Dict[Name: DesignInputSpecification]
+ elif is_specification_like:
+ # List of DesignInputSpecifications
+ if isinstance(inputs, DesignInputSpecification):
+ inputs = [inputs]
+ inputs = {f"backbone_{i}": spec for i, spec in enumerate(inputs)}
+ elif is_atom_array_like:
+ raise NotImplementedError("AtomArray inputs not yet supported.")
+ else:
+ raise ValueError(
+ f"Invalid input type: {type(inputs)}. Expected JSON/YAML file paths, AtomArray, or DesignInputSpecification.\nInput: {inputs}"
+ )
+
+ return design_specifications
+
+ def _multiply_specifications(
+ self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
+ ) -> Dict[str, Dict[str, Any]]:
+ # Find existing example IDS in output directory
+ if exists(self.out_dir):
+ existing_example_ids = set(
+ extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
+ for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
+ )
+ ranked_logger.info(
+ f"Found {len(existing_example_ids)} existing example IDs in the output directory."
+ )
+
+ # Based on inputs, construct the specifications to loop through
+ design_specifications = {}
+ for prefix, example_spec in inputs.items():
+ # ... Create n_batches for example
+ for batch_id in range((n_batches) if exists(n_batches) else 1):
+ # ... Example ID
+ example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
+
+ if (
+ self.skip_existing
+ and exists(self.out_dir)
+ and example_id in existing_example_ids
+ ):
+ ranked_logger.info(
+ f"Skipping design specification for example {example_id} | Already exists."
+ )
+ continue
+ design_specifications[example_id] = example_spec
+ return design_specifications
+
+ def save_batch_outputs(
+ self,
+ *,
+ out_dir,
+ network_output,
+ prediction_metadata,
+ predicted_atom_array_stack,
+ pipeline_output,
+ example_id,
+ ):
+ out_dir = Path(out_dir)
+ dump_structures(
+ atom_arrays=predicted_atom_array_stack,
+ base_path=out_dir / example_id,
+ one_model_per_file=True,
+ extra_fields=SAVED_CONDITIONING_ANNOTATIONS,
+ )
+
+ if self.dump_prediction_metadata_json:
+ dump_metadata(
+ prediction_metadata=prediction_metadata,
+ base_path=out_dir / example_id,
+ one_model_per_file=True,
+ )
+
+ if self.dump_trajectories:
+ dump_trajectories(
+ trajectory_list=network_output["X_denoised_L_traj"],
+ atom_array=pipeline_output["atom_array"],
+ base_path=out_dir / f"{example_id}_denoised",
+ align_structures=self.align_trajectory_structures,
+ )
+ dump_trajectories(
+ trajectory_list=network_output["X_noisy_L_traj"],
+ atom_array=pipeline_output["atom_array"],
+ base_path=out_dir / f"{example_id}_noisy",
+ align_structures=self.align_trajectory_structures,
+ )
+
+ ranked_logger.info(
+ f"Outputs for {example_id} written to {out_dir / example_id}."
+ )
+
+
+def normalize_inputs(inputs: str | list | None) -> list[str | None]:
+ """
+ inputs: str | list[str] | None
+ - Can be:
+ - A single path to a JSON, YAML, or regular input file (cif or pdb)
+ - A comma-separated string of paths (e.g. "a.json,b.json")
+ - A list of file paths
+ - None or an empty list, in which case a dummy input is added (used for e.g. motif-only design)
+ - Returns list of paths or [None] if no inputs are provided
+ """
+ if inputs is None or (isinstance(inputs, list) and len(inputs) == 0):
+ inputs = [None]
+ elif isinstance(inputs, str):
+ inputs = inputs.split(",")
+ elif not isinstance(inputs, list):
+ raise ValueError(
+ f"Invalid input type: {type(inputs)}. Expected str, list, or None.\nInput: {inputs}"
+ )
+ return inputs
+
+
+def process_input(
+ inputs: str | list | None,
+ json_keys_subset: str | list | None = None,
+ global_prefix: str | None = None,
+ specification_overrides: dict | None = None,
+ validate: bool = True,
+) -> Dict[str, dict]:
+ """
+ inputs: Any -> list[str | None] (see normalize_inputs)
+ json_keys_subset: extract only subset of JSON keys. None will keep all keys
+ prefix: If provided, prefix all example ids with said prefix
+
+ returns: Dictionaries of specifcation args pre-batching:
+ {
+ 'jsonfile_jsonkey1': {
+ **args_from_key1
+ },
+ 'jsonfile_jsonkey2': {
+ **args_from_key2
+ }
+ }
+ """
+ specification_overrides = dict(specification_overrides or {})
+
+ def merge_args(example_args: dict) -> dict:
+ return merge_with(lambda x: x[-1], example_args, specification_overrides)
+
+ inputs = normalize_inputs(inputs)
+
+ # If global_prefix is not provided, then default to using the basename of the JSON or YAML file (when provided)
+ if global_prefix is None:
+ use_json_basename_prefix = True
+ else:
+ use_json_basename_prefix = False
+
+ # ... Convert all inputs to list of inputs (e.g. if comma-separated)
+ if exists(inputs) and "," in inputs:
+ inputs = inputs.split(",")
+ elif not exists(inputs):
+ # If inputs is None or empty, we will create a dummy input
+ inputs = []
+ inputs = inputs if isinstance(inputs, list) else [inputs]
+ if len(inputs) == 0:
+ inputs = [None]
+
+ # ... Determine prefix of sample to create
+ all_specs = {}
+ for input in inputs:
+ if exists(input) and (input.endswith(".json") or input.endswith(".yaml")):
+ # ... Load JSON or YAML file
+ with open(input, "r") as f:
+ data = json.load(f) if input.endswith(".json") else yaml.safe_load(f)
+
+ # ... Apply any global args for this input file
+ if "global_args" in data:
+ global_args = data.pop("global_args")
+ for example in data:
+ data[example].update(global_args)
+
+ # ... Subset to keys
+ if json_keys_subset is not None:
+ json_keys_subset = (
+ json_keys_subset.split(",")
+ if isinstance(json_keys_subset, str)
+ else json_keys_subset
+ )
+ data = {
+ example: data[example]
+ for example in json_keys_subset
+ if example in data
+ }
+
+ # ... Extract each accumulated example in data.
+ for example, args in data.items():
+ args = ensure_input_is_abspath(args, input)
+ if use_json_basename_prefix:
+ name = os.path.splitext(os.path.basename(input))[0]
+ prefix = f"{name}_{example}"
+ else:
+ prefix = f"{global_prefix}{example}"
+ args["extra"] = args.get("extra", {}) | {"example": example}
+ all_specs[prefix] = dict(merge_args(args))
+
+ elif exists(input):
+ prefix = os.path.basename(os.path.splitext(input)[0])
+ if global_prefix is not None:
+ prefix = f"{global_prefix}{prefix}"
+ all_specs[prefix] = dict(merge_args({"input": input}))
+ else:
+ all_specs["backbone"] = dict(specification_overrides)
+
+ if validate:
+ for prefix, example_spec in all_specs.items():
+ ranked_logger.info(
+ f"Prevalidating design specification for example: {prefix}"
+ )
+ DesignInputSpecification.safe_init(**example_spec)
+
+ return all_specs
diff --git a/models/rfd3/src/rfd3/inference/datasets.py b/models/rfd3/src/rfd3/inference/datasets.py
index 669f84a..e6e720e 100644
--- a/models/rfd3/src/rfd3/inference/datasets.py
+++ b/models/rfd3/src/rfd3/inference/datasets.py
@@ -5,29 +5,17 @@
import json
import os
import textwrap
-import time
from os import PathLike
-from typing import Any, List
+from typing import Any, Dict, List
import yaml
-from atomworks.io.parser import parse_atom_array
-
from atomworks.ml.datasets import MolecularDataset
-from atomworks.ml.transforms.base import Compose, Transform, TransformedDict
-from biotite.structure import BondList
+from atomworks.ml.transforms.base import Compose, Transform
from omegaconf import DictConfig, OmegaConf
-from rfd3.constants import (
- INFERENCE_ANNOTATIONS,
- REQUIRED_INFERENCE_ANNOTATIONS,
-)
-from rfd3.inference.inference_utils import ensure_input_is_abspath
from rfd3.inference.input_parsing import (
- create_atom_array_from_design_specification,
-)
-from rfd3.transforms.conditioning_base import (
- check_has_required_conditioning_annotations,
- convert_existing_annotations_to_bool,
+ DesignInputSpecification,
)
+from rfd3.utils.inference import ensure_input_is_abspath
from torch.utils.data import (
DataLoader,
SequentialSampler,
@@ -49,7 +37,7 @@ class ContigJsonDataset(MolecularDataset):
def __init__(
self,
*,
- data: PathLike | dict,
+ data: PathLike | Dict[str, dict | DesignInputSpecification],
cif_parser_args: dict | None,
transform: Transform | Compose | None,
name: str | None,
@@ -137,7 +125,7 @@ class ContigJsonDataset(MolecularDataset):
def _check_json_keys(self):
"""Check if the JSON keys are valid."""
for k, data in self.data.items():
- if not isinstance(data, dict):
+ if not isinstance(data, (dict, DesignInputSpecification)):
raise ValueError("Each item in the JSON data must be a dictionary.")
@property
@@ -171,77 +159,19 @@ class ContigJsonDataset(MolecularDataset):
spec = self.data[example_id]
# if 'input' in metadata and not abspath, prepend the source json directory to the file path
- spec = ensure_input_is_abspath(spec, self.json_path)
+ if not isinstance(spec, DesignInputSpecification):
+ spec = ensure_input_is_abspath(spec, self.json_path)
+ spec["cif_parser_args"] = self.cif_parser_args
+ spec = DesignInputSpecification(**spec)
- # ... Create atom array with conditioning annotations
- atom_array, spec_dict = create_atom_array_from_design_specification(
- cif_parser_args=self.cif_parser_args, **spec
- )
+ # Create pipeline input
+ data = spec.to_pipeline_input(example_id=example_id)
- # ... Forward into
- data = prepare_pipeline_input_from_atom_array(atom_array)
- data["example_id"] = example_id
-
- # ... Wrap up with additional features
- if "extra" not in spec_dict:
- spec_dict["extra"] = {}
- spec_dict["extra"]["example_id"] = example_id
- data["specification"] = spec_dict
-
- # ... Send through pipeline
+ # Apply transforms and return
data = self.transform(data)
return data
-def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers.base.load_example_from_metadata_row
- atom_array_orig,
-) -> dict:
- """
- Load or create an example from a metadata dictionary.
- If the file path is not provided in the metadata dictionary, create a spoofed CIF file based on the length.
- Args:
- atom_array_orig: Atom array instantiated with conditioning annotations
-
- Returns:
- dict: A dictionary containing the parsed row data and additional loaded CIF data.
- """
- _start_parse_time = time.time()
- # HACK: Set empty bond graph:
- if atom_array_orig.bonds is None:
- atom_array_orig.bonds = BondList(atom_array_orig.array_length())
-
- # Temporary spoof of chain IDs to ensure duplicates aren't dropped:
- result_dict = parse_atom_array(
- atom_array_orig,
- remove_ccds=[],
- fix_arginines=False,
- add_missing_atoms=False,
- extra_fields=INFERENCE_ANNOTATIONS,
- build_assembly=None,
- hydrogen_policy="remove",
- )
- atom_array = result_dict["asym_unit"][0]
-
- # HACK: Set iid information manually
- # We currently do not preserve this information from the input,
- # if you want these we'd need to remove the spoofing here
- check_has_required_conditioning_annotations(
- atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
- )
- atom_array = convert_existing_annotations_to_bool(atom_array)
- atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
- atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
- data = {
- "atom_array": atom_array, # First model
- "chain_info": result_dict["chain_info"],
- "ligand_info": result_dict["ligand_info"],
- "metadata": result_dict["metadata"],
- }
- _stop_parse_time = time.time()
- data = TransformedDict(data)
- return data
-
-
def assemble_distributed_inference_loader_from_json(
*, rank: int, world_size: int, **dataset_kwargs
) -> DataLoader:
diff --git a/models/rfd3/src/rfd3/inference/engine.py b/models/rfd3/src/rfd3/inference/engine.py
deleted file mode 100644
index b6e2f50..0000000
--- a/models/rfd3/src/rfd3/inference/engine.py
+++ /dev/null
@@ -1,448 +0,0 @@
-import json
-import logging
-import os
-import time
-from os import PathLike
-from pathlib import Path
-from typing import Dict
-
-import hydra
-import torch
-import yaml
-from lightning.fabric import seed_everything
-from omegaconf import OmegaConf
-from rfd3.constants import (
- SAVED_CONDITIONING_ANNOTATIONS,
-)
-from rfd3.inference.datasets import (
- assemble_distributed_inference_loader_from_json,
-)
-from rfd3.inference.inference_utils import ensure_input_is_abspath
-from rfd3.inference.input_parsing import InputSpecification
-from toolz import merge_with
-
-from modelhub.common import exists
-from modelhub.inference_engines.af3 import AF3InferenceEngine
-from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
-from modelhub.utils.io import (
- CIF_LIKE_EXTENSIONS,
- extract_example_id_from_path,
- find_files_with_extension,
-)
-from modelhub.utils.logging import print_config_tree
-
-logging.basicConfig(level=logging.INFO)
-ranked_logger = RankedLogger(__name__, rank_zero_only=True)
-
-
-def normalize_inputs(inputs: str | list | None) -> list[str | None]:
- """
- inputs: str | list[str] | None
- - Can be:
- - A single path to a JSON, YAML, or regular input file (cif or pdb)
- - A comma-separated string of paths (e.g. "a.json,b.json")
- - A list of file paths
- - None or an empty list, in which case a dummy input is added (used for e.g. motif-only design)
- - Returns list of paths or [None] if no inputs are provided
- """
- if inputs is None or (isinstance(inputs, list) and len(inputs) == 0):
- inputs = [None]
- elif isinstance(inputs, str):
- inputs = inputs.split(",")
- elif not isinstance(inputs, list):
- raise ValueError(
- f"Invalid input type: {type(inputs)}. Expected str, list, or None.\nInput: {inputs}"
- )
- return inputs
-
-
-def process_input(
- inputs: str | list | None,
- json_keys_subset: str | list | None = None,
- global_prefix: str = None,
- specification_overrides: dict = {},
-) -> Dict[str, dict]:
- """
- inputs: Any -> list[str | None] (see normalize_inputs)
- json_keys_subset: extract only subset of JSON keys. None will keep all keys
- prefix: If provided, prefix all example ids with said prefix
-
- returns: Dictionaries of specifcation args pre-batching:
- {
- 'jsonfile_jsonkey1': {
- **args_from_key1
- },
- 'jsonfile_jsonkey2': {
- **args_from_key2
- }
- }
- """
- merge_args = lambda d: merge_with(lambda x: x[-1], d, specification_overrides) # noqa
- inputs = normalize_inputs(inputs)
-
- # If global_prefix is not provided, then default to using the basename of the JSON or YAML file (when provided)
- if global_prefix is None:
- use_json_basename_prefix = True
- else:
- use_json_basename_prefix = False
-
- # ... Convert all inputs to list of inputs (e.g. if comma-separated)
- if exists(inputs) and "," in inputs:
- inputs = inputs.split(",")
- elif not exists(inputs):
- # If inputs is None or empty, we will create a dummy input
- inputs = []
- inputs = inputs if isinstance(inputs, list) else [inputs]
- if len(inputs) == 0:
- inputs = [None]
-
- # ... Determine prefix of sample to create
- all_specs = {}
- for input in inputs:
- if exists(input) and (input.endswith(".json") or input.endswith(".yaml")):
- # ... Load JSON or YAML file
- with open(input, "r") as f:
- data = json.load(f) if input.endswith(".json") else yaml.safe_load(f)
-
- # ... Apply any global args for this input file
- if "global_args" in data:
- global_args = data.pop("global_args")
- for example in data:
- data[example].update(global_args)
-
- # ... Subset to keys
- if json_keys_subset is not None:
- json_keys_subset = (
- json_keys_subset.split(",")
- if isinstance(json_keys_subset, str)
- else json_keys_subset
- )
- data = {
- example: data[example]
- for example in json_keys_subset
- if example in data
- }
-
- # ... Extract each accumulated example in data.
- for example, args in data.items():
- args = ensure_input_is_abspath(args, input)
- if use_json_basename_prefix:
- name = os.path.splitext(os.path.basename(input))[0]
- prefix = f"{name}_{example}"
- else:
- prefix = f"{global_prefix}{example}"
- args["extra"] = args.get("extra", {}) | {"example": example}
- all_specs[prefix] = dict(merge_args(args))
-
- elif exists(input):
- prefix = os.path.basename(os.path.splitext(input)[0])
- if global_prefix is not None:
- prefix = f"{global_prefix}{prefix}"
- all_specs[prefix] = dict(merge_args({"input": input}))
- else:
- all_specs["backbone"] = specification_overrides
-
- return all_specs
-
-
-class RFD3InferenceEngine(AF3InferenceEngine):
- """Inference engine for RFdiffusion3"""
-
- def __init__(
- self,
- # Required args:
- out_dir: str | PathLike,
- ckpt_path: str | PathLike,
- inputs: str | PathLike,
- json_keys_subset: None | list[str],
- n_batches: int,
- *,
- # Default design specification args:
- specification: dict,
- # Base inference engine args
- diffusion_batch_size: int,
- skip_existing: bool,
- inference_sampler: dict,
- # Structure dumping arguments
- cleanup_guideposts: bool,
- cleanup_virtual_atoms: bool,
- read_sequence_from_sequence_head: bool,
- output_full_json: bool,
- dump_prediction_metadata_json: bool,
- dump_trajectories: bool,
- align_trajectory_structures: bool,
- one_model_per_file: bool,
- global_prefix: str | None,
- ###############################################
- num_nodes: int,
- devices_per_node: int,
- print_config: bool,
- seed: int | None,
- temp_dir=None,
- low_memory_mode: bool = False,
- prevalidate_inputs: bool = True,
- # Atom array instantiation args collapsed into default dict:
- ):
- """
- Design specification args:
- # Main:
- inputs: JSON, YAML, PDB or CIF (comma-separated string or single) containing coordinate data to be parsed or args to override
- length: length of designed structure (int or str like '20-100'). Default: None (specified by contig string)
- contig: string of residues to use as backbone. Default: None (specified by length) Example: '10-20,A11-13,10-20'
- fixed_atoms: Dict[str] of atom names to use as indexed motif atoms with fixed coordinates. Default: None
- unindex: list of residues in input_src to unindex in design. Default: None
-
- redesign_motif_sidechains: bool or str. Specifies which motif residues have fixed sidechains by default. Default: True
- ligand: str. Ligands in input file to include as motif
- atomwise_rasa: Dict[str] of atomwise rasas to use for design. Default: None
- ori_token: list of 3 floats indicating the origin to center the design around. Default: None
- seed: int or None. If None, a random seed will be sampled.
-
- Additional args:
- ckpt_path: Path to checkpoint file
- n_batches: Number of samples to
- """
- if not os.path.isabs(out_dir):
- out_dir = os.path.abspath(out_dir)
- ranked_logger.info("Using absolute path for out_dir: {}".format(out_dir))
-
- # Convert input sources to design specification dictionaries
- inputs = process_input(
- inputs,
- json_keys_subset=json_keys_subset,
- global_prefix=global_prefix,
- specification_overrides=specification,
- ) # any -> Dict[Name: InputSpecification]
- self.design_specifications = {}
- for prefix, example_spec in inputs.items():
- # ... Set example key as the prefix
- if prevalidate_inputs:
- ranked_logger.info(
- f"Prevalidating design specification for example: {prefix}"
- )
- InputSpecification.safe_init(**example_spec)
-
- # ... Create n_batches for example
- for batch_i in range(n_batches):
- # ... Example ID
- example_id = f"{prefix}_{batch_i}"
- self.design_specifications[example_id] = example_spec
-
- ############################################################
- # Feed-forward inputs similar to MH-AF3 inference engine
- ############################################################
-
- # ... set the random seed for reproducibility (and for augmentation, e.g., for antibodies)
- if not exists(seed):
- seed = int(time.time() * 1000) % (2**31)
- ranked_logger.info(f"Seeding everything with seed={seed}...")
- seed_everything(seed, workers=True, verbose=True)
- self.seed = seed
-
- # We only extract the `train_cfg` from the checkpoint initially
- self.load_and_override_ckpt_config(
- ckpt_path=ckpt_path,
- num_nodes=num_nodes,
- devices_per_node=devices_per_node,
- inference_sampler=inference_sampler,
- )
-
- set_accelerator_based_on_availability(self.cfg)
-
- # (b) based on the dataset (we will apply when constructing the pipeline)
- self.dataset_overrides = {
- "diffusion_batch_size": diffusion_batch_size,
- }
-
- # ... instantiate the trainer with the (modified) configuration
- self.trainer = hydra.utils.instantiate(
- self.cfg.trainer,
- _convert_="partial",
- _recursive_=False,
- )
-
- # Set the output directory for the CIF files (e.g., predicted structures)
- self.cif_out_dir = Path(out_dir) if out_dir else Path("./")
-
- # Structure dumping
- self.dump_prediction_metadata_json = dump_prediction_metadata_json
- self.dump_trajectories = dump_trajectories
- self.one_model_per_file = one_model_per_file
- self.align_trajectory_structures = align_trajectory_structures
-
- self.trainer.cleanup_virtual_atoms = cleanup_virtual_atoms
- self.trainer.cleanup_guideposts = cleanup_guideposts
- self.trainer.read_sequence_from_sequence_head = read_sequence_from_sequence_head
- self.trainer.output_full_json = output_full_json
- self.trainer.inference_sampler_overrides = inference_sampler
- self.prediction_extra_fields = SAVED_CONDITIONING_ANNOTATIONS
- self.skip_existing = skip_existing
- self.dump_predictions = True
- self.print_config = print_config
-
- if not cleanup_guideposts:
- ranked_logger.warning(
- "Guideposts will not be cleaned up. This is intended for debugging purposes."
- )
- if not cleanup_virtual_atoms:
- ranked_logger.warning(
- "Virtual atoms will not be cleaned up. Some tools like MPNN may run, but outputs will not be like native structures."
- )
-
- # Check which example ids already exist in the output directory
- self.existing_example_ids = set(
- extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
- for path in find_files_with_extension(out_dir, CIF_LIKE_EXTENSIONS)
- )
- ranked_logger.info(
- f"Found {len(self.existing_example_ids)} existing example IDs in the output directory."
- )
-
- if low_memory_mode:
- ranked_logger.info("Low memory mode enabled.")
- # HACK: Set attribute to the diffusion module
- os.environ["RFD3_LOW_MEMORY_MODE"] = "1"
-
- def load_and_override_ckpt_config(
- self, ckpt_path, num_nodes, devices_per_node, inference_sampler
- ):
- assert exists(ckpt_path), f"Checkpoint path ({ckpt_path}) not provided."
- ranked_logger.info(f"Loading checkpoint from {Path(ckpt_path).resolve()}...")
- self.ckpt_path = ckpt_path
-
- self.cfg = OmegaConf.create(
- torch.load(self.ckpt_path, "cpu", weights_only=False)["train_cfg"]
- )
-
- # Override specific parameters within the Hydra config:
- # (a) based on the input arguments
- self.cfg.trainer.num_nodes = num_nodes
- self.cfg.trainer.devices_per_node = devices_per_node
- for k, v in inference_sampler.items():
- if v is None:
- continue
- setattr(self.cfg.model.net.inference_sampler, k, v)
-
- # Set metrics / callbacks to be null s.t. they aren't loaded
- self.cfg.trainer.metrics = None
-
- # Record the random seed to be dumped in the output JSON
- self.cfg.trainer.seed = self.seed
-
- def example_id_exists(self, example_id, verbose=False):
- # TODO: Move this to another file to standardize better with src
- if not self.one_model_per_file:
- # Check if one file exists
- all_exist = example_id in self.existing_example_ids
- if all_exist and verbose:
- ranked_logger.info(
- f"Model file for example {example_id} already exists in the output directory."
- )
- else:
- all_exist = all(
- [
- (f"{example_id}_model_{i}" in self.existing_example_ids)
- for i in range(self.dataset_overrides["diffusion_batch_size"])
- ]
- )
- if all_exist and verbose:
- ranked_logger.info(
- f"All models for example {example_id} already exist in the output directory."
- )
- return all_exist
-
- def eval(self):
- """
- Run design on a set of specifications
- """
- if self.print_config:
- print_config_tree(
- self.cfg.model, resolve=True, title="INFERENCE MODEL CONFIGURATION"
- )
-
- # ... spawn processes for distributed training, if using multiple GPUs
- ranked_logger.info(
- f"Spawning {self.trainer.fabric.world_size} processes from {self.trainer.fabric.global_rank}..."
- )
-
- # ==============================================================================
- # Construct the model and load the checkpoint
- # ==============================================================================
- self.trainer.initialize_or_update_trainer_state({"train_cfg": self.cfg})
- self.trainer.construct_model()
- self.trainer.load_checkpoint(ckpt_path=self.ckpt_path, is_inference=True)
-
- # Ensure optimizer isn't loaded
- self.trainer.state["optimizer"] = None
- self.trainer.state["train_cfg"].model.optimizer = None
-
- self.trainer.setup_model_optimizers_and_schedulers()
- self.trainer.state["model"].eval()
-
- # ==============================================================================
- # Prepare pipeline and inference loader
- # ==============================================================================
- # TODO: have name be the basename of the JSON or YAML file
- loader = assemble_distributed_inference_loader_from_json(
- # Passed directly to ContigJSONDataset
- data=self.design_specifications,
- transform=self.construct_pipeline(),
- name="inference-dataset",
- cif_parser_args={},
- subset_to_keys=None,
- eval_every_n=1,
- # Sampler args
- world_size=self.trainer.fabric.world_size,
- rank=self.trainer.fabric.global_rank,
- )
- loader = self.trainer.fabric.setup_dataloaders(
- loader,
- use_distributed_sampler=False,
- )
-
- # ==============================================================================
- # Evaluate, using `validation_step``
- # ==============================================================================
-
- for batch_idx, batch in enumerate(loader):
- pipeline_output = batch[0]
- example_id = pipeline_output["example_id"]
-
- if self.skip_existing:
- if self.example_id_exists(example_id, verbose=True):
- ranked_logger.info(
- f"Skipping structure {batch_idx + 1}/{len(loader)}: {example_id} | Already exists."
- )
- continue
- else:
- ranked_logger.info(
- f"Predicting structure {batch_idx + 1}/{len(loader)}: {example_id}"
- )
-
- # Model inference
- t0 = time.time()
- with torch.no_grad():
- pipeline_output = self.trainer.fabric.to_device(pipeline_output)
- output = self.trainer.validation_step(
- batch=pipeline_output,
- batch_idx=0,
- compute_metrics=False,
- )
- t_end = time.time()
-
- # Add additional information to prediction metadata
- for key in output["prediction_metadata"].keys():
- ckpt = Path(self.ckpt_path)
- if ckpt.is_symlink():
- ckpt = ckpt.resolve(strict=True) # follow symlink to target
- output["prediction_metadata"][key]["ckpt_path"] = str(ckpt)
- output["prediction_metadata"][key]["seed"] = self.seed
-
- ranked_logger.info(f"Finished inference batch in {t_end - t0:.2f} seconds.")
- self.save_batch_outputs(
- example_id=example_id,
- network_output=output["network_output"],
- prediction_metadata=output["prediction_metadata"],
- predicted_atom_array_stack=output["predicted_atom_array_stack"],
- pipeline_output=pipeline_output,
- )
diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py
index 0b4b8f4..1122952 100644
--- a/models/rfd3/src/rfd3/inference/input_parsing.py
+++ b/models/rfd3/src/rfd3/inference/input_parsing.py
@@ -1,11 +1,18 @@
import copy
+import json
import logging
+import os
+import time
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
import numpy as np
from atomworks.constants import STANDARD_AA
+from atomworks.io.parser import parse_atom_array
+
+# from atomworks.ml.datasets.datasets import BaseDataset
+from atomworks.ml.transforms.base import TransformedDict
from atomworks.ml.utils.token import (
get_token_starts,
)
@@ -18,20 +25,10 @@ from pydantic import (
model_validator,
)
from rfd3.constants import (
+ INFERENCE_ANNOTATIONS,
REQUIRED_CONDITIONING_ANNOTATION_VALUES,
REQUIRED_INFERENCE_ANNOTATIONS,
)
-from rfd3.inference.components import (
- get_design_pattern_with_constraints,
- get_motif_components_and_breaks,
-)
-from rfd3.inference.inference_utils import (
- extract_ligand_array,
- inference_load_,
- set_com,
- set_common_annotations,
- set_indices,
-)
from rfd3.inference.legacy_input_parsing import (
create_atom_array_from_design_specification_legacy,
)
@@ -48,8 +45,19 @@ from rfd3.transforms.conditioning_base import (
set_default_conditioning_annotations,
)
from rfd3.transforms.util_transforms import assign_types_
+from rfd3.utils.inference import (
+ extract_ligand_array,
+ inference_load_,
+ set_com,
+ set_common_annotations,
+ set_indices,
+)
from modelhub.common import exists
+from modelhub.utils.components import (
+ get_design_pattern_with_constraints,
+ get_motif_components_and_breaks,
+)
from modelhub.utils.ddp import RankedLogger
logging.basicConfig(level=logging.DEBUG)
@@ -62,216 +70,6 @@ logger = RankedLogger(__name__, rank_zero_only=True)
#################################################################################
-@contextmanager
-def validator_context(validator_name: str, data: dict = None):
- """Context manager for validator execution with logging."""
- logger.debug(f"Starting validator: {validator_name}")
- try:
- yield
- logger.debug(f"✓ Completed validator: {validator_name}")
- except Exception as e:
- logger.error(
- f"✗ Failed in validator: {validator_name}\n"
- f" Error: {str(e)}\n"
- f" Error type: {type(e).__name__}"
- )
- raise e
-
-
-def create_diffused_residues(n, additional_annotations=None):
- if n <= 0:
- raise ValueError(f"Negative/null residue count ({n}) not allowed.")
-
- atoms = []
- [
- atoms.extend(
- [
- struc.Atom(
- np.array([0.0, 0.0, 0.0], dtype=np.float32),
- res_name="ALA",
- res_id=idx,
- )
- for _ in range(5)
- ]
- )
- for idx in range(1, n + 1)
- ]
- array = struc.array(atoms)
- array.set_annotation(
- "element", np.array(["N", "C", "C", "O", "C"] * n, dtype=" AtomArray:
- # ... Create list of components
- assert (
- x := (set(list(indexed_tokens.keys()) + list(unindexed_tokens.keys())))
- ).issubset(
- (y := set(components_to_accumulate))
- ), "Unindexed and indexed set {} is not subset of components to accumulate {}".format(
- x, y
- )
- all_tokens = indexed_tokens | unindexed_tokens
- all_annots = []
- [
- all_annots.extend(list(tok.get_annotation_categories()))
- for tok in all_tokens.values()
- ]
- all_annots = set(all_annots)
-
- # ... For-loop accum variables
- unindexed_components_started = (
- False # once one unindexed component is added, stop adding diffused residues
- )
- chain = start_chain
- res_id = start_resid
- molecule_id = 0
-
- # ... Insert contig information one- by one-
- assert len(components_to_accumulate) == len(
- unindexed_breaks
- ), "Mismatch in number of components to accumulate and breaks"
- for component, is_break in zip(components_to_accumulate, unindexed_breaks):
- if component == "/0":
- # Reset iterators on next chain
- chain = chr(ord(chain) + 1)
- molecule_id += 1
- res_id = 1
- continue
-
- # ... Create array to insert
- if str(component)[0].isalpha(): # motif (e.g. "A22")
- n = 1
-
- # ... Fetch the motif residue
- token = all_tokens[component]
-
- # ... Insert breakpoint when break clause is met
- if exists(is_break) and is_break:
- if not unindexed_components_started:
- chain = start_chain
- unindexed_components_started = True
- token.set_annotation(
- "is_motif_atom_unindexed_motif_breakpoint",
- np.ones(token.shape[0], dtype=int),
- )
- else:
- token.set_annotation(
- "is_motif_atom_unindexed_motif_breakpoint",
- np.zeros(token.shape[0], dtype=int),
- )
- else:
- n = int(component)
- # ... Skip if none or unindexed
- if n == 0 or unindexed_components_started:
- res_id += n
- continue
-
- # ... Create diffused residues
- token = create_diffused_residues(n, all_annots)
-
- # ... Set index of insertion
- token = set_indices(
- array=token,
- chain=chain,
- res_id_start=res_id,
- molecule_id=molecule_id,
- component=component,
- )
-
- assert (
- len(get_token_starts(token)) == n
- ), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
-
- # ... Insert & Increment residue ID
- atom_array_accum.append(token)
- res_id += n
-
- # ... Concatenate all components
- atom_array_accum = struc.concatenate(atom_array_accum)
- atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
-
- # Reset res_id for unindexed residues to avoid duplicates
- if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
- atom_array_accum.is_motif_atom_unindexed.astype(bool)
- ):
- max_id = np.max(
- atom_array_accum[
- ~atom_array_accum.is_motif_atom_unindexed.astype(bool)
- ].res_id
- )
- atom_array_accum.res_id[
- atom_array_accum.is_motif_atom_unindexed.astype(bool)
- ] += max_id + 1
-
- # ... Bonds
- if atom_array_accum.bonds is None:
- atom_array_accum.bonds = BondList(atom_array_accum.array_length())
- return atom_array_accum
-
-
class LegacySpecification(BaseModel):
"""Legacy specification for compatibility with legacy input parsing."""
@@ -280,20 +78,34 @@ class LegacySpecification(BaseModel):
extra="allow",
)
- def build(self):
+ def build(self, *args, **kwargs):
"""Build atom array using legacy input parsing."""
atom_array = create_atom_array_from_design_specification_legacy(
design_specification=self.model_dump(),
)
return atom_array, self.model_dump()
+ def to_pipeline_input(self, example_id):
+ atom_array, spec_dict = self.build(return_metadata=True)
+
+ # ... Forward into
+ data = prepare_pipeline_input_from_atom_array(atom_array)
+ data["example_id"] = example_id
+
+ # ... Wrap up with additional features
+ if "extra" not in spec_dict:
+ spec_dict["extra"] = {}
+ spec_dict["extra"]["example_id"] = example_id
+ data["specification"] = spec_dict
+ return data
+
# ========================================================================
# Input specification
# ========================================================================
-class InputSpecification(BaseModel):
+class DesignInputSpecification(BaseModel):
"""Validated and parsed input specification before resolution."""
model_config = ConfigDict(
@@ -932,12 +744,75 @@ class InputSpecification(BaseModel):
else:
return cls(**spec_kwargs)
+ def to_pipeline_input(self, example_id):
+ atom_array, spec_dict = self.build(return_metadata=True)
+
+ # ... Forward into
+ data = prepare_pipeline_input_from_atom_array(atom_array)
+ data["example_id"] = example_id
+
+ # ... Wrap up with additional features
+ if "extra" not in spec_dict:
+ spec_dict["extra"] = {}
+ spec_dict["extra"]["example_id"] = example_id
+ data["specification"] = spec_dict
+ return data
+
# ============================================================================
-# Public API
+# APIs and utils
# ============================================================================
+def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers.base.load_example_from_metadata_row
+ atom_array_orig,
+) -> dict:
+ """
+ Load or create an example from a metadata dictionary.
+ If the file path is not provided in the metadata dictionary, create a spoofed CIF file based on the length.
+ Args:
+ atom_array_orig: Atom array instantiated with conditioning annotations
+
+ Returns:
+ dict: A dictionary containing the parsed row data and additional loaded CIF data.
+ """
+ _start_parse_time = time.time()
+ # HACK: Set empty bond graph:
+ if atom_array_orig.bonds is None:
+ atom_array_orig.bonds = BondList(atom_array_orig.array_length())
+
+ # Temporary spoof of chain IDs to ensure duplicates aren't dropped:
+ result_dict = parse_atom_array(
+ atom_array_orig,
+ remove_ccds=[],
+ fix_arginines=False,
+ add_missing_atoms=False,
+ extra_fields=INFERENCE_ANNOTATIONS,
+ build_assembly=None,
+ hydrogen_policy="remove",
+ )
+ atom_array = result_dict["asym_unit"][0]
+
+ # HACK: Set iid information manually
+ # We currently do not preserve this information from the input,
+ # if you want these we'd need to remove the spoofing here
+ check_has_required_conditioning_annotations(
+ atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
+ )
+ atom_array = convert_existing_annotations_to_bool(atom_array)
+ atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
+ atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
+ data = {
+ "atom_array": atom_array, # First model
+ "chain_info": result_dict["chain_info"],
+ "ligand_info": result_dict["ligand_info"],
+ "metadata": result_dict["metadata"],
+ }
+ _stop_parse_time = time.time()
+ data = TransformedDict(data)
+ return data
+
+
def create_atom_array_from_design_specification(
**spec_kwargs,
) -> tuple[AtomArray, dict]:
@@ -952,6 +827,216 @@ def create_atom_array_from_design_specification(
return atom_array, {}
# Create input specfication and build
- spec = InputSpecification(**spec_kwargs)
+ spec = DesignInputSpecification(**spec_kwargs)
atom_array, metadata = spec.build(return_metadata=True)
return atom_array, metadata
+
+
+@contextmanager
+def validator_context(validator_name: str, data: dict = None):
+ """Context manager for validator execution with logging."""
+ logger.debug(f"Starting validator: {validator_name}")
+ try:
+ yield
+ logger.debug(f"✓ Completed validator: {validator_name}")
+ except Exception as e:
+ logger.error(
+ f"✗ Failed in validator: {validator_name}\n"
+ f" Error: {str(e)}\n"
+ f" Error type: {type(e).__name__}"
+ )
+ raise e
+
+
+def create_diffused_residues(n, additional_annotations=None):
+ if n <= 0:
+ raise ValueError(f"Negative/null residue count ({n}) not allowed.")
+
+ atoms = []
+ [
+ atoms.extend(
+ [
+ struc.Atom(
+ np.array([0.0, 0.0, 0.0], dtype=np.float32),
+ res_name="ALA",
+ res_id=idx,
+ )
+ for _ in range(5)
+ ]
+ )
+ for idx in range(1, n + 1)
+ ]
+ array = struc.array(atoms)
+ array.set_annotation(
+ "element", np.array(["N", "C", "C", "O", "C"] * n, dtype=" AtomArray:
+ # ... Create list of components
+ assert (
+ x := (set(list(indexed_tokens.keys()) + list(unindexed_tokens.keys())))
+ ).issubset(
+ (y := set(components_to_accumulate))
+ ), "Unindexed and indexed set {} is not subset of components to accumulate {}".format(
+ x, y
+ )
+ all_tokens = indexed_tokens | unindexed_tokens
+ all_annots = []
+ [
+ all_annots.extend(list(tok.get_annotation_categories()))
+ for tok in all_tokens.values()
+ ]
+ all_annots = set(all_annots)
+
+ # ... For-loop accum variables
+ unindexed_components_started = (
+ False # once one unindexed component is added, stop adding diffused residues
+ )
+ chain = start_chain
+ res_id = start_resid
+ molecule_id = 0
+
+ # ... Insert contig information one- by one-
+ assert len(components_to_accumulate) == len(
+ unindexed_breaks
+ ), "Mismatch in number of components to accumulate and breaks"
+ for component, is_break in zip(components_to_accumulate, unindexed_breaks):
+ if component == "/0":
+ # Reset iterators on next chain
+ chain = chr(ord(chain) + 1)
+ molecule_id += 1
+ res_id = 1
+ continue
+
+ # ... Create array to insert
+ if str(component)[0].isalpha(): # motif (e.g. "A22")
+ n = 1
+
+ # ... Fetch the motif residue
+ token = all_tokens[component]
+
+ # ... Insert breakpoint when break clause is met
+ if exists(is_break) and is_break:
+ if not unindexed_components_started:
+ chain = start_chain
+ unindexed_components_started = True
+ token.set_annotation(
+ "is_motif_atom_unindexed_motif_breakpoint",
+ np.ones(token.shape[0], dtype=int),
+ )
+ else:
+ token.set_annotation(
+ "is_motif_atom_unindexed_motif_breakpoint",
+ np.zeros(token.shape[0], dtype=int),
+ )
+ else:
+ n = int(component)
+ # ... Skip if none or unindexed
+ if n == 0 or unindexed_components_started:
+ res_id += n
+ continue
+
+ # ... Create diffused residues
+ token = create_diffused_residues(n, all_annots)
+
+ # ... Set index of insertion
+ token = set_indices(
+ array=token,
+ chain=chain,
+ res_id_start=res_id,
+ molecule_id=molecule_id,
+ component=component,
+ )
+
+ assert (
+ len(get_token_starts(token)) == n
+ ), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
+
+ # ... Insert & Increment residue ID
+ atom_array_accum.append(token)
+ res_id += n
+
+ # ... Concatenate all components
+ atom_array_accum = struc.concatenate(atom_array_accum)
+ atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
+
+ # Reset res_id for unindexed residues to avoid duplicates
+ if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
+ atom_array_accum.is_motif_atom_unindexed.astype(bool)
+ ):
+ max_id = np.max(
+ atom_array_accum[
+ ~atom_array_accum.is_motif_atom_unindexed.astype(bool)
+ ].res_id
+ )
+ atom_array_accum.res_id[
+ atom_array_accum.is_motif_atom_unindexed.astype(bool)
+ ] += max_id + 1
+
+ # ... Bonds
+ if atom_array_accum.bonds is None:
+ atom_array_accum.bonds = BondList(atom_array_accum.array_length())
+ return atom_array_accum
diff --git a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py
index 9b9c019..d1a3a4e 100644
--- a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py
+++ b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py
@@ -17,23 +17,6 @@ from rfd3.constants import (
OPTIONAL_CONDITIONING_VALUES,
REQUIRED_INFERENCE_ANNOTATIONS,
)
-from rfd3.inference.components import (
- fetch_mask_from_component,
- fetch_mask_from_idx,
- get_design_pattern_with_constraints,
- get_motif_components_and_breaks,
- get_name_mask,
- split_contig,
-)
-from rfd3.inference.inference_utils import (
- create_cb_atoms,
- create_o_atoms,
- extract_ligand_array,
- inference_load_,
- set_com,
- set_common_annotations,
- set_indices,
-)
from rfd3.inference.symmetry.symmetry_utils import (
center_symmetric_src_atom_array,
make_symmetric_atom_array,
@@ -45,8 +28,25 @@ from rfd3.transforms.conditioning_base import (
set_default_conditioning_annotations,
)
from rfd3.transforms.util_transforms import assign_types_
+from rfd3.utils.inference import (
+ create_cb_atoms,
+ create_o_atoms,
+ extract_ligand_array,
+ inference_load_,
+ set_com,
+ set_common_annotations,
+ set_indices,
+)
from modelhub.common import exists
+from modelhub.utils.components import (
+ fetch_mask_from_component,
+ fetch_mask_from_idx,
+ get_design_pattern_with_constraints,
+ get_motif_components_and_breaks,
+ get_name_mask,
+ split_contig,
+)
from modelhub.utils.ddp import RankedLogger
logging.basicConfig(level=logging.INFO)
diff --git a/models/rfd3/src/rfd3/inference/parsing.py b/models/rfd3/src/rfd3/inference/parsing.py
index 6fc1791..f6fbdf5 100644
--- a/models/rfd3/src/rfd3/inference/parsing.py
+++ b/models/rfd3/src/rfd3/inference/parsing.py
@@ -9,7 +9,8 @@ from pydantic import (
model_serializer,
model_validator,
)
-from rfd3.inference.components import (
+
+from modelhub.utils.components import (
ComponentStr,
fetch_mask_from_idx,
get_name_mask,
diff --git a/models/rfd3/src/rfd3/inference/specification.py b/models/rfd3/src/rfd3/inference/specification.py
deleted file mode 100644
index e69de29..0000000
diff --git a/models/rfd3/src/rfd3/inference/symmetry/contigs.py b/models/rfd3/src/rfd3/inference/symmetry/contigs.py
index 679bf74..7af1e2b 100755
--- a/models/rfd3/src/rfd3/inference/symmetry/contigs.py
+++ b/models/rfd3/src/rfd3/inference/symmetry/contigs.py
@@ -1,5 +1,6 @@
import numpy as np
-from rfd3.inference.components import fetch_mask_from_idx
+
+from modelhub.utils.components import fetch_mask_from_idx
def expand_contig_to_resid_from_string(contig_string):
diff --git a/models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py b/models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py
index cdabd30..03397c2 100644
--- a/models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py
+++ b/models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py
@@ -8,7 +8,6 @@ from pydantic import (
ConfigDict,
Field,
)
-from rfd3.inference.components import fetch_mask_from_component
from rfd3.inference.symmetry.atom_array import (
FIXED_ENTITY_ID,
FIXED_TRANSFORM_ID,
@@ -33,6 +32,7 @@ from rfd3.inference.symmetry.frames import (
)
from rfd3.transforms.conditioning_base import get_motif_features
+from modelhub.utils.components import fetch_mask_from_component
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
diff --git a/models/rfd3/src/rfd3/metrics/losses.py b/models/rfd3/src/rfd3/metrics/losses.py
index 0c9f712..95c49d2 100644
--- a/models/rfd3/src/rfd3/metrics/losses.py
+++ b/models/rfd3/src/rfd3/metrics/losses.py
@@ -4,30 +4,6 @@ import torch.nn as nn
from modelhub.training.checkpoint import activation_checkpointing
-class Loss(nn.Module):
- def __init__(self, **losses):
- super().__init__()
- self.to_compute = []
- for loss_name, loss in losses.items():
- loss_fn = hydra.utils.instantiate(loss)
- print(f"Adding loss {loss_name} to the loss function")
- self.to_compute.append(loss_fn)
-
- def forward(
- self,
- network_input,
- network_output,
- loss_input,
- ):
- loss_dict = {}
- loss = 0
- for loss_fn in self.to_compute:
- loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
- loss += loss_
- loss_dict.update(loss_dict_)
- loss_dict["total_loss"] = loss.detach()
- return loss, loss_dict
-
class SequenceLoss(nn.Module):
def __init__(self, weight, min_t=0, max_t=torch.inf):
super().__init__()
@@ -80,7 +56,7 @@ class SequenceLoss(nn.Module):
return self.weight * token_loss, outs
-class SimpleDiffusionLoss(nn.Module):
+class DiffusionLoss(nn.Module):
def __init__(
self,
*,
@@ -94,8 +70,7 @@ class SimpleDiffusionLoss(nn.Module):
unindexed_t_alpha=1.0,
unindexed_norm_p=1.0,
lp_weight=0.0,
- normalize_virtual_atom_weight=False,
- alpha_normalized_virtual_atom_weight=3.9,
+ **_, # dump args from old configs
):
super().__init__()
self.weight = weight
@@ -108,8 +83,6 @@ class SimpleDiffusionLoss(nn.Module):
self.unindexed_t_alpha = unindexed_t_alpha
self.lp_weight = lp_weight
self.alpha_ligand = alpha_ligand
- self.normalize_virtual_atom_weight = normalize_virtual_atom_weight
- self.alpha_normalized_virtual_atom_weight = alpha_normalized_virtual_atom_weight
self.alpha_polar_residues = alpha_polar_residues
self.get_lambda = (
@@ -123,18 +96,10 @@ class SimpleDiffusionLoss(nn.Module):
crd_mask_L = loss_input["crd_mask_L"] # (D, L)
crd_mask_L = crd_mask_L.unsqueeze(0).expand(D, -1)
tok_idx = network_input["f"]["atom_to_token_map"]
- is_motif_token = network_input["f"]["is_motif_token"] # N
t = network_input["t"] # (D,)
is_original_unindexed_token = loss_input["is_original_unindexed_token"][tok_idx]
is_polar_atom = network_input["f"]["is_polar"][tok_idx]
is_ligand = network_input["f"]["is_ligand"][tok_idx]
-
- # Treat fully fixed atoms as non-lossable atoms to provide stable normalization
- # is_motif_atom_with_fixed_coord = network_input["f"][
- # "is_motif_atom_with_fixed_coord"
- # ]
- # crd_mask_L = crd_mask_L * ~is_motif_atom_with_fixed_coord[None]
-
is_virtual_atom = network_input["f"]["is_virtual"] # L
is_sidechain_atom = network_input["f"]["is_sidechain"] # L
is_sidechain_atom = is_sidechain_atom & ~is_virtual_atom
@@ -148,29 +113,6 @@ class SimpleDiffusionLoss(nn.Module):
# Upweight polar residues
w_L[is_polar_atom] *= self.alpha_polar_residues
-
- if self.normalize_virtual_atom_weight:
- # Divide by the number of virtual atoms within a token
- n_virtual_atoms_per_token = torch.zeros_like(is_motif_token).float()
- n_virtual_atoms_per_token.scatter_add_(
- 0, tok_idx.long(), is_virtual_atom.float()
- )
- n_virtual_atoms_per_token = n_virtual_atoms_per_token.clamp(min=1)
-
- # Also divide by the number of heavy atoms within a token
- n_sc_atoms_per_token = torch.zeros_like(is_motif_token).float()
- n_sc_atoms_per_token.scatter_add_(
- 0, tok_idx.long(), is_sidechain_atom.float()
- )
- n_sc_atoms_per_token = n_sc_atoms_per_token.clamp(min=1)
-
- w_L[is_virtual_atom] /= (
- self.alpha_normalized_virtual_atom_weight
- ) * n_virtual_atoms_per_token[tok_idx][is_virtual_atom]
- w_L[is_sidechain_atom] /= (
- 10 - self.alpha_normalized_virtual_atom_weight
- ) * n_sc_atoms_per_token[tok_idx][is_sidechain_atom]
-
w_L = w_L[None].expand(D, -1) * crd_mask_L
X_gt_L = torch.nan_to_num(loss_input["X_gt_L_in_input_frame"])
diff --git a/models/rfd3/src/rfd3/metrics/sidechain_metrics.py b/models/rfd3/src/rfd3/metrics/sidechain_metrics.py
index 9d9ce16..435eb86 100644
--- a/models/rfd3/src/rfd3/metrics/sidechain_metrics.py
+++ b/models/rfd3/src/rfd3/metrics/sidechain_metrics.py
@@ -3,7 +3,7 @@ import numpy as np
from biotite.structure.info import residue
from scipy.spatial.distance import cdist
-from modelhub.metrics.base import Metric
+from modelhub.metrics.metric import Metric
def collapsing_virtual_atoms_batched(
diff --git a/models/rfd3/src/rfd3/model/RFD3.py b/models/rfd3/src/rfd3/model/RFD3.py
new file mode 100644
index 0000000..ecd6544
--- /dev/null
+++ b/models/rfd3/src/rfd3/model/RFD3.py
@@ -0,0 +1,105 @@
+import os
+
+import hydra
+import torch
+from omegaconf import DictConfig
+from rfd3.model.cfg_utils import (
+ strip_f,
+)
+from rfd3.model.inference_sampler import ConditionalDiffusionSampler
+from rfd3.model.layers.encoders import TokenInitializer
+from torch import nn
+
+from modelhub.utils.ddp import RankedLogger
+
+ranked_logger = RankedLogger(__name__, rank_zero_only=True)
+
+
+class RFD3(nn.Module):
+ """
+ Simplified model for generation
+ This module level serves to wrap the diffusion module of AF3
+ to be roughly equivalent to the AF3 model w/o trunk processing.
+
+ Allows the same sampler to be used
+ """
+
+ def __init__(
+ self,
+ *,
+ # Channel dimensions ('global' features)
+ c_s: int,
+ c_z: int,
+ c_atom: int,
+ c_atompair: int,
+ # Arguments for modules that will be instantiated
+ token_initializer: DictConfig | dict,
+ diffusion_module: DictConfig | dict,
+ inference_sampler: DictConfig | dict,
+ **_,
+ ):
+ super().__init__()
+ # Check for chunked P_LL mode via environment variable
+ use_chunked_pll = os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"
+ ranked_logger.info(f"RFD3 initialized with chunked_pll={use_chunked_pll}")
+
+ # Simple constant-feature initializer
+ self.token_initializer = TokenInitializer(
+ c_s=c_s,
+ c_z=c_z,
+ c_atom=c_atom,
+ c_atompair=c_atompair,
+ use_chunked_pll=use_chunked_pll,
+ **token_initializer,
+ )
+
+ # Diffusion module instantiated to allow for config scripting
+ self.diffusion_module = hydra.utils.instantiate(
+ diffusion_module, c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z
+ )
+
+ self.use_classifier_free_guidance = (
+ inference_sampler["use_classifier_free_guidance"]
+ and inference_sampler["cfg_scale"] != 1.0
+ )
+ self.cfg_features = inference_sampler.pop("cfg_features", [])
+
+ # ... initialize the inference sampler, which performs a full diffusion rollout during inference
+ self.inference_sampler = ConditionalDiffusionSampler(**inference_sampler)
+
+ def forward(
+ self,
+ input: dict,
+ coord_atom_lvl_to_be_noised: torch.Tensor = None,
+ n_cycle=None,
+ **_,
+ ) -> dict:
+ initializer_outputs = self.token_initializer(input["f"])
+
+ if self.training:
+ # Single denoising step
+ return self.diffusion_module(
+ X_noisy_L=input["X_noisy_L"],
+ t=input["t"],
+ f=input["f"],
+ n_recycle=n_cycle,
+ **initializer_outputs,
+ ) # [D, L, 3]
+ else:
+ if self.use_classifier_free_guidance:
+ f_ref = strip_f(input["f"], self.cfg_features)
+ ref_initializer_outputs = self.token_initializer(f_ref)
+ else:
+ f_ref = None
+ ref_initializer_outputs = None
+
+ return self.inference_sampler.sample_diffusion_like_af3(
+ f=input["f"],
+ f_ref=f_ref, # for cfg
+ diffusion_module=self.diffusion_module,
+ diffusion_batch_size=coord_atom_lvl_to_be_noised.shape[0],
+ coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
+ # Forwarded as **kwargs:
+ initializer_outputs=initializer_outputs,
+ ref_initializer_outputs=ref_initializer_outputs, # for cfg
+ )
diff --git a/models/rfd3/src/rfd3/model/rfd3_diffusion_module.py b/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py
similarity index 90%
rename from models/rfd3/src/rfd3/model/rfd3_diffusion_module.py
rename to models/rfd3/src/rfd3/model/RFD3_diffusion_module.py
index ba419fd..f7c4200 100644
--- a/models/rfd3/src/rfd3/model/rfd3_diffusion_module.py
+++ b/models/rfd3/src/rfd3/model/RFD3_diffusion_module.py
@@ -5,11 +5,11 @@ from contextlib import ExitStack
import torch
import torch.nn as nn
-from rfd3.model.block_utils import (
+from rfd3.model.layers.block_utils import (
bucketize_scaled_distogram,
create_attention_indices,
)
-from rfd3.model.blocks import (
+from rfd3.model.layers.blocks import (
CompactStreamingDecoder,
Downcast,
LinearEmbedWithPool,
@@ -17,17 +17,14 @@ from rfd3.model.blocks import (
LocalAtomTransformer,
LocalTokenTransformer,
)
-from rfd3.model.encoders import (
+from rfd3.model.layers.encoders import (
DiffusionTokenEncoder,
)
+from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
-from modelhub.model.AF3_structure import (
+from modelhub.model.layers.blocks import (
FourierEmbedding,
)
-from modelhub.model.layers.af3_diffusion_transformer import (
- DiffusionTransformer,
-)
-from modelhub.model.layers.layer_utils import RMSNorm, linearNoBias
logger = logging.getLogger(__name__)
@@ -53,7 +50,7 @@ class RFD3DiffusionModule(nn.Module):
atom_attention_decoder,
# upcast,
downcast,
- use_local_token_attention=False,
+ use_local_token_attention=True,
**_,
):
super().__init__()
@@ -111,20 +108,12 @@ class RFD3DiffusionModule(nn.Module):
**diffusion_token_encoder,
)
- if not use_local_token_attention:
- self.diffusion_transformer = DiffusionTransformer(
- c_token=c_token,
- c_tokenpair=c_z,
- c_s=c_s,
- **diffusion_transformer,
- )
- else:
- self.diffusion_transformer = LocalTokenTransformer(
- c_token=c_token,
- c_tokenpair=c_z,
- c_s=c_s,
- **diffusion_transformer,
- )
+ self.diffusion_transformer = LocalTokenTransformer(
+ c_token=c_token,
+ c_tokenpair=c_z,
+ c_s=c_s,
+ **diffusion_transformer,
+ )
self.decoder = CompactStreamingDecoder(
c_atom=c_atom,
@@ -342,21 +331,18 @@ class RFD3DiffusionModule(nn.Module):
)
# ... Diffusion transformer
- if not self.use_local_token_attention:
- A_I = self.diffusion_transformer(A_I, S_I, Z_II, Beta_II=None)
- else:
- A_I = self.diffusion_transformer(
- A_I,
- S_I,
- Z_II,
- f=f,
- X_L=(
- X_noisy_L[..., f["is_ca"], :]
- if X_L_self is None
- else X_L_self[..., f["is_ca"], :]
- ),
- full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"),
- )
+ A_I = self.diffusion_transformer(
+ A_I,
+ S_I,
+ Z_II,
+ f=f,
+ X_L=(
+ X_noisy_L[..., f["is_ca"], :]
+ if X_L_self is None
+ else X_L_self[..., f["is_ca"], :]
+ ),
+ full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"),
+ )
# ... Decoder readout
# Check if using chunked P_LL mode
diff --git a/models/rfd3/src/rfd3/model/af3_design.py b/models/rfd3/src/rfd3/model/af3_design.py
deleted file mode 100644
index f553392..0000000
--- a/models/rfd3/src/rfd3/model/af3_design.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from rfd3.model.encoders import SimpleRecycler
-
-from modelhub.model.AF3 import AF3
-
-
-class AF3DesignTrunk(AF3):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.recycler = SimpleRecycler(
- c_s=kwargs["c_s"], c_z=kwargs["c_z"], **kwargs["recycler"]
- )
- self.distogram_head = None
diff --git a/models/rfd3/src/rfd3/model/aa_design.py b/models/rfd3/src/rfd3/model/inference_sampler.py
similarity index 64%
rename from models/rfd3/src/rfd3/model/aa_design.py
rename to models/rfd3/src/rfd3/model/inference_sampler.py
index 5bbaf65..123c3ad 100644
--- a/models/rfd3/src/rfd3/model/aa_design.py
+++ b/models/rfd3/src/rfd3/model/inference_sampler.py
@@ -1,143 +1,20 @@
import inspect
-import os
+from dataclasses import dataclass
+from typing import Any, Literal
-import hydra
import torch
-from beartype.typing import Any
from jaxtyping import Float
-from omegaconf import DictConfig
-from rfd3.inference.symmetry.symmetry_utils import (
- apply_symmetry_to_xyz_atomwise,
-)
-from rfd3.model.cfg_utils import (
- strip_f,
- strip_X,
-)
-from rfd3.model.encoders import TokenInitializer
-from torch import nn
-from modelhub import SWAP_LAYER_NORM_FOR_RMS_NORM
-from modelhub.alignment import weighted_rigid_align
from modelhub.common import exists
-from modelhub.data.rotation_augmentation import (
+from modelhub.utils.ddp import RankedLogger
+from modelhub.utils.rotation_augmentation import (
rot_vec_mul,
uniform_random_rotation,
)
-from modelhub.diffusion_samplers.inference_sampler import SampleDiffusion
-from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
-class RFD3(nn.Module):
- """
- Simplified model for generation
- This module level serves to wrap the diffusion module of AF3
- to be roughly equivalent to the AF3 model w/o trunk processing.
-
- Allows the same sampler to be used
- """
-
- def __init__(
- self,
- *,
- # Channel dimensions ('global' features)
- c_s: int,
- c_z: int,
- c_atom: int,
- c_atompair: int,
- # Arguments for modules that will be instantiated
- token_initializer: DictConfig | dict,
- diffusion_module: DictConfig | dict,
- inference_sampler: DictConfig | dict,
- **_,
- ):
- super().__init__()
-
- # Register whether the model uses RMSNorms or LayerNorms
- self.register_buffer(
- "use_rmsnorm",
- torch.tensor(SWAP_LAYER_NORM_FOR_RMS_NORM, dtype=torch.bool),
- )
-
- # Check for chunked P_LL mode via environment variable
- use_chunked_pll = os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"
- ranked_logger.info(f"RFD3 initialized with chunked_pll={use_chunked_pll}")
-
- # Simple constant-feature initializer
- self.token_initializer = TokenInitializer(
- c_s=c_s,
- c_z=c_z,
- c_atom=c_atom,
- c_atompair=c_atompair,
- use_chunked_pll=use_chunked_pll,
- **token_initializer,
- )
-
- # Diffusion module instantiated to allow for config scripting
- self.diffusion_module = hydra.utils.instantiate(
- diffusion_module, c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z
- )
-
- self.use_classifier_free_guidance = (
- inference_sampler["use_classifier_free_guidance"]
- and inference_sampler["cfg_scale"] != 1.0
- )
- self.cfg_features = inference_sampler.pop("cfg_features", [])
-
- # ... initialize the inference sampler, which performs a full diffusion rollout during inference
- self.inference_sampler = ConditionalDiffusionSampler(**inference_sampler)
-
- def forward(
- self,
- input: dict,
- coord_atom_lvl_to_be_noised: torch.Tensor = None,
- n_cycle=None,
- **_,
- ) -> dict:
- # Assert that the correct swap is used
- if bool(self.use_rmsnorm.item()) != SWAP_LAYER_NORM_FOR_RMS_NORM:
- raise ValueError(
- "Loaded checkpoint with use RMSNorm {} but environment variable set expects {}".format(
- self.use_rmsnorm.item(),
- SWAP_LAYER_NORM_FOR_RMS_NORM,
- )
- + " Set environment variable SWAP_LAYER_NORM_FOR_RMS_NORM to {}".format(
- {True: "1", False: "0"}[self.use_rmsnorm.item()]
- )
- )
-
- initializer_outputs = self.token_initializer(input["f"])
-
- if self.training:
- # Single denoising step
- return self.diffusion_module(
- X_noisy_L=input["X_noisy_L"],
- t=input["t"],
- f=input["f"],
- n_recycle=n_cycle,
- **initializer_outputs,
- ) # [D, L, 3]
- else:
- if self.use_classifier_free_guidance:
- f_ref = strip_f(input["f"], self.cfg_features)
- ref_initializer_outputs = self.token_initializer(f_ref)
- else:
- f_ref = None
- ref_initializer_outputs = None
-
- return self.inference_sampler.sample_diffusion_like_af3(
- f=input["f"],
- f_ref=f_ref, # for cfg
- diffusion_module=self.diffusion_module,
- diffusion_batch_size=input["t"].shape[0],
- coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
- # Forwarded as **kwargs:
- initializer_outputs=initializer_outputs,
- ref_initializer_outputs=ref_initializer_outputs, # for cfg
- )
-
-
def centre_random_augment_around_motif(
X_L: torch.Tensor, # (D, L, 3) noisy diffused coordinates
coord_atom_lvl_to_be_noised: torch.Tensor, # (D, L, 3) original coordinates
@@ -193,159 +70,70 @@ def centre_random_augment_around_motif(
return X_L, R
-class SampleDiffusionWithMotif(SampleDiffusion):
- def __init__(
- self,
- center_option: str = "all",
- move_noise_to_reset_com: bool = False, # Reset the COM of the diffuse region after the re-noising operation in each diffusion step
- s_trans: float = 1.0, # Translational noise scale for augmentation during inference
- s_jitter_origin: float = 0.0, # Random translation of motif at the start of diffusion
- fraction_of_steps_to_fix_motif: float = 0.0, # Fraction of steps to let the model not move the motif. e.g. if we have 10 steps, set this value to 0.2 will make model not move motif for the first 2 steps.
- skip_few_diffusion_steps: bool = False, # Choose to skip some diffusion steps based on the noise scheme
- inference_noise_scaling_factor: float = 1.0,
- # Additional argumnets
- gamma_min2: float = 0.0,
- allow_realignment: bool = False,
- insert_motif_at_end: bool = True,
- use_classifier_free_guidance: bool = False,
- cfg_scale: float = 2.0,
- use_frame_guidance: bool = False, # Use frame guidance to align the virtual atoms to the central atom
- fg_scale: float = 1.5,
- zero_drift_noise: bool = False,
- cfg_t_max: float
- | None = None, # If not None, use classifier-free guidance only for t < cfg_t_max
- **kwargs,
- ):
- self.gamma_min2 = gamma_min2
- self.allow_realignment = allow_realignment
- self.insert_motif_at_end = insert_motif_at_end
- self.use_classifier_free_guidance = use_classifier_free_guidance
- self.cfg_scale = cfg_scale
- self.cfg_t_max = cfg_t_max
+@dataclass(kw_only=True)
+class SampleDiffusionWithMotif:
+ """Diffusion sampler that supports optional motif alignment."""
- self.center_option = center_option
- self.fraction_of_steps_to_fix_motif = fraction_of_steps_to_fix_motif
- self.move_noise_to_reset_com = move_noise_to_reset_com
- self.s_trans = s_trans
- self.skip_few_diffusion_steps = skip_few_diffusion_steps
- self.s_jitter_origin = s_jitter_origin
- self.inference_noise_scaling_factor = inference_noise_scaling_factor
- self.zero_drift_noise = zero_drift_noise
+ # Standard EDM args
+ num_timesteps: int # AF-3: 200
+ min_t: int # AF-3: 0
+ max_t: int # AF-3: 1
+ sigma_data: int # AF-3: 16
+ s_min: float # AF-3: 4e-4
+ s_max: int # AF-3: 160
+ p: int # AF-3: 7
+ gamma_0: float # AF-3: 0.8
+ gamma_min: float # AF-3: 1.0
+ noise_scale: float # AF-3: 1.003
+ step_scale: float # AF-3: 1.5
+ solver: Literal["af3"]
- self.use_frame_guidance = use_frame_guidance
- self.fg_scale = fg_scale
+ # RFD3 / design args
+ center_option: str = "all"
+ s_trans: float = 1.0
+ s_jitter_origin: float = 0.0
+ fraction_of_steps_to_fix_motif: float = 0.0
+ skip_few_diffusion_steps: bool = False
+ allow_realignment: bool = False
+ insert_motif_at_end: bool = True
+ use_classifier_free_guidance: bool = False
+ cfg_scale: float = 2.0
+ cfg_t_max: float | None = None
+ use_frame_guidance: bool = False
+ fg_scale: float = 1.0
- super().__init__(**kwargs)
-
- # TODO: Make this a properly-parametrized function in terms of instance variables provided in the configs
- # For now, it's just hard-coded for early testing
- def modify_noise_schedule(self, noise_schedule: torch.Tensor) -> torch.Tensor:
- """
- Modify the noise schedule to skip more steps at high noise and fewer at low noise.
- """
- mask = torch.ones_like(noise_schedule, dtype=bool)
- mask_len = len(mask)
- mask[: mask_len // 4] = torch.arange(mask_len // 4) % 5 == 0
- mask[mask_len // 4 : mask_len // 2] = (
- torch.arange(mask_len // 4, mask_len // 2) % 3 == 0
- )
- mask[mask_len // 2 : -mask_len // 4] = (
- torch.arange(mask_len // 2, mask_len - mask_len // 4) % 2 == 0
- )
- return noise_schedule[mask]
-
- def _get_initial_structure(
- self,
- c0: torch.Tensor,
- D: int,
- L: int,
- coord_atom_lvl_to_be_noised: torch.Tensor,
- is_motif_atom_with_fixed_coord,
+ def _construct_inference_noise_schedule(
+ self, device: torch.device, partial_t: float = None
) -> torch.Tensor:
- noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
- noise[..., is_motif_atom_with_fixed_coord, :] = 0 # Zero out noise going in
- X_L = noise + coord_atom_lvl_to_be_noised
- return X_L
+ """Constructs a noise schedule for use during inference.
- def _move_noise_to_reset_com(self, X_noisy_L, is_motif_atom_with_fixed_coord):
+ The inference noise schedule is defined in the AF-3 supplement as:
+
+ t_hat = sigma_data * (s_max**(1/p) + t * (s_min**(1/p) - s_max**(1/p)))**p
+
+ Returns:
+ torch.Tensor: A tensor representing the noise schedule `t_hat`.
+
+ Reference:
+ AlphaFold 3 Supplement, Section 3.7.1.
"""
- Reset the COM of the diffuse region after the re-noising operation in each diffusion step.
- """
- if self.center_option == "motif":
- print(
- "Warning: Moving the noise is not relevant when centering on the motif! Will be ignored."
+ # Create a linearly spaced tensor of timesteps between min_t and max_t
+ t = torch.linspace(self.min_t, self.max_t, self.num_timesteps, device=device)
+
+ # Construct the noise schedule, using the formula provided in the reference
+ t_hat = (
+ self.sigma_data
+ * (
+ (self.s_max) ** (1 / self.p)
+ + t * (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p))
)
- elif self.center_option == "diffuse":
- displacement_vec = torch.mean(
- X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :],
- dim=-2,
- keepdim=True,
- ) # (D, 1, 3) - COM of noisy diffused atoms
-
- X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] = (
- X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] - displacement_vec
- )
- else:
- n_diffused = (~is_motif_atom_with_fixed_coord).sum()
- displacement_vec = (
- torch.sum(
- X_noisy_L,
- dim=-2,
- keepdim=True,
- )
- / n_diffused
- )
-
- X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] = (
- X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] - displacement_vec
- )
-
- return X_noisy_L
-
- def _skip_few_diffusion_steps(self, noise_schedule: torch.Tensor) -> torch.Tensor:
- """
- Modify the noise schedule to skip more steps at high noise and fewer at low noise.
- i.e. When the noise is high (first few diffusion steps), skip more steps;
- When the noise is lower, skip fewer steps;
- When the noise is low, keep all the steps.
- """
- mask = torch.ones_like(noise_schedule, dtype=bool)
- mask_len = len(mask)
- mask[: mask_len // 4] = torch.arange(mask_len // 4) % 5 == 0
- mask[mask_len // 4 : mask_len // 2] = (
- torch.arange(mask_len // 4, mask_len // 2) % 3 == 0
- )
- mask[mask_len // 2 : -mask_len // 4] = (
- torch.arange(mask_len // 2, mask_len - mask_len // 4) % 2 == 0
- )
- return noise_schedule[mask]
-
- def sample_diffusion_like_af3(
- self,
- *,
- f: dict[str, Any],
- diffusion_module: torch.nn.Module,
- diffusion_batch_size: int,
- coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
- initializer_outputs,
- ref_initializer_outputs: dict[str, Any] | None,
- f_ref: dict[str, Any] | None,
- ) -> dict[str, Any]:
- # Motif setup to recenter the motif at every step
- is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
-
- # Book-keeping
- noise_schedule = self._construct_inference_noise_schedule(
- device=coord_atom_lvl_to_be_noised.device
+ ** self.p
)
- # Choose to adjust the noise schedule
- if self.skip_few_diffusion_steps:
- noise_schedule = self._skip_few_diffusion_steps(noise_schedule)
-
- if "partial_t" in f:
+ if partial_t is not None:
# For now, partial t is a global parameter
- partial_t = f["partial_t"].mean()
+ partial_t = float(partial_t.mean())
+ noise_schedule = t_hat
ranked_logger.info("Using partial diffusion with t={}".format(partial_t))
# Debug the noise schedule filtering
@@ -382,11 +170,44 @@ class SampleDiffusionWithMotif(SampleDiffusion):
f"Using fallback: final step with t={noise_schedule[0].item():.6f}"
)
+ return t_hat
+
+ def _get_initial_structure(
+ self,
+ c0: torch.Tensor,
+ D: int,
+ L: int,
+ coord_atom_lvl_to_be_noised: torch.Tensor,
+ is_motif_atom_with_fixed_coord,
+ ) -> torch.Tensor:
+ noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
+ noise[..., is_motif_atom_with_fixed_coord, :] = 0 # Zero out noise going in
+ X_L = noise + coord_atom_lvl_to_be_noised
+ return X_L
+
+ def sample_diffusion_like_af3(
+ self,
+ *,
+ f: dict[str, Any],
+ diffusion_module: torch.nn.Module,
+ diffusion_batch_size: int,
+ coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
+ initializer_outputs,
+ ref_initializer_outputs: dict[str, Any] | None,
+ f_ref: dict[str, Any] | None,
+ ) -> dict[str, Any]:
+ # Motif setup to recenter the motif at every step
+ is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
+
+ # Book-keeping
+ noise_schedule = self._construct_inference_noise_schedule(
+ device=coord_atom_lvl_to_be_noised.device,
+ partial_t=f.get("partial_t", None),
+ )
+
L = f["ref_element"].shape[0]
D = diffusion_batch_size
- noise_schedule = noise_schedule * self.inference_noise_scaling_factor
-
X_L = self._get_initial_structure(
c0=noise_schedule[0],
D=D,
@@ -433,7 +254,7 @@ class SampleDiffusionWithMotif(SampleDiffusion):
# Update gamma & step scale
gamma = self.gamma_0 if c_t > self.gamma_min else 0
- step_scale = self.step_scale if c_t > self.gamma_min2 else 3.0
+ step_scale = self.step_scale
# Compute the value of t_hat
t_hat = c_t_minus_1 * (gamma + 1)
@@ -444,19 +265,11 @@ class SampleDiffusionWithMotif(SampleDiffusion):
* torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
* torch.normal(mean=0.0, std=1.0, size=X_L.shape, device=X_L.device)
)
- if self.zero_drift_noise:
- epsilon_L = epsilon_L - torch.mean(epsilon_L, dim=-2, keepdim=True)
epsilon_L[..., is_motif_atom_with_fixed_coord, :] = (
0 # No noise injection for fixed atoms
)
X_noisy_L = X_L + epsilon_L
- # Adjustg the center of mass
- if self.move_noise_to_reset_com:
- X_noisy_L = self._move_noise_to_reset_com(
- X_noisy_L, is_motif_atom_with_fixed_coord
- )
-
# Denoise the coordinates
# Handle chunked mode vs standard mode
if "chunked_pairwise_embedder" in initializer_outputs:
@@ -624,53 +437,10 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
# Book-keeping
noise_schedule = self._construct_inference_noise_schedule(
- device=coord_atom_lvl_to_be_noised.device
+ device=coord_atom_lvl_to_be_noised.device,
+ partial_t=f.get("partial_t", None),
)
- # Handle partial_t for symmetry sampler (same as regular sampler)
- if "partial_t" in f:
- # For now, partial t is a global parameter
- partial_t = f["partial_t"].mean()
- ranked_logger.info(
- "Symmetry sampler: Using partial diffusion with t={}".format(partial_t)
- )
-
- # Debug the noise schedule filtering
- original_schedule_len = len(noise_schedule)
- original_max = noise_schedule.max().item()
- original_min = noise_schedule.min().item()
-
- noise_schedule = noise_schedule[noise_schedule <= partial_t]
-
- new_schedule_len = len(noise_schedule)
- if new_schedule_len > 0:
- new_max = noise_schedule.max().item()
- new_min = noise_schedule.min().item()
- ranked_logger.info(
- f"Symmetry noise schedule: {original_schedule_len} → {new_schedule_len} steps"
- )
- ranked_logger.info(
- f"Symmetry original range: [{original_min:.3f}, {original_max:.3f}]"
- )
- ranked_logger.info(
- f"Symmetry filtered range: [{new_min:.3f}, {new_max:.3f}]"
- )
- else:
- ranked_logger.warning(
- f"Symmetry sampler: No noise schedule steps found with t <= {partial_t}!"
- )
- ranked_logger.info(
- f"Symmetry original schedule range: [{original_min:.3f}, {original_max:.3f}]"
- )
- # Fallback to smallest available step
- noise_schedule_original = self._construct_inference_noise_schedule(
- device=coord_atom_lvl_to_be_noised.device
- )
- noise_schedule = noise_schedule_original[-1:] # Just use the final step
- ranked_logger.info(
- f"Symmetry using fallback: final step with t={noise_schedule[0].item():.6f}"
- )
-
L = f["ref_element"].shape[0]
D = diffusion_batch_size
X_L = self._get_initial_structure(
@@ -711,7 +481,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
# Update gamma & step scale
gamma = self.gamma_0 if c_t > self.gamma_min else 0
- step_scale = self.step_scale if c_t > self.gamma_min2 else 1.05
+ step_scale = self.step_scale
# Compute the value of t_hat
t_hat = c_t_minus_1 * (gamma + 1)
diff --git a/models/rfd3/src/rfd3/model/attention.py b/models/rfd3/src/rfd3/model/layers/attention.py
similarity index 99%
rename from models/rfd3/src/rfd3/model/attention.py
rename to models/rfd3/src/rfd3/model/layers/attention.py
index ced003b..a3ead49 100644
--- a/models/rfd3/src/rfd3/model/attention.py
+++ b/models/rfd3/src/rfd3/model/layers/attention.py
@@ -6,18 +6,18 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from opt_einsum import contract as einsum
-from rfd3.model.block_utils import (
+from rfd3.model.layers.block_utils import (
create_attention_indices,
indices_to_mask,
)
-
-from modelhub.common import exists
-from modelhub.model.layers.layer_utils import (
+from rfd3.model.layers.layer_utils import (
AdaLN,
LinearBiasInit,
RMSNorm,
linearNoBias,
)
+
+from modelhub.common import exists
from modelhub.training.checkpoint import activation_checkpointing
from modelhub.utils.ddp import RankedLogger
@@ -417,7 +417,6 @@ def sparse_pairbias_attention(
if B.ndim == 3:
B_gathered = B[query_idx, indices, :] # (D, L, k, H)
elif B.ndim == 4: # (D, L, L, H)
- # import ipdb; ipdb.set_trace()
B_gathered = B[batch_idx, query_idx, indices, :] # (D, L, k, H)
else:
assert B.shape == (D, L, k, H), "B must be batched with shape (D, L, k, H)"
diff --git a/models/rfd3/src/rfd3/model/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py
similarity index 100%
rename from models/rfd3/src/rfd3/model/block_utils.py
rename to models/rfd3/src/rfd3/model/layers/block_utils.py
diff --git a/models/rfd3/src/rfd3/model/blocks.py b/models/rfd3/src/rfd3/model/layers/blocks.py
similarity index 94%
rename from models/rfd3/src/rfd3/model/blocks.py
rename to models/rfd3/src/rfd3/model/layers/blocks.py
index dbd0334..51a678b 100644
--- a/models/rfd3/src/rfd3/model/blocks.py
+++ b/models/rfd3/src/rfd3/model/layers/blocks.py
@@ -6,35 +6,62 @@ import torch.nn as nn
import torch.nn.functional as F
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from einops import rearrange
-from rfd3.model.attention import (
+from rfd3.model.layers.attention import (
GatedCrossAttention,
LocalAttentionPairBias,
)
-from rfd3.model.block_utils import (
+from rfd3.model.layers.block_utils import (
build_valid_mask,
create_attention_indices,
group_atoms,
ungroup_atoms,
)
-from torch.nn.functional import one_hot
-
-from modelhub import DISABLE_CHECKPOINTING
-from modelhub.common import exists
-from modelhub.model.layers.af3_diffusion_transformer import (
- ConditionedTransitionBlock,
-)
-from modelhub.model.layers.layer_utils import (
+from rfd3.model.layers.layer_utils import (
+ AdaLN,
EmbeddingLayer,
+ LinearBiasInit,
RMSNorm,
Transition,
collapse,
linearNoBias,
)
-from modelhub.model.layers.pairformer_layers import PairformerBlock
+from rfd3.model.layers.pairformer_layers import PairformerBlock
+from torch.nn.functional import one_hot
+
+from modelhub import DISABLE_CHECKPOINTING
+from modelhub.common import exists
logger = logging.getLogger(__name__)
+# SwiGLU transition block with adaptive layernorm
+class ConditionedTransitionBlock(nn.Module):
+ def __init__(self, c_token, c_s, n=2):
+ super().__init__()
+ self.ada_ln = AdaLN(c_a=c_token, c_s=c_s)
+ self.linear_1 = linearNoBias(c_token, c_token * n)
+ self.linear_2 = linearNoBias(c_token, c_token * n)
+ self.linear_output_project = nn.Sequential(
+ LinearBiasInit(c_s, c_token, biasinit=-2.0),
+ nn.Sigmoid(),
+ )
+ self.linear_3 = linearNoBias(c_token * n, c_token)
+
+ def forward(
+ self,
+ Ai, # [B, I, C_token]
+ Si, # [B, I, C_token]
+ ):
+ Ai = self.ada_ln(Ai, Si)
+ # BUG: This is not the correct implementation of SwiGLU
+ # Bi = torch.sigmoid(self.linear_1(Ai)) * self.linear_2(Ai)
+ # FIX: This is the correct implementation of SwiGLU
+ Bi = torch.nn.functional.silu(self.linear_1(Ai)) * self.linear_2(Ai)
+
+ # Output projection (from adaLN-Zero)
+ return self.linear_output_project(Si) * self.linear_3(Bi)
+
+
class PositionPairDistEmbedder(nn.Module):
def __init__(self, c_atompair, embed_frame=True):
super().__init__()
diff --git a/models/rfd3/src/rfd3/model/chunked_pairwise.py b/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py
similarity index 99%
rename from models/rfd3/src/rfd3/model/chunked_pairwise.py
rename to models/rfd3/src/rfd3/model/layers/chunked_pairwise.py
index fb028a4..9f14d44 100644
--- a/models/rfd3/src/rfd3/model/chunked_pairwise.py
+++ b/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py
@@ -10,8 +10,7 @@ from typing import Optional
import torch
import torch.nn as nn
-
-from modelhub.model.layers.layer_utils import RMSNorm, linearNoBias
+from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
class ChunkedPositionPairDistEmbedder(nn.Module):
diff --git a/models/rfd3/src/rfd3/model/encoders.py b/models/rfd3/src/rfd3/model/layers/encoders.py
similarity index 98%
rename from models/rfd3/src/rfd3/model/encoders.py
rename to models/rfd3/src/rfd3/model/layers/encoders.py
index cdd7455..d24f8e2 100644
--- a/models/rfd3/src/rfd3/model/encoders.py
+++ b/models/rfd3/src/rfd3/model/layers/encoders.py
@@ -3,11 +3,11 @@ import logging
import torch
import torch.nn as nn
-from rfd3.model.block_utils import (
+from rfd3.model.layers.block_utils import (
bucketize_scaled_distogram,
pairwise_mean_pool,
)
-from rfd3.model.blocks import (
+from rfd3.model.layers.blocks import (
Downcast,
LocalAtomTransformer,
OneDFeatureEmbedder,
@@ -15,19 +15,19 @@ from rfd3.model.blocks import (
RelativePositionEncodingWithIndexRemoval,
SinusoidalDistEmbed,
)
-from rfd3.model.chunked_pairwise import (
+from rfd3.model.layers.chunked_pairwise import (
ChunkedPairwiseEmbedder,
ChunkedPositionPairDistEmbedder,
ChunkedSinusoidalDistEmbed,
)
-
-from modelhub.common import exists
-from modelhub.model.layers.layer_utils import (
+from rfd3.model.layers.layer_utils import (
RMSNorm,
Transition,
linearNoBias,
)
-from modelhub.model.layers.pairformer_layers import PairformerBlock
+from rfd3.model.layers.pairformer_layers import PairformerBlock
+
+from modelhub.common import exists
from modelhub.training.checkpoint import activation_checkpointing
logger = logging.getLogger(__name__)
diff --git a/models/rfd3/src/rfd3/model/layers/layer_utils.py b/models/rfd3/src/rfd3/model/layers/layer_utils.py
new file mode 100644
index 0000000..598cb38
--- /dev/null
+++ b/models/rfd3/src/rfd3/model/layers/layer_utils.py
@@ -0,0 +1,197 @@
+import math
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn.functional import silu
+
+from modelhub.training.checkpoint import activation_checkpointing
+from modelhub.utils.ddp import RankedLogger
+
+ranked_logger = RankedLogger(__name__, rank_zero_only=True)
+try:
+ from apex.normalization.fused_layer_norm import FusedRMSNorm
+
+ ranked_logger.info("Fused RMSNorm enabled!")
+ RMSNorm_ = FusedRMSNorm
+except (ImportError, ModuleNotFoundError):
+ ranked_logger.warning(
+ "Using nn.RMSNorm instead of apex.normalization.fused_layer_norm.FusedRMSNorm."
+ "Ensure you're using the correct apptainer"
+ )
+ RMSNorm_ = nn.RMSNorm
+
+
+# Allow bias=False to be passed for RMSNorm
+def RMSNorm(*args, **kwargs):
+ if "bias" in kwargs:
+ kwargs.pop("bias")
+ return RMSNorm_(*args, **kwargs)
+
+
+SWAP_LAYER_NORM_FOR_RMS_NORM = True
+RMSNorm = RMSNorm if SWAP_LAYER_NORM_FOR_RMS_NORM else nn.LayerNorm
+linearNoBias = partial(torch.nn.Linear, bias=False)
+
+
+class EmbeddingLayer(nn.Linear):
+ """
+ Specialized linear layer for correct weight initialization for embedding layers.
+
+ Embedding layers are functionally a multiplication of an N channel input by an NxC weight matrix to produce an
+ embedding of length C. However, we compute the components separately with a ModuleDict, then sum at the end, for
+ embedding reusability and interoperability purposes.
+
+ This layer uses Xavier initialization as described in [1]_.
+
+ References
+ ----------
+ .. [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty
+ of training deep feedforward neural networks." (2010)
+ http://proceedings.mlr.press/v9/glorot10a.html
+ """
+
+ def __init__(
+ self,
+ this_in_features,
+ total_embedding_features,
+ out_features,
+ device=None,
+ dtype=None,
+ ):
+ self.total_embedding_features = total_embedding_features
+ self.out_features = out_features
+ super().__init__(
+ this_in_features, out_features, bias=False, device=device, dtype=dtype
+ )
+ self.reset_parameters()
+
+ def reset_parameters(self, **kwargs):
+ super().reset_parameters()
+ a = math.sqrt(6.0 / float(self.total_embedding_features + self.out_features))
+ nn.init._no_grad_uniform_(self.weight, -a, a)
+
+
+def collapse(x, L):
+ return x.reshape((L, x.numel() // L))
+
+
+class MultiDimLinear(nn.Linear):
+ def __init__(self, in_features, out_shape, norm=False, **kwargs):
+ self.out_shape = out_shape
+ out_features = np.prod(out_shape)
+ super().__init__(in_features, out_features, **kwargs)
+ if norm:
+ self.ln = RMSNorm((out_features,))
+ self.use_ln = True
+ else:
+ self.use_ln = False
+ self.reset_parameters()
+
+ def reset_parameters(self, **kwargs) -> None:
+ super().reset_parameters()
+ nn.init.xavier_uniform_(self.weight)
+
+ def forward(self, x):
+ out = super().forward(x)
+ if self.use_ln:
+ out = self.ln(out)
+ return out.reshape(x.shape[:-1] + self.out_shape)
+
+
+class LinearBiasInit(nn.Linear):
+ def __init__(self, *args, biasinit, **kwargs):
+ assert biasinit == -2.0 # Sanity check
+ self.biasinit = biasinit
+ super().__init__(*args, **kwargs)
+
+ def reset_parameters(self) -> None:
+ super().reset_parameters()
+ self.bias.data.fill_(self.biasinit)
+
+
+class Transition(nn.Module):
+ def __init__(self, n, c):
+ super().__init__()
+ self.layer_norm_1 = RMSNorm(c)
+ self.linear_1 = linearNoBias(c, n * c)
+ self.linear_2 = linearNoBias(c, n * c)
+ self.linear_3 = linearNoBias(n * c, c)
+
+ @activation_checkpointing
+ def forward(
+ self,
+ X,
+ ):
+ X = self.layer_norm_1(X)
+ A = self.linear_1(X)
+ B = self.linear_2(X)
+ X = self.linear_3(silu(A) * B)
+ return X
+
+
+class AdaLN(nn.Module):
+ def __init__(self, c_a, c_s, n=2):
+ super().__init__()
+ self.ln_a = RMSNorm(normalized_shape=(c_a,), elementwise_affine=False)
+ self.ln_s = RMSNorm(normalized_shape=(c_s,), bias=False)
+ self.to_gain = nn.Sequential(
+ nn.Linear(c_s, c_a),
+ nn.Sigmoid(),
+ )
+ self.to_bias = linearNoBias(c_s, c_a)
+
+ def forward(
+ self,
+ Ai, # [B, I, C_a]
+ Si, # [B, I, C_s]
+ ):
+ """
+ Output:
+ [B, I, C_a]
+ """
+ Ai = self.ln_a(Ai)
+ Si = self.ln_s(Si)
+ return self.to_gain(Si) * Ai + self.to_bias(Si)
+
+
+def create_batch_dimension_if_not_present(batched_n_dim):
+ """
+ Decorator for adapting a function which expects batched arguments with ndim `batched_n_dim` also
+ accept unbatched arguments.
+ """
+
+ def wrap(f):
+ def _wrap(arg):
+ inserted_batch_dim = False
+ if arg.ndim == batched_n_dim - 1:
+ arg = arg[None]
+ inserted_batch_dim = True
+ elif arg.ndim == batched_n_dim:
+ pass
+ else:
+ raise Exception(
+ f"arg must have {batched_n_dim - 1} or {batched_n_dim} dimensions, got shape {arg.shape=}"
+ )
+ o = f(arg)
+
+ if inserted_batch_dim:
+ assert o.shape[0] == 1, f"{o.shape=}[0] != 1"
+ return o[0]
+ return o
+
+ return _wrap
+
+ return wrap
+
+
+def unpack_args_for_checkpointing(arg_names):
+ def wrap(f):
+ def _wrap(*args):
+ f = args[0]
+ return f(**dict(zip(arg_names, args)))
+
+ return _wrap
+
+ return wrap
diff --git a/models/rfd3/src/rfd3/model/layers/pairformer_layers.py b/models/rfd3/src/rfd3/model/layers/pairformer_layers.py
new file mode 100644
index 0000000..c036f34
--- /dev/null
+++ b/models/rfd3/src/rfd3/model/layers/pairformer_layers.py
@@ -0,0 +1,128 @@
+import torch
+from rfd3.model.layers.layer_utils import (
+ MultiDimLinear,
+ RMSNorm,
+ Transition,
+ linearNoBias,
+)
+from torch import nn
+
+from modelhub.training.checkpoint import activation_checkpointing
+from modelhub.utils.torch import device_of
+
+
+class AttentionPairBiasPairformerDeepspeed(nn.Module):
+ def __init__(self, c_a, c_s, c_pair, n_head, kq_norm=False):
+ super().__init__()
+ self.n_head = n_head
+ self.c_a = c_a
+ self.c_pair = c_pair
+ self.c = c_a // n_head
+
+ self.to_q = MultiDimLinear(c_a, (n_head, self.c))
+ self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
+ self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
+ self.to_b = linearNoBias(c_pair, n_head)
+ self.to_g = nn.Sequential(
+ MultiDimLinear(c_a, (n_head, self.c), bias=False),
+ nn.Sigmoid(),
+ )
+ self.to_a = linearNoBias(c_a, c_a)
+ # self.linear_output_project = nn.Sequential(
+ # LinearBiasInit(c_s, c_a, biasinit=-2.),
+ # nn.Sigmoid(),
+ # )
+ self.ln_0 = RMSNorm((c_pair,))
+ # self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
+ self.ln_1 = RMSNorm((c_a,))
+ self.use_deepspeed_evo = False
+ self.force_bfloat16 = True
+
+ def forward(
+ self,
+ A_I, # [I, C_a]
+ S_I, # [I, C_a] | None
+ Z_II, # [I, I, C_z]
+ Beta_II=None, # [I, I]
+ ):
+ # Input projections
+ assert S_I is None
+ A_I = self.ln_1(A_I)
+
+ if self.use_deepspeed_evo or self.force_bfloat16:
+ A_I = A_I.to(torch.bfloat16)
+
+ Q_IH = self.to_q(A_I) # / np.sqrt(self.c)
+ K_IH = self.to_k(A_I)
+ V_IH = self.to_v(A_I)
+ B_IIH = self.to_b(self.ln_0(Z_II)) + Beta_II[..., None]
+ G_IH = self.to_g(A_I)
+
+ B, L = B_IIH.shape[:2]
+
+ if not self.use_deepspeed_evo or L <= 24:
+ Q_IH = Q_IH / torch.sqrt(
+ torch.tensor(self.c).to(Q_IH.device, torch.bfloat16)
+ )
+ # Attention
+ A_IIH = torch.softmax(
+ torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2
+ ) # softmax over j
+ ## G_IH: [I, H, C]
+ ## A_IIH: [I, I, H]
+ ## V_IH: [I, H, C]
+ A_I = torch.einsum("...ijh,...jhc->...ihc", A_IIH, V_IH)
+ A_I = G_IH * A_I # [B, I, H, C]
+ A_I = A_I.flatten(start_dim=-2) # [B, I, Ca]
+ else:
+ raise NotImplementedError
+
+ A_I = self.to_a(A_I)
+
+ return A_I
+
+
+class PairformerBlock(nn.Module):
+ """
+ Attempt to replicate AF3 architecture from scratch.
+ """
+
+ def __init__(
+ self,
+ c_s,
+ c_z,
+ attention_pair_bias,
+ p_drop=0.1,
+ triangle_multiplication=None,
+ triangle_attention=None,
+ n_transition=4,
+ use_deepspeed_evo=True,
+ use_triangle_mult=False,
+ use_triangle_attn=False,
+ ):
+ super().__init__()
+
+ # self.drop_row = Dropout(broadcast_dim=-2, p_drop=p_drop)
+ # self.drop_col = Dropout(broadcast_dim=-3, p_drop=p_drop)
+
+ self.z_transition = Transition(c=c_z, n=n_transition)
+
+ if c_s > 0:
+ self.s_transition = Transition(c=c_s, n=n_transition)
+
+ self.attention_pair_bias = AttentionPairBiasPairformerDeepspeed(
+ c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias
+ )
+
+ @activation_checkpointing
+ def forward(self, S_I, Z_II):
+ with torch.amp.autocast(
+ device_type=device_of(self).type, enabled=True, dtype=torch.bfloat16
+ ):
+ Z_II = Z_II + self.z_transition(Z_II)
+ if S_I is not None:
+ S_I = S_I + self.attention_pair_bias(
+ S_I, None, Z_II, Beta_II=torch.tensor([0.0], device=Z_II.device)
+ )
+ S_I = S_I + self.s_transition(S_I)
+ return S_I, Z_II
diff --git a/models/rfd3/src/rfd3/run_inference.py b/models/rfd3/src/rfd3/run_inference.py
index 98554db..f4d98f6 100644
--- a/models/rfd3/src/rfd3/run_inference.py
+++ b/models/rfd3/src/rfd3/run_inference.py
@@ -1,13 +1,12 @@
-#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rfd3_exec.sh" "$0" "$@"'
+#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
import os
-import tempfile
-from pathlib import Path
import hydra
import rootutils
+from dotenv import load_dotenv
from hydra.utils import instantiate
-from omegaconf import DictConfig
+from omegaconf import DictConfig, OmegaConf
from modelhub.utils.logging import suppress_warnings
@@ -15,6 +14,8 @@ from modelhub.utils.logging import suppress_warnings
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+load_dotenv(override=True)
+
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
@@ -27,16 +28,18 @@ _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
def run_inference(cfg: DictConfig) -> None:
"""Execute the specified inference pipeline"""
- with tempfile.TemporaryDirectory() as temp_dir:
- temp_dir = Path(temp_dir)
- temp_dir.mkdir(parents=True, exist_ok=True)
- import ipdb
+ run_params_set = {"inputs", "n_batches", "out_dir"}
+ run_params = {k: v for k, v in cfg.items() if k in run_params_set}
- ipdb.set_trace()
+ # Create __init__ args by filtering for all configs not in run_params
+ cfg_dict = OmegaConf.to_container(cfg, resolve=True)
+ init_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in run_params_set}
+ init_cfg = OmegaConf.create(init_cfg_dict)
+ inference_engine = instantiate(init_cfg, _convert_="partial", _recursive_=False)
- inference_engine = instantiate(cfg, temp_dir=temp_dir, _convert_="partial")
- with suppress_warnings(is_inference=True):
- inference_engine.eval()
+ # # Run inference
+ with suppress_warnings(is_inference=True):
+ inference_engine.run(**run_params)
if __name__ == "__main__":
diff --git a/models/rfd3/src/rfd3/testing/debug.py b/models/rfd3/src/rfd3/testing/debug.py
index 0ba98b8..2275d9b 100755
--- a/models/rfd3/src/rfd3/testing/debug.py
+++ b/models/rfd3/src/rfd3/testing/debug.py
@@ -4,7 +4,6 @@ import logging
import os
import sys
import time
-from pathlib import Path
import hydra
import ipdb # noqa: F401
@@ -38,7 +37,7 @@ print(f"Project root: {os.environ.get('PROJECT_ROOT', '../..')}")
is_inference = True
-outdir = Path("/home/jbutch/Projects/HT25/af3/modelhub_refactor/rfd3/tests/outs")
+# outdir = Path("/home/jbutch/Projects/HT25/af3/modelhub_refactor/rfd3/tests/outs")
args = TEST_JSON_DATA["1qys-1-refactored"]
input = instantiate_example(args, is_inference=is_inference)
diff --git a/models/rfd3/src/rfd3/testing/debug_utils.py b/models/rfd3/src/rfd3/testing/debug_utils.py
index 0d7170c..7952791 100644
--- a/models/rfd3/src/rfd3/testing/debug_utils.py
+++ b/models/rfd3/src/rfd3/testing/debug_utils.py
@@ -3,7 +3,7 @@ from atomworks.common import sum_string_arrays
from atomworks.io.utils.io_utils import to_cif_file
from atomworks.ml.transforms.center_random_augmentation import CenterRandomAugmentation
from biotite.structure import AtomArrayStack
-from rfd3.trainer.rfd3_trainer import _reassign_unindexed_token_chains
+from rfd3.trainer.rfd3 import _reassign_unindexed_token_chains
from rfd3.transforms.design_transforms import (
MotifCenterRandomAugmentation,
)
diff --git a/models/rfd3/src/rfd3/testing/testing_utils.py b/models/rfd3/src/rfd3/testing/testing_utils.py
index e70d73f..3ff9428 100644
--- a/models/rfd3/src/rfd3/testing/testing_utils.py
+++ b/models/rfd3/src/rfd3/testing/testing_utils.py
@@ -1,6 +1,5 @@
import copy
import getpass
-import inspect
import json
import logging
import os
@@ -23,14 +22,12 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../src")
import atomworks
from atomworks import parse
+from atomworks.io.parser import STANDARD_PARSER_ARGS
from atomworks.io.utils.io_utils import to_cif_file
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf
-from rfd3.constants import STANDARD_PARSER_ARGS
-from rfd3.inference.datasets import (
- prepare_pipeline_input_from_atom_array,
-)
from rfd3.inference.input_parsing import (
+ DesignInputSpecification,
create_atom_array_from_design_specification,
)
from rfd3.transforms.pipelines import (
@@ -106,7 +103,7 @@ TEST_CFG_TRAIN = load_train_or_val_cfg()
##########################################################################################
DIRS = [
- os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../../tests'),
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../tests"),
os.path.join(os.path.dirname(os.path.abspath(__file__))),
TEST_CFG_TRAIN.paths.data.design_benchmark_data_dir,
]
@@ -156,9 +153,6 @@ def load_test_json():
TEST_JSON_DATA = load_test_json()
assert TEST_JSON_DATA, "No test json data loaded!"
-sig = inspect.signature(create_atom_array_from_design_specification)
-valid_keys_ = sig.parameters.keys()
-
def filter_inference_args(args):
return {k: v for k, v in args.items() if k in valid_keys_}
@@ -171,8 +165,11 @@ def instantiate_example(args, is_inference=True):
if is_inference:
# Keep only the kwargs that the function actually accepts
# args = filter_inference_args(args)
- atom_array, spec = create_atom_array_from_design_specification(**args)
- input = prepare_pipeline_input_from_atom_array(atom_array)
+ # atom_array, spec = create_atom_array_from_design_specification(**args)
+ # input = prepare_pipeline_input_from_atom_array(atom_array)
+ input = DesignInputSpecification.safe_init(**args).to_pipeline_input(
+ example_id=args.get("example_id", "example")
+ )
else:
file = args.get("input")
if file is None:
diff --git a/models/rfd3/src/rfd3/trainer/fabric_trainer.py b/models/rfd3/src/rfd3/trainer/fabric_trainer.py
index 15eb4ab..cc95f70 100644
--- a/models/rfd3/src/rfd3/trainer/fabric_trainer.py
+++ b/models/rfd3/src/rfd3/trainer/fabric_trainer.py
@@ -219,7 +219,6 @@ class FabricTrainer(ABC):
)
self.initialize_or_update_trainer_state({"scheduler_cfg": scheduler_cfg})
- @abstractmethod
def construct_model(self):
"""Instantiate the model, updating the trainer state in-place.
diff --git a/models/rfd3/src/rfd3/trainer/rfd3_trainer.py b/models/rfd3/src/rfd3/trainer/rfd3.py
similarity index 52%
rename from models/rfd3/src/rfd3/trainer/rfd3_trainer.py
rename to models/rfd3/src/rfd3/trainer/rfd3.py
index faa8203..865af24 100644
--- a/models/rfd3/src/rfd3/trainer/rfd3_trainer.py
+++ b/models/rfd3/src/rfd3/trainer/rfd3.py
@@ -1,54 +1,34 @@
-from collections import OrderedDict
-
-import hydra
import numpy as np
import torch
-from atomworks.ml.encoding_definitions import AF3SequenceEncoding
-from atomworks.ml.utils.token import (
- get_token_starts,
- spread_token_wise,
-)
from beartype.typing import Any, List, Union
-from biotite.structure import AtomArray, AtomArrayStack, concatenate, infer_elements
+from biotite.structure import AtomArray, AtomArrayStack
from biotite.structure.residues import get_residue_starts
from einops import repeat
-from jaxtyping import Float, Int
from lightning_utilities import apply_to_collection
from omegaconf import DictConfig
-from rfd3.constants import (
- ATOM14_ATOM_NAMES,
- VIRTUAL_ATOM_ELEMENT_NAME,
- association_schemes,
- association_schemes_stripped,
-)
from rfd3.metrics.design_metrics import get_all_backbone_metrics
from rfd3.metrics.hbonds_hbplus_metrics import get_hbond_metrics
-from rfd3.trainer.fabric_trainer import FabricTrainer
from rfd3.trainer.recycling import get_recycle_schedule
-from rfd3.transforms.conditioning_utils import (
+from rfd3.trainer.trainer_utils import (
+ _build_atom_array_stack,
+ _cleanup_virtual_atoms_and_assign_atom_name_elements,
+ _reassign_unindexed_token_chains,
+ _reorder_dict,
process_unindexed_outputs,
)
-from rfd3.util.io import (
+from rfd3.utils.io import (
build_stack_from_atom_array_and_batched_coords,
)
-from modelhub.common import exists
+from modelhub.metrics.losses import Loss
from modelhub.metrics.metric import MetricManager
-from modelhub.training.EMA import EMA
+from modelhub.trainers.fabric import FabricTrainer
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.torch import assert_no_nans, assert_same_shape
global_logger = RankedLogger(__name__, rank_zero_only=False)
-def _remap_outputs(
- xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
-) -> Float[torch.Tensor, "D L 3"]:
- """Helper function to remap outputs using a mapping tensor."""
- for i in range(xyz.shape[0]):
- xyz[i, mapping[i]] = xyz[i].clone()
- return xyz
-
class AADesignTrainer(FabricTrainer):
"""Mostly for unique things like saving outputs and parsing inputs
@@ -90,9 +70,6 @@ class AADesignTrainer(FabricTrainer):
)
self.association_scheme = association_scheme
self.seed = None
- self.inference_sampler_overrides = None
-
- super().__init__(**kwargs)
# (Initialize recycle schedule upfront so all GPU's can sample the same number of recycles within a batch)
self.n_recycles_train = n_recycles_train
@@ -110,31 +87,10 @@ class AADesignTrainer(FabricTrainer):
if metrics
else None
)
-
# Loss (full precision)
with torch.autocast(device_type=self.fabric.device.type, enabled=False):
- self.loss = AF3Loss(**loss) if loss else None
+ self.loss = Loss(**loss) if loss else None
- # FROM RF3
- def construct_model(self):
- """Construct the model and optionally wrap with EMA."""
- # ... instantiate model with Hydra and Fabric
- with self.fabric.init_module():
- ranked_logger.info("Instantiating model...")
-
- model = hydra.utils.instantiate(
- self.state["train_cfg"].model.net,
- _recursive_=False,
- )
-
- # Optionally, wrap the model with EMA
- if self.state["train_cfg"].model.ema is not None:
- ranked_logger.info("Wrapping model with EMA...")
- model = EMA(model, **self.state["train_cfg"].model.ema)
-
- self.initialize_or_update_trainer_state({"model": model})
-
- # ~~~FROM RF3
def _assemble_network_inputs(self, example: dict) -> dict:
"""Assemble and validate the network inputs."""
assert_same_shape(example["coord_atom_lvl_to_be_noised"], example["noise"])
@@ -159,7 +115,7 @@ class AADesignTrainer(FabricTrainer):
network_input["X_noisy_L"] = torch.nan_to_num(
network_input["X_noisy_L"]
)
- ranked_logger.warning(str(e))
+ global_logger.warning(str(e))
else:
# During validation, since we do not crop, there should be no NaN's in the coordinates to noise
# (They were either removed, as is done with fully unresolved chains, or resolved accoring to our pipeline's rules)
@@ -172,40 +128,6 @@ class AADesignTrainer(FabricTrainer):
return network_input
- # FROM RF3
- def _assemble_metrics_extra_info(self, example: dict, network_output: dict) -> dict:
- """Prepares the extra info for the metrics"""
- # We need the same information as for the loss...
- metrics_extra_info = self._assemble_loss_extra_info(example)
-
- # ... and possibly some additional metadata from the example dictionary
- # TODO: Generalize, so we always use the `extra_info` key, rather than unpacking the ground truth as well
- metrics_extra_info.update(
- {
- # TODO: Remove, instead using `extra_info` for all keys
- **{
- k: example["ground_truth"][k]
- for k in [
- "interfaces_to_score",
- "pn_units_to_score",
- "chain_iid_token_lvl",
- ]
- if k in example["ground_truth"]
- },
- "example_id": example[
- "example_id"
- ], # We require the example ID for logging
- # (From the parser)
- **example.get("extra_info", {}),
- }
- )
-
- # Record metrics_tags for this example
- metrics_extra_info["metrics_tags"] = example.get("metrics_tags", set())
-
- # (Create a shallow copy to avoid modifying the original dictionary)
- return {**metrics_extra_info}
-
def training_step(
self,
batch: Any,
@@ -296,9 +218,6 @@ class AADesignTrainer(FabricTrainer):
# (Note that forward() passes to the EMA/shadow model if the model is not training)
network_output = model.forward(
input=network_input,
- n_cycle=example["feats"]["msa_stack"].shape[
- 0
- ], # Determine the number of recycles from the MSA stack shape
coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"],
)
@@ -307,28 +226,19 @@ class AADesignTrainer(FabricTrainer):
msg=f"network_output for example_id: {example['example_id']}",
)
+ # ... Convert output to a stack of atom arrays
+ predicted_atom_array_stack, prediction_metadata = (
+ self._build_predicted_atom_array_stack(network_output, example)
+ )
+
metrics_output = {}
- if compute_metrics and exists(self.metrics):
+ if compute_metrics:
+ assert self.metrics is not None, "Metrics are not defined!"
+
metrics_extra_info = self._assemble_metrics_extra_info(
example, network_output
)
- # Symmetry resolution
- # TODO: Refactor such that symmetry returns the ideal coordinate permutation, we apply permutation, and pass adjusted prediction to metrics
- # (without needing to use `extra_info` as we are now)
- # TODO: Update symmetry resolution to be functional (vs. using class variable), take explicit inputs (vs. all from netowork_ouput), and use extra_info for the keys it needs
- metrics_extra_info = self.subunit_symm_resolve(
- network_output,
- metrics_extra_info,
- example["symmetry_resolution"],
- )
-
- metrics_extra_info = self.residue_symm_resolve(
- network_output,
- metrics_extra_info,
- example["automorphisms"],
- )
-
metrics_output = self.metrics(
network_input=network_input,
network_output=network_output,
@@ -337,22 +247,33 @@ class AADesignTrainer(FabricTrainer):
ground_truth_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
metrics_extra_info["X_gt_L"], example.get("atom_array", None)
),
- predicted_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
- network_output["X_L"], example.get("atom_array", None)
- ),
+ predicted_atom_array_stack=predicted_atom_array_stack,
+ prediction_metadata=prediction_metadata,
)
+ if "X_gt_index_to_X" in metrics_extra_info:
+ # Remap outputs to minimize error with ground truth
+ # TODO: Remap before computing metrics, so that we can avoid pass `extra_info` to metrics (we instead just pass the remapped prediction)
+ mapping = metrics_extra_info["X_gt_index_to_X"] # [D, L]
+ network_output["X_L"] = _remap_outputs(network_output["X_L"], mapping)
+
# Avoid gradients in stored values to prevent memory leaks
if metrics_output is not None:
metrics_output = apply_to_collection(
metrics_output, torch.Tensor, lambda x: x.detach()
)
- network_output = apply_to_collection(
- network_output, torch.Tensor, lambda x: x.detach()
- )
+ if network_output is not None:
+ network_output = apply_to_collection(
+ network_output, torch.Tensor, lambda x: x.detach()
+ )
- return {"metrics_output": metrics_output, "network_output": network_output}
+ return {
+ "metrics_output": metrics_output,
+ "network_output": network_output,
+ "predicted_atom_array_stack": predicted_atom_array_stack,
+ "prediction_metadata": prediction_metadata,
+ }
def _assemble_loss_extra_info(self, example: dict) -> dict:
"""Assembles metadata arguments to the loss function (incremental to the network inputs and outputs)."""
@@ -501,7 +422,7 @@ class AADesignTrainer(FabricTrainer):
# Align ca and calculate RMSD:
if xyz_ca_input.shape == xyz_ca_output.shape:
try:
- from rfd3.util.alignment import weighted_rigid_align
+ from rfd3.utils.alignment import weighted_rigid_align
xyz_ca_output_aligned = (
weighted_rigid_align(
@@ -532,266 +453,3 @@ class AADesignTrainer(FabricTrainer):
# Reorder metadata dictionaries to ensure 'metrics' and 'specification' are last
metadata_dict = {k: _reorder_dict(d) for k, d in metadata_dict.items()}
return atom_array_stack, metadata_dict
-
-
-def _reorder_dict(d: dict) -> OrderedDict:
- """
- Reorders keys in the dictionary to ensure 'metrics' and 'specification' are last (in that order if both present).
- """
- ordered = OrderedDict()
- first_keys = ["task", "diffused_index_map"]
- last_keys = ["metrics", "specification", "inference_sampler"]
- # First
- for k in first_keys:
- if k in d:
- ordered[k] = d[k]
- # Middle
- for k in d:
- if k not in last_keys and k not in first_keys:
- ordered[k] = d[k]
- # Last
- for k in last_keys:
- if k in d:
- ordered[k] = d[k]
- return ordered
-
-
-def _build_atom_array_stack(
- coords,
- src_atom_array,
- sequence_indices,
- sequence_logits,
- allow_sequence_outputs=True,
- read_sequence_from_sequence_head=True,
- association_scheme: str = "atom14",
-):
- """
- Wraps around build_atom_array_and_batched_coords to also include additional modifications to atom array
- """
- atom_array_stack = build_stack_from_atom_array_and_batched_coords(
- coords, src_atom_array.copy()
- )
-
- # ... Spoof empty sequences to alanines
- atom_array_stack.res_name[
- atom_array_stack.is_protein & (atom_array_stack.res_name == "UNK")
- ] = "ALA"
-
- # ... Add sequence if available
- if allow_sequence_outputs:
- array_list = []
- if read_sequence_from_sequence_head and exists(sequence_logits):
- sequence_encoding = AF3SequenceEncoding()
- for i, (atom_array, seq_indices, seq_logits) in enumerate(
- zip(atom_array_stack, sequence_indices, sequence_logits)
- ):
- # Set residue names
- diffused_mask = ~atom_array.is_motif_atom_with_fixed_seq
- three_letter_sequence = sequence_encoding.decode(
- seq_indices.cpu().numpy().astype(int)
- ) # [I]
-
- atom_array.res_name[diffused_mask] = three_letter_sequence[
- atom_array.token_id
- ][diffused_mask] # [L]
-
- # Set bfactor column as entropy of sequence logits
- p = torch.softmax(seq_logits, dim=-1).cpu().numpy() # shape (L, 32)
- res_entropy = -np.sum(p * np.log(p + 1e-10), axis=-1) # shape (L,)
- atom_array.b_factor = spread_token_wise(atom_array, res_entropy)
- array_list.append(atom_array.copy())
- else:
- # This automatically deletes virtual atoms and assigns resname, atom name, and elements
- for atom_array in atom_array_stack:
- atom_array = _readout_seq_from_struc(
- atom_array, association_scheme=association_scheme
- )
- array_list.append(atom_array)
-
- # Return as list
- atom_array_stack = array_list
-
- return atom_array_stack
-
-
-def _reassign_unindexed_token_chains(atom_array):
- if np.any((mask := atom_array.is_motif_atom_unindexed)):
- # HACK: Since res_ids are the same, we should save them with a different chain index.
- atom_array.chain_id[mask] = "X"
- atom_array.res_id[mask] = atom_array.orig_res_id[mask]
-
- # Parse to separate chains
- starts = get_token_starts(atom_array)
- unindexed_starts = starts[mask[starts]]
- token_breaks = atom_array[
- unindexed_starts
- ].is_motif_atom_unindexed_motif_breakpoint
- token_group_id = np.cumsum(token_breaks, dtype=int) # Group by motif breaks
- token_chain_id = np.array([f"X{i}" for i in token_group_id])
-
- chains = atom_array.chain_id[starts]
- chains[mask[starts]] = token_chain_id
- atom_array.chain_id = spread_token_wise(atom_array, chains)
- return atom_array
-
-
-def _cleanup_virtual_atoms_and_assign_atom_name_elements(
- atom_array, association_scheme: str = "atom14"
-):
- ## remove virtual atoms based on predicted residue and assign correct atom name and elements
- ret_mask = []
- atom_names = []
- # This is used to indicate which residue is unidentified, probably due to an invalid structure.
- # This is different from the ref_mask, which is used to delete virtual atoms, but this one is used to assign UNK resname for invalid residues.
- invalid_mask = []
-
- # ... Iterate through each residue.
- # Here we iterate through res_id instead of token_id to avoid some atomization cases or something else.
- res_ids = atom_array.res_id
- res_start_indices = np.concatenate(
- [[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
- )
- res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
- warning_issued = False
- for start, end in zip(res_start_indices, res_end_indices):
- res_array = atom_array[start:end]
-
- is_seq_known = all(
- np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
- ) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
-
- # ... If sequence is known for the original atom array, just skip
- if is_seq_known:
- ret_mask += [True] * len(res_array)
- invalid_mask += [False] * len(res_array)
- res_name = res_array[0].res_name
- atom_names += res_array.gt_atom_name.tolist()
- continue
-
- # ... If sequence is unknown for the original atom array, use the predicted / inferred sequence
- res_name = res_array[0].res_name
- if res_name not in association_schemes[association_scheme]:
- global_logger.warning(
- "Model predicted non-protein sequence for diffused residue. Cannot clean up outputs. Assigning unknown residue token."
- )
- warning_issued = True
- ret_mask += [True] * len(res_array)
- invalid_mask += [True] * len(res_array)
- atom_names += res_array.atom_name.tolist()
- continue
-
- scheme = association_schemes[association_scheme][res_name]
- ret_mask += [True if item is not None else False for item in scheme]
- atom_names += [item.strip() if item is not None else "VX" for item in scheme]
- invalid_mask += [False] * len(scheme)
-
- if len(atom_names) != atom_array.array_length():
- global_logger.warning(
- f"{atom_names=}\n{atom_array.atom_name=}\nAtom names length {len(atom_names)} does not match original array length {atom_array.array_length()}."
- "\nCould not cleanup atom array!!!"
- )
- if not warning_issued:
- raise ValueError("Atom names length does not match original array length. ")
- return atom_array
- atom_array.atom_name = atom_names
- atom_array.element = np.where(
- atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME,
- infer_elements(atom_names),
- atom_array.element,
- )
- atom_array.res_name[invalid_mask] = np.array(["UNK"] * sum(invalid_mask))
- return atom_array[ret_mask]
-
-
-def _readout_seq_from_struc(
- atom_array, central_atom="CB", threshold=0.5, association_scheme: str = "atom14"
-):
- cur_atom_array_list = []
-
- # Iterate through each residue
- res_ids = atom_array.res_id
- res_start_indices = np.concatenate(
- [[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
- )
- res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
-
- for start, end in zip(res_start_indices, res_end_indices):
- # ... Check if the current residue is after padding (seq unknown):
- cur_res_atom_array = atom_array[start:end]
- is_seq_known = all(
- np.array(cur_res_atom_array.is_motif_atom_with_fixed_seq, dtype=bool)
- )
-
- # Here it assumes that every non-protein part has its sequence shown (not padded)
- if not is_seq_known:
- # For Glycine: it doesn't have CB, so set the virtual atom as CA.
- # The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
- # There might be a better way to do this.
- CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
- CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
- if np.linalg.norm(CA_coord - CB_coord) < threshold:
- cur_central_atom = "CA"
- else:
- cur_central_atom = central_atom
-
- central_mask = cur_res_atom_array.atom_name == cur_central_atom
-
- # ... Calculate the distance to the central atom
- central_coord = cur_res_atom_array.coord[central_mask][
- 0
- ] # Should only have one central atom anyway
- dists = np.linalg.norm(cur_res_atom_array.coord - central_coord, axis=-1)
-
- # ... Select virtual atom by the distance. Shouldn't count the central atom itself.
- is_virtual = (dists < threshold) & ~central_mask
-
- # ... Throw away virtual atoms
- cur_res_atom_array_wo_virtual = cur_res_atom_array[~is_virtual]
- cur_pred_res_atom_names = (
- cur_res_atom_array_wo_virtual.atom_name
- ) # e.g. [N, CA, C, O, CB, V6, V2]
-
- # ... Iterate over the possible restypes and find the matched one if there is any
- has_restype_assigned = False
- for restype, atom_names in association_schemes_stripped[
- association_scheme
- ].items():
- atom_names = np.array(atom_names)
-
- # Shouldn't match these two
- if restype in ["UNK", "MSK"]:
- continue
-
- # ... Find the index of virtual atom names in the standard atom14 names
- atom_name_idx_in_atom14_scheme = np.array(
- [
- np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
- for atom_name in cur_pred_res_atom_names
- ]
- ) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
- atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
- atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
-
- # ... Find the matched restype by checking if all the non-None posititons and None positions match
- # This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
- if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
- x is None for x in atom_names[~atom14_scheme_mask]
- ):
- cur_res_atom_array.res_name = np.array(
- [restype] * len(cur_res_atom_array)
- )
- cur_atom_array_list.append(cur_res_atom_array)
- has_restype_assigned = True
- break
- else:
- cur_atom_array_list.append(cur_res_atom_array)
- has_restype_assigned = True
-
- # ... Give UNK as the residue name if the mapping fails (unrealistic sidechain)
- if not has_restype_assigned:
- cur_res_atom_array.res_name = np.array(["UNK"] * len(cur_res_atom_array))
- cur_atom_array_list.append(cur_res_atom_array)
-
- cur_atom_array = concatenate(cur_atom_array_list)
-
- return cur_atom_array
diff --git a/models/rfd3/src/rfd3/trainer/trainer_utils.py b/models/rfd3/src/rfd3/trainer/trainer_utils.py
new file mode 100644
index 0000000..77fbb3a
--- /dev/null
+++ b/models/rfd3/src/rfd3/trainer/trainer_utils.py
@@ -0,0 +1,502 @@
+from collections import Counter, OrderedDict
+
+import numpy as np
+import torch
+from atomworks.ml.encoding_definitions import AF3SequenceEncoding
+from atomworks.ml.utils.token import (
+ get_token_starts,
+ spread_token_wise,
+)
+from biotite.structure import concatenate, infer_elements
+from jaxtyping import Float, Int
+from rfd3.constants import (
+ ATOM14_ATOM_NAMES,
+ VIRTUAL_ATOM_ELEMENT_NAME,
+ association_schemes,
+ association_schemes_stripped,
+)
+from rfd3.utils.io import (
+ build_stack_from_atom_array_and_batched_coords,
+)
+from scipy.optimize import linear_sum_assignment
+
+from modelhub.common import exists
+from modelhub.utils.ddp import RankedLogger
+
+global_logger = RankedLogger(__name__, rank_zero_only=False)
+
+#######################################################################
+# Pythonic Helper functions
+#######################################################################
+
+
+def _remap_outputs(
+ xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
+) -> Float[torch.Tensor, "D L 3"]:
+ """Helper function to remap outputs using a mapping tensor."""
+ for i in range(xyz.shape[0]):
+ xyz[i, mapping[i]] = xyz[i].clone()
+ return xyz
+
+
+def _reorder_dict(d: dict) -> OrderedDict:
+ """
+ Reorders keys in the dictionary to ensure 'metrics' and 'specification' are last (in that order if both present).
+ """
+ ordered = OrderedDict()
+ first_keys = ["task", "diffused_index_map"]
+ last_keys = ["metrics", "specification", "inference_sampler"]
+ # First
+ for k in first_keys:
+ if k in d:
+ ordered[k] = d[k]
+ # Middle
+ for k in d:
+ if k not in last_keys and k not in first_keys:
+ ordered[k] = d[k]
+ # Last
+ for k in last_keys:
+ if k in d:
+ ordered[k] = d[k]
+ return ordered
+
+
+#######################################################################
+# Biotite-related helper functions
+#######################################################################
+
+
+def _build_atom_array_stack(
+ coords,
+ src_atom_array,
+ sequence_indices,
+ sequence_logits,
+ allow_sequence_outputs=True,
+ read_sequence_from_sequence_head=True,
+ association_scheme: str = "atom14",
+):
+ """
+ Wraps around build_atom_array_and_batched_coords to also include additional modifications to atom array
+ """
+ atom_array_stack = build_stack_from_atom_array_and_batched_coords(
+ coords, src_atom_array.copy()
+ )
+
+ # ... Spoof empty sequences to alanines
+ atom_array_stack.res_name[
+ atom_array_stack.is_protein & (atom_array_stack.res_name == "UNK")
+ ] = "ALA"
+
+ # ... Add sequence if available
+ if allow_sequence_outputs:
+ array_list = []
+ if read_sequence_from_sequence_head and exists(sequence_logits):
+ sequence_encoding = AF3SequenceEncoding()
+ for i, (atom_array, seq_indices, seq_logits) in enumerate(
+ zip(atom_array_stack, sequence_indices, sequence_logits)
+ ):
+ # Set residue names
+ diffused_mask = ~atom_array.is_motif_atom_with_fixed_seq
+ three_letter_sequence = sequence_encoding.decode(
+ seq_indices.cpu().numpy().astype(int)
+ ) # [I]
+
+ atom_array.res_name[diffused_mask] = three_letter_sequence[
+ atom_array.token_id
+ ][diffused_mask] # [L]
+
+ # Set bfactor column as entropy of sequence logits
+ p = torch.softmax(seq_logits, dim=-1).cpu().numpy() # shape (L, 32)
+ res_entropy = -np.sum(p * np.log(p + 1e-10), axis=-1) # shape (L,)
+ atom_array.b_factor = spread_token_wise(atom_array, res_entropy)
+ array_list.append(atom_array.copy())
+ else:
+ # This automatically deletes virtual atoms and assigns resname, atom name, and elements
+ for atom_array in atom_array_stack:
+ atom_array = _readout_seq_from_struc(
+ atom_array, association_scheme=association_scheme
+ )
+ array_list.append(atom_array)
+
+ # Return as list
+ atom_array_stack = array_list
+
+ return atom_array_stack
+
+
+def _cleanup_virtual_atoms_and_assign_atom_name_elements(
+ atom_array, association_scheme: str = "atom14"
+):
+ ## remove virtual atoms based on predicted residue and assign correct atom name and elements
+ ret_mask = []
+ atom_names = []
+ # This is used to indicate which residue is unidentified, probably due to an invalid structure.
+ # This is different from the ref_mask, which is used to delete virtual atoms, but this one is used to assign UNK resname for invalid residues.
+ invalid_mask = []
+
+ # ... Iterate through each residue.
+ # Here we iterate through res_id instead of token_id to avoid some atomization cases or something else.
+ res_ids = atom_array.res_id
+ res_start_indices = np.concatenate(
+ [[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
+ )
+ res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
+ warning_issued = False
+ for start, end in zip(res_start_indices, res_end_indices):
+ res_array = atom_array[start:end]
+
+ is_seq_known = all(
+ np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
+ ) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
+
+ # ... If sequence is known for the original atom array, just skip
+ if is_seq_known:
+ ret_mask += [True] * len(res_array)
+ invalid_mask += [False] * len(res_array)
+ res_name = res_array[0].res_name
+ atom_names += res_array.gt_atom_name.tolist()
+ continue
+
+ # ... If sequence is unknown for the original atom array, use the predicted / inferred sequence
+ res_name = res_array[0].res_name
+ if res_name not in association_schemes[association_scheme]:
+ global_logger.warning(
+ "Model predicted non-protein sequence for diffused residue. Cannot clean up outputs. Assigning unknown residue token."
+ )
+ warning_issued = True
+ ret_mask += [True] * len(res_array)
+ invalid_mask += [True] * len(res_array)
+ atom_names += res_array.atom_name.tolist()
+ continue
+
+ scheme = association_schemes[association_scheme][res_name]
+ ret_mask += [True if item is not None else False for item in scheme]
+ atom_names += [item.strip() if item is not None else "VX" for item in scheme]
+ invalid_mask += [False] * len(scheme)
+
+ if len(atom_names) != atom_array.array_length():
+ global_logger.warning(
+ f"{atom_names=}\n{atom_array.atom_name=}\nAtom names length {len(atom_names)} does not match original array length {atom_array.array_length()}."
+ "\nCould not cleanup atom array!!!"
+ )
+ if not warning_issued:
+ raise ValueError("Atom names length does not match original array length. ")
+ return atom_array
+ atom_array.atom_name = atom_names
+ atom_array.element = np.where(
+ atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME,
+ infer_elements(atom_names),
+ atom_array.element,
+ )
+ atom_array.res_name[invalid_mask] = np.array(["UNK"] * sum(invalid_mask))
+ return atom_array[ret_mask]
+
+
+def _readout_seq_from_struc(
+ atom_array, central_atom="CB", threshold=0.5, association_scheme: str = "atom14"
+):
+ cur_atom_array_list = []
+
+ # Iterate through each residue
+ res_ids = atom_array.res_id
+ res_start_indices = np.concatenate(
+ [[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
+ )
+ res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
+
+ for start, end in zip(res_start_indices, res_end_indices):
+ # ... Check if the current residue is after padding (seq unknown):
+ cur_res_atom_array = atom_array[start:end]
+ is_seq_known = all(
+ np.array(cur_res_atom_array.is_motif_atom_with_fixed_seq, dtype=bool)
+ )
+
+ # Here it assumes that every non-protein part has its sequence shown (not padded)
+ if not is_seq_known:
+ # For Glycine: it doesn't have CB, so set the virtual atom as CA.
+ # The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
+ # There might be a better way to do this.
+ CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
+ CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
+ if np.linalg.norm(CA_coord - CB_coord) < threshold:
+ cur_central_atom = "CA"
+ else:
+ cur_central_atom = central_atom
+
+ central_mask = cur_res_atom_array.atom_name == cur_central_atom
+
+ # ... Calculate the distance to the central atom
+ central_coord = cur_res_atom_array.coord[central_mask][
+ 0
+ ] # Should only have one central atom anyway
+ dists = np.linalg.norm(cur_res_atom_array.coord - central_coord, axis=-1)
+
+ # ... Select virtual atom by the distance. Shouldn't count the central atom itself.
+ is_virtual = (dists < threshold) & ~central_mask
+
+ # ... Throw away virtual atoms
+ cur_res_atom_array_wo_virtual = cur_res_atom_array[~is_virtual]
+ cur_pred_res_atom_names = (
+ cur_res_atom_array_wo_virtual.atom_name
+ ) # e.g. [N, CA, C, O, CB, V6, V2]
+
+ # ... Iterate over the possible restypes and find the matched one if there is any
+ has_restype_assigned = False
+ for restype, atom_names in association_schemes_stripped[
+ association_scheme
+ ].items():
+ atom_names = np.array(atom_names)
+
+ # Shouldn't match these two
+ if restype in ["UNK", "MSK"]:
+ continue
+
+ # ... Find the index of virtual atom names in the standard atom14 names
+ atom_name_idx_in_atom14_scheme = np.array(
+ [
+ np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
+ for atom_name in cur_pred_res_atom_names
+ ]
+ ) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
+ atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
+ atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
+
+ # ... Find the matched restype by checking if all the non-None posititons and None positions match
+ # This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
+ if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
+ x is None for x in atom_names[~atom14_scheme_mask]
+ ):
+ cur_res_atom_array.res_name = np.array(
+ [restype] * len(cur_res_atom_array)
+ )
+ cur_atom_array_list.append(cur_res_atom_array)
+ has_restype_assigned = True
+ break
+ else:
+ cur_atom_array_list.append(cur_res_atom_array)
+ has_restype_assigned = True
+
+ # ... Give UNK as the residue name if the mapping fails (unrealistic sidechain)
+ if not has_restype_assigned:
+ cur_res_atom_array.res_name = np.array(["UNK"] * len(cur_res_atom_array))
+ cur_atom_array_list.append(cur_res_atom_array)
+
+ cur_atom_array = concatenate(cur_atom_array_list)
+
+ return cur_atom_array
+
+
+#######################################################################
+# Unindexed output parsing
+#######################################################################
+
+
+def _reassign_unindexed_token_chains(atom_array):
+ if np.any((mask := atom_array.is_motif_atom_unindexed)):
+ # HACK: Since res_ids are the same, we should save them with a different chain index.
+ atom_array.chain_id[mask] = "X"
+ atom_array.res_id[mask] = atom_array.orig_res_id[mask]
+
+ # Parse to separate chains
+ starts = get_token_starts(atom_array)
+ unindexed_starts = starts[mask[starts]]
+ token_breaks = atom_array[
+ unindexed_starts
+ ].is_motif_atom_unindexed_motif_breakpoint
+ token_group_id = np.cumsum(token_breaks, dtype=int) # Group by motif breaks
+ token_chain_id = np.array([f"X{i}" for i in token_group_id])
+
+ chains = atom_array.chain_id[starts]
+ chains[mask[starts]] = token_chain_id
+ atom_array.chain_id = spread_token_wise(atom_array, chains)
+ return atom_array
+
+
+def process_unindexed_outputs(
+ atom_array,
+ match_atom_names=True,
+ insert_guideposts=False,
+ verbose=False,
+):
+ """
+ Process design outputs containing unindexed tokens.
+ Returns metadata such as the assigned positional indices from the input indices
+ and the RMSD of the unindexed tokens.
+
+ Returns:
+ - Diffused atom array (without additional unindexed tokens)
+ - Metadata:
+ - diffused_indices: keys = original (contig) indices, values = diffused indices
+ - insertion_rmsd: overall RMSD of insertion
+ - insertion_rmsd_by_residue: RMSD of insertion for each token
+
+ TODO: Add additional geometry metrics such as bond angle non-ideality, clashes etc.
+ TODO: atom1d conditioning adherence - does the output contain HBonds in the right places, correct rasa values?
+ """
+ # ... Find assignments based on greedy search
+ starts = get_token_starts(atom_array, add_exclusive_stop=True)
+
+ # [N_diffused,]
+ atom_array_diffused = atom_array[~atom_array.is_motif_atom_unindexed].copy()
+ global_idx = np.arange(atom_array.array_length())[
+ ~atom_array.is_motif_atom_unindexed
+ ]
+
+ metadata = {
+ "diffused_index_map": {},
+ "insertion_rmsd_by_token": {},
+ "join_point_rmsd_by_token": {},
+ "insertion_rmsd_by_restype": {},
+ }
+ token_maes = []
+ token_rmcds = []
+ n_conjoined_residues = 0
+
+ # Initialize an empty array
+ inserted_mask = np.full_like(atom_array_diffused.is_motif_atom_unindexed, False)
+
+ for start, end in zip(starts[:-1], starts[1:]):
+ token = atom_array[start:end]
+ if not token.is_motif_atom_unindexed.all():
+ continue
+
+ if "src_component" in token.get_annotation_categories():
+ token_pdb_id = token.src_component[0]
+ else:
+ raise ValueError(
+ "Missing annotation 'src_component' in token. Is this inference?"
+ )
+
+ if "src_sym_component" in token.get_annotation_categories():
+ # if symmetry, token_pdb_id are updated to match the symmetrized component
+ token_pdb_id = token.src_sym_component[0]
+
+ res_name = token.res_name[0]
+
+ # ... Calculate [N_unindex, N_diffused] distance matrix
+ dists = np.linalg.norm(
+ token.coord[:, None] - atom_array_diffused.coord[None, :], axis=-1
+ )
+
+ # ... Match atom indices based on atom names (mask out non-identical) and remove already inserted
+ dists[:, inserted_mask.copy()] = np.inf
+ if match_atom_names:
+ matching_atom_name = (
+ token.atom_name[:, None] == atom_array_diffused.atom_name[None, :]
+ )
+ dists[~matching_atom_name] = np.inf
+
+ # ... Find the res_id's in the diffused regions belonging to the diffused indices
+ row_ind, col_ind = linear_sum_assignment(dists)
+ res_id, chain_id, is_conjoined = indices_to_components_(
+ atom_array_diffused, col_ind
+ )
+ n_conjoined_residues += int(is_conjoined)
+
+ # ... Recompute distance indices based on single residue pairings only
+ token_match = (atom_array_diffused.res_id == res_id) & (
+ atom_array_diffused.chain_id == chain_id
+ )
+ dists[:, ~token_match] = np.nan
+ BIG = 1e12
+ dists = np.nan_to_num(dists, nan=BIG, posinf=BIG, neginf=BIG)
+ row_ind, col_ind = linear_sum_assignment(dists)
+ res_id_, chain_id_, _ = indices_to_components_(atom_array_diffused, col_ind)
+
+ assert (res_id_ == res_id) & (chain_id_ == chain_id)
+ inserted_mask = np.logical_or(inserted_mask, token_match)
+
+ # ... Compute metrics based on the new distances
+ diff = token.coord[row_ind] - atom_array_diffused.coord[col_ind]
+ token_rmsd = float(np.sqrt((diff**2).sum(-1).mean()))
+ token_rmcd = float(np.cbrt((np.abs(diff) ** 3).sum(-1).mean()))
+ token_mae = float((np.abs(diff)).sum(-1).mean())
+
+ metadata["insertion_rmsd_by_token"][token_pdb_id] = token_rmsd
+ token_maes.append(token_mae)
+ token_rmcds.append(token_rmcd)
+
+ if res_name not in metadata["insertion_rmsd_by_restype"]:
+ metadata["insertion_rmsd_by_restype"][res_name] = []
+ metadata["insertion_rmsd_by_restype"][res_name].append(token_rmsd)
+ if not np.any(np.isin(token.atom_name, ["N", "CA", "C", "O"])):
+ if np.sum(token.atomize) == 1:
+ join_atom = np.where(token.atomize)[0][0]
+ elif "CB" in token.atom_name:
+ join_atom = np.where(token.atom_name == "CB")[0][0]
+ else:
+ join_atom = None
+
+ if join_atom is None:
+ global_logger.warning(
+ f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
+ )
+ else:
+ dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
+ metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
+
+ metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
+
+ # ... Decide whether to cleanup guideposts or not
+ if insert_guideposts:
+ atom_array_diffused.coord[global_idx[col_ind]] = token.coord[row_ind]
+ if token.is_motif_atom_with_fixed_seq[0]:
+ atom_array_diffused.res_name[token_match] = token.res_name[0]
+ # atom_array_diffused.is_motif_token[token_match] = True
+ # atom_array_diffused.is_motif_atom[global_idx[col_ind]] = True
+ atom_array_diffused.is_motif_atom_with_fixed_coord[global_idx[col_ind]] = (
+ True
+ )
+
+ # ... Calculate global metrics
+ def safe_mean(x):
+ """Return nan-safe mean for empty or nan arrays."""
+ x = np.asarray(x, float)
+ if x.size == 0 or not np.isfinite(x).any():
+ return float("nan")
+ return float(np.nanmean(x))
+
+ metadata["insertion.mae"] = safe_mean(token_maes)
+ metadata["insertion.rmcd"] = safe_mean(token_rmcds)
+ metadata["insertion_rmsd"] = safe_mean(
+ list(metadata["insertion_rmsd_by_token"].values())
+ )
+ metadata["join_point_rmsd"] = safe_mean(
+ list(metadata["join_point_rmsd_by_token"].values())
+ )
+ metadata["insertion_rmsd_by_restype"] = {
+ a: safe_mean(v) for a, v in metadata["insertion_rmsd_by_restype"].items()
+ }
+ metadata["n_conjoined_residues"] = n_conjoined_residues
+
+ if not verbose:
+ metadata = {
+ k: v for k, v in metadata.items() if not k.startswith("insertion_rmsd_by_")
+ }
+
+ return atom_array_diffused, metadata
+
+
+def indices_to_components_(atom_array, col_ind):
+ """
+ Fetch chain and resids in atom array given a set of raw indices
+ will return 'conjoined' if indices to not map to a unique residue
+ """
+ res_ids, chain_ids = (
+ atom_array.res_id[col_ind],
+ atom_array.chain_id[col_ind],
+ )
+ if len(set(res_ids.tolist())) > 1 or len(set(chain_ids.tolist())) > 1:
+ global_logger.warning(
+ f"Unindexed token mapped its atoms to multiple diffused residues: {res_ids.tolist()} and chains {chain_ids.tolist()}."
+ )
+ # Handle by majority
+ pair_counts = Counter(zip(chain_ids.tolist(), res_ids.tolist()))
+ (chain_id, res_id), _ = pair_counts.most_common(1)[0]
+ conjoined = True
+ else:
+ res_id = res_ids[0]
+ chain_id = chain_ids[0]
+ conjoined = False
+
+ return res_id, chain_id, conjoined
diff --git a/models/rfd3/src/rfd3/transforms/bfactor_conditioned_transforms.py b/models/rfd3/src/rfd3/transforms/bfactor_conditioned_transforms.py
deleted file mode 100644
index 69e4a03..0000000
--- a/models/rfd3/src/rfd3/transforms/bfactor_conditioned_transforms.py
+++ /dev/null
@@ -1,56 +0,0 @@
-from atomworks.ml.transforms._checks import (
- check_contains_keys,
- check_is_instance,
-)
-from atomworks.ml.transforms.base import Transform
-from biotite.structure import AtomArray
-
-
-class SetOccToZeroOnBfactor(Transform):
- """
- This component marks atoms as occ=0 based on bfactor values
-
- It takes as input 'brange', a list specifying the Mminimum and maximum B factors to
- keep.
-
- Example:
- brange = [-1.0,70.0] will mark with occ=0 any atom with b>70 or b<-1
- """
-
- def __init__(
- self,
- bmin=None,
- bmax=None,
- ):
- self.bmin = bmin
- self.bmax = bmax
-
- def check_input(self, data: dict):
- check_contains_keys(data, ["atom_array"])
- check_is_instance(data, "atom_array", AtomArray)
- # check_atom_array_annotation(data, ["b_factor", "occupancy"])
-
- def forward(self, data: dict) -> dict:
- atom_array = data["atom_array"]
-
- if (
- self.bmin is None and self.bmax is None
- ) or "b_factor" not in atom_array.get_annotation_categories():
- return data
-
- bfact = atom_array.get_annotation("b_factor")
- if self.bmin is not None:
- mask = bfact < self.bmin
- if self.bmax is not None:
- mask = mask | (bfact > self.bmax)
- else:
- mask = bfact > self.bmax
-
- occ = atom_array.get_annotation("occupancy")
- occ[mask] = 0.0
-
- atom_array.set_annotation("occupancy", occ)
-
- data["atom_array"] = atom_array
-
- return data
diff --git a/models/rfd3/src/rfd3/transforms/conditioning_utils.py b/models/rfd3/src/rfd3/transforms/conditioning_utils.py
index 14350a9..e439ef9 100644
--- a/models/rfd3/src/rfd3/transforms/conditioning_utils.py
+++ b/models/rfd3/src/rfd3/transforms/conditioning_utils.py
@@ -1,10 +1,6 @@
-from collections import Counter
-
import networkx as nx
import numpy as np
from atomworks.io.utils.bonds import _atom_array_to_networkx_graph
-from atomworks.ml.utils.token import get_token_starts
-from scipy.optimize import linear_sum_assignment
from modelhub.utils.ddp import RankedLogger
@@ -192,171 +188,6 @@ def choose_uniformly_random_atom_name(subarray):
#################################################################################
-def process_unindexed_outputs(
- atom_array,
- match_atom_names=True,
- insert_guideposts=False,
- verbose=False,
-):
- """
- Process design outputs containing unindexed tokens.
- Returns metadata such as the assigned positional indices from the input indices
- and the RMSD of the unindexed tokens.
-
- Returns:
- - Diffused atom array (without additional unindexed tokens)
- - Metadata:
- - diffused_indices: keys = original (contig) indices, values = diffused indices
- - insertion_rmsd: overall RMSD of insertion
- - insertion_rmsd_by_residue: RMSD of insertion for each token
-
- TODO: Add additional geometry metrics such as bond angle non-ideality, clashes etc.
- TODO: atom1d conditioning adherence - does the output contain HBonds in the right places, correct rasa values?
- """
- # ... Find assignments based on greedy search
- starts = get_token_starts(atom_array, add_exclusive_stop=True)
-
- # [N_diffused,]
- atom_array_diffused = atom_array[~atom_array.is_motif_atom_unindexed].copy()
- global_idx = np.arange(atom_array.array_length())[
- ~atom_array.is_motif_atom_unindexed
- ]
-
- metadata = {
- "diffused_index_map": {},
- "insertion_rmsd_by_token": {},
- "join_point_rmsd_by_token": {},
- "insertion_rmsd_by_restype": {},
- }
- token_maes = []
- token_rmcds = []
- n_conjoined_residues = 0
-
- # Initialize an empty array
- inserted_mask = np.full_like(atom_array_diffused.is_motif_atom_unindexed, False)
-
- for start, end in zip(starts[:-1], starts[1:]):
- token = atom_array[start:end]
- if not token.is_motif_atom_unindexed.all():
- continue
-
- if "src_component" in token.get_annotation_categories():
- token_pdb_id = token.src_component[0]
- else:
- raise ValueError(
- "Missing annotation 'src_component' in token. Is this inference?"
- )
-
- if "src_sym_component" in token.get_annotation_categories():
- # if symmetry, token_pdb_id are updated to match the symmetrized component
- token_pdb_id = token.src_sym_component[0]
-
- res_name = token.res_name[0]
-
- # ... Calculate [N_unindex, N_diffused] distance matrix
- dists = np.linalg.norm(
- token.coord[:, None] - atom_array_diffused.coord[None, :], axis=-1
- )
-
- # ... Match atom indices based on atom names (mask out non-identical) and remove already inserted
- dists[:, inserted_mask.copy()] = np.inf
- if match_atom_names:
- matching_atom_name = (
- token.atom_name[:, None] == atom_array_diffused.atom_name[None, :]
- )
- dists[~matching_atom_name] = np.inf
-
- # ... Find the res_id's in the diffused regions belonging to the diffused indices
- row_ind, col_ind = linear_sum_assignment(dists)
- res_id, chain_id, is_conjoined = indices_to_components(
- atom_array_diffused, col_ind
- )
- n_conjoined_residues += int(is_conjoined)
-
- # ... Recompute distance indices based on single residue pairings only
- token_match = (atom_array_diffused.res_id == res_id) & (
- atom_array_diffused.chain_id == chain_id
- )
- dists[:, ~token_match] = np.nan
- BIG = 1e12
- dists = np.nan_to_num(dists, nan=BIG, posinf=BIG, neginf=BIG)
- row_ind, col_ind = linear_sum_assignment(dists)
- res_id_, chain_id_, _ = indices_to_components(atom_array_diffused, col_ind)
-
- assert (res_id_ == res_id) & (chain_id_ == chain_id)
- inserted_mask = np.logical_or(inserted_mask, token_match)
-
- # ... Compute metrics based on the new distances
- diff = token.coord[row_ind] - atom_array_diffused.coord[col_ind]
- token_rmsd = float(np.sqrt((diff**2).sum(-1).mean()))
- token_rmcd = float(np.cbrt((np.abs(diff) ** 3).sum(-1).mean()))
- token_mae = float((np.abs(diff)).sum(-1).mean())
-
- metadata["insertion_rmsd_by_token"][token_pdb_id] = token_rmsd
- token_maes.append(token_mae)
- token_rmcds.append(token_rmcd)
-
- if res_name not in metadata["insertion_rmsd_by_restype"]:
- metadata["insertion_rmsd_by_restype"][res_name] = []
- metadata["insertion_rmsd_by_restype"][res_name].append(token_rmsd)
- if not np.any(np.isin(token.atom_name, ["N", "CA", "C", "O"])):
- if np.sum(token.atomize) == 1:
- join_atom = np.where(token.atomize)[0][0]
- elif "CB" in token.atom_name:
- join_atom = np.where(token.atom_name == "CB")[0][0]
- else:
- join_atom = None
-
- if join_atom is None:
- global_logger.warning(
- f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
- )
- else:
- dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
- metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
-
- metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
-
- # ... Decide whether to cleanup guideposts or not
- if insert_guideposts:
- atom_array_diffused.coord[global_idx[col_ind]] = token.coord[row_ind]
- if token.is_motif_atom_with_fixed_seq[0]:
- atom_array_diffused.res_name[token_match] = token.res_name[0]
- # atom_array_diffused.is_motif_token[token_match] = True
- # atom_array_diffused.is_motif_atom[global_idx[col_ind]] = True
- atom_array_diffused.is_motif_atom_with_fixed_coord[global_idx[col_ind]] = (
- True
- )
-
- # ... Calculate global metrics
- def safe_mean(x):
- """Return nan-safe mean for empty or nan arrays."""
- x = np.asarray(x, float)
- if x.size == 0 or not np.isfinite(x).any():
- return float("nan")
- return float(np.nanmean(x))
-
- metadata["insertion.mae"] = safe_mean(token_maes)
- metadata["insertion.rmcd"] = safe_mean(token_rmcds)
- metadata["insertion_rmsd"] = safe_mean(
- list(metadata["insertion_rmsd_by_token"].values())
- )
- metadata["join_point_rmsd"] = safe_mean(
- list(metadata["join_point_rmsd_by_token"].values())
- )
- metadata["insertion_rmsd_by_restype"] = {
- a: safe_mean(v) for a, v in metadata["insertion_rmsd_by_restype"].items()
- }
- metadata["n_conjoined_residues"] = n_conjoined_residues
-
- if not verbose:
- metadata = {
- k: v for k, v in metadata.items() if not k.startswith("insertion_rmsd_by_")
- }
-
- return atom_array_diffused, metadata
-
-
def random_condition(p_cond):
"""
Made this function because I always get confused by which order the
@@ -367,28 +198,3 @@ def random_condition(p_cond):
return False
else:
return np.random.rand() < p_cond
-
-
-def indices_to_components(atom_array, col_ind):
- """
- Fetch chain and resids in atom array given a set of raw indices
- will return 'conjoined' if indices to not map to a unique residue
- """
- res_ids, chain_ids = (
- atom_array.res_id[col_ind],
- atom_array.chain_id[col_ind],
- )
- if len(set(res_ids.tolist())) > 1 or len(set(chain_ids.tolist())) > 1:
- global_logger.warning(
- f"Unindexed token mapped its atoms to multiple diffused residues: {res_ids.tolist()} and chains {chain_ids.tolist()}."
- )
- # Handle by majority
- pair_counts = Counter(zip(chain_ids.tolist(), res_ids.tolist()))
- (chain_id, res_id), _ = pair_counts.most_common(1)[0]
- conjoined = True
- else:
- res_id = res_ids[0]
- chain_id = chain_ids[0]
- conjoined = False
-
- return res_id, chain_id, conjoined
diff --git a/models/rfd3/src/rfd3/transforms/pipelines.py b/models/rfd3/src/rfd3/transforms/pipelines.py
index 6abc182..055e75e 100644
--- a/models/rfd3/src/rfd3/transforms/pipelines.py
+++ b/models/rfd3/src/rfd3/transforms/pipelines.py
@@ -624,35 +624,33 @@ def build_atom14_base_pipeline(
kwargs.setdefault("crop_spatial_probability", 0.0)
kwargs.setdefault("dna_contact_crop_probability", 0.0)
kwargs.setdefault("max_atoms_in_crop", None)
- kwargs.setdefault("b_factor_min", None)
- kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
- kwargs.setdefault("meta_conditioning_probabilities", {})
- kwargs.setdefault("association_scheme", "atom14")
-
- kwargs.setdefault("sigma_perturb", 0.0)
- kwargs.setdefault("sigma_perturb_com", 0.0)
- kwargs.setdefault("allowed_types", "ALL")
- kwargs.setdefault("train_conditions", {})
-
- # TODO: Delete these once all checkpoints are updated with the latest defaults
- kwargs.setdefault("generate_conformers_for_non_protein_only", True)
- kwargs.setdefault("atom_1d_features", None)
- kwargs.setdefault("token_1d_features", None)
- kwargs.setdefault("diffusion_batch_size", 16)
- kwargs.setdefault("sigma_data", 16)
- kwargs.setdefault("return_atom_array", True)
- kwargs.setdefault("provide_elements_for_unindexed_components", False)
- kwargs.setdefault("center_option", "all")
- kwargs.setdefault("use_element_for_atom_names_of_atomized_tokens", False)
-
- kwargs.setdefault("residue_cache_dir", None)
kwargs.setdefault("keep_full_binder_in_spatial_crop", True)
- kwargs.setdefault("max_binder_length", 999)
kwargs.setdefault("max_ppi_hotspots_frac_to_provide", 0)
kwargs.setdefault("ppi_hotspot_max_distance", 15)
kwargs.setdefault("max_ss_frac_to_provide", 0.0)
kwargs.setdefault("min_ss_island_len", 0)
kwargs.setdefault("max_ss_island_len", 999)
+ kwargs.setdefault("max_binder_length", 999)
+
+ kwargs.setdefault("b_factor_min", None)
+ kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
+ kwargs.setdefault("meta_conditioning_probabilities", {})
+ kwargs.setdefault("association_scheme", "dense")
+ kwargs.setdefault("sigma_perturb", 0.0)
+ kwargs.setdefault("sigma_perturb_com", 0.0)
+ kwargs.setdefault("allowed_types", "ALL")
+ kwargs.setdefault("train_conditions", {})
+ kwargs.setdefault("residue_cache_dir", None)
+
+ # TODO: Delete these once all checkpoints are updated with the latest defaults
+ kwargs.setdefault("generate_conformers_for_non_protein_only", True)
+ # kwargs.setdefault("atom_1d_features", None)
+ # kwargs.setdefault("token_1d_features", None)
+ # kwargs.setdefault("diffusion_batch_size", 16)
+ # kwargs.setdefault("sigma_data", 16)
+ kwargs.setdefault("return_atom_array", True)
+ kwargs.setdefault("provide_elements_for_unindexed_components", False)
+ kwargs.setdefault("center_option", "all")
return build_atom14_base_pipeline_(
is_inference=is_inference,
diff --git a/models/rfd3/src/rfd3/transforms/virtual_atoms.py b/models/rfd3/src/rfd3/transforms/virtual_atoms.py
index 64165a5..6d6dea6 100644
--- a/models/rfd3/src/rfd3/transforms/virtual_atoms.py
+++ b/models/rfd3/src/rfd3/transforms/virtual_atoms.py
@@ -37,7 +37,9 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato
raise ValueError(
f"Scheme {scheme} not found in association_schemes_stripped. Available schemes: {list(association_schemes_stripped.keys())}"
)
- atom_names = [str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
+ atom_names = (
+ [str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
+ )
idxs = np.array(
[
association_schemes_stripped[scheme][res_name].index(name)
diff --git a/models/rfd3/src/rfd3/inference/inference_utils.py b/models/rfd3/src/rfd3/utils/inference.py
similarity index 99%
rename from models/rfd3/src/rfd3/inference/inference_utils.py
rename to models/rfd3/src/rfd3/utils/inference.py
index c6dd052..3f04fb8 100644
--- a/models/rfd3/src/rfd3/inference/inference_utils.py
+++ b/models/rfd3/src/rfd3/utils/inference.py
@@ -10,6 +10,7 @@ import biotite.structure as struc
import numpy as np
from atomworks import parse
from atomworks.constants import STANDARD_DNA
+from atomworks.io.parser import STANDARD_PARSER_ARGS
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.preprocessing.utils.structure_utils import (
get_atom_mask_from_cell_list,
@@ -21,18 +22,18 @@ from atomworks.ml.utils.token import (
from rfd3.constants import (
REQUIRED_CONDITIONING_ANNOTATIONS,
)
-from rfd3.inference.components import (
- fetch_mask_from_component,
- get_name_mask,
- unravel_components,
-)
from rfd3.transforms.conditioning_base import (
convert_existing_annotations_to_bool,
set_default_conditioning_annotations,
)
from rfd3.transforms.conditioning_utils import sample_island_tokens
-from rfd3.constants import STANDARD_PARSER_ARGS
+
from modelhub.common import exists
+from modelhub.utils.components import (
+ fetch_mask_from_component,
+ get_name_mask,
+ unravel_components,
+)
from modelhub.utils.ddp import RankedLogger
logging.basicConfig(level=logging.INFO)
diff --git a/models/rfd3/src/rfd3/util/io.py b/models/rfd3/src/rfd3/utils/io.py
similarity index 99%
rename from models/rfd3/src/rfd3/util/io.py
rename to models/rfd3/src/rfd3/utils/io.py
index a11b94b..4806a69 100644
--- a/models/rfd3/src/rfd3/util/io.py
+++ b/models/rfd3/src/rfd3/utils/io.py
@@ -8,7 +8,8 @@ import numpy as np
import torch
from atomworks.io.utils.io_utils import to_cif_file
from biotite.structure import AtomArray, AtomArrayStack, stack
-from rfd3.util.alignment import weighted_rigid_align
+
+from modelhub.utils.alignment import weighted_rigid_align
DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
diff --git a/models/rfd3/src/rfd3/util/vizualize.py b/models/rfd3/src/rfd3/utils/vizualize.py
similarity index 99%
rename from models/rfd3/src/rfd3/util/vizualize.py
rename to models/rfd3/src/rfd3/utils/vizualize.py
index 1f573ff..7a62c7f 100755
--- a/models/rfd3/src/rfd3/util/vizualize.py
+++ b/models/rfd3/src/rfd3/utils/vizualize.py
@@ -258,7 +258,7 @@ def _viz_from_file(
atom_array = pickle.load(f)
elif file_path.endswith((".cif", ".cif.gz", ".bcif", ".bcif.gz")):
from atomworks.io.utils.io_utils import get_structure, read_any
- from rfd3.inference.inference_utils import (
+ from rfd3.utils.inference import (
_add_design_annotations_from_cif_block_metadata,
)
diff --git a/models/rfd3/tests/test_aa_design.py b/models/rfd3/tests/test_aa_design.py
index 90a2dc0..4379a96 100644
--- a/models/rfd3/tests/test_aa_design.py
+++ b/models/rfd3/tests/test_aa_design.py
@@ -15,7 +15,6 @@ from rfd3.testing.testing_utils import (
TEST_CFG_INFERENCE,
TEST_CFG_TRAIN,
TEST_JSON_DATA,
- build_pipelines,
instantiate_example,
)
@@ -29,6 +28,23 @@ sys.path.append(PATH_TO_SRC)
smoke_test = list(TEST_JSON_DATA.keys())
+@pytest.mark.fast
+def test_imports():
+ import rfd3
+
+ import modelhub
+
+ print("Imported rfd3 version:", rfd3)
+ print("Imported modelhub version:", modelhub)
+
+ # Try imports from main modules
+ from rfd3.metrics.losses import DiffusionLoss
+ from rfd3.model.RFD3 import RFD3
+ from rfd3.trainer.rfd3 import AADesignTrainer
+
+ print("imported modules:", RFD3, AADesignTrainer, DiffusionLoss)
+
+
def test_inference():
# Silence outputs:
@@ -60,6 +76,6 @@ def test_training_pipeline(example_name):
example.get("training_condition_name"),
)
+
if __name__ == "__main__":
- # pytest.main(sys.argv)
- pytest.main(["-v", __file__ + "::test_train_pipeline[rfd3-NA-full]"])
+ pytest.main(sys.argv)
diff --git a/models/rfd3/tests/test_metrics.py b/models/rfd3/tests/test_metrics.py
index ecc93da..bff9f38 100644
--- a/models/rfd3/tests/test_metrics.py
+++ b/models/rfd3/tests/test_metrics.py
@@ -8,7 +8,7 @@ from rfd3.testing.testing_utils import (
build_pipelines,
instantiate_example,
)
-from rfd3.trainer.rfd3_trainer import (
+from rfd3.trainer.trainer_utils import (
_cleanup_virtual_atoms_and_assign_atom_name_elements,
)
diff --git a/models/rfd3/tests/test_selections.py b/models/rfd3/tests/test_selections.py
index 8eebe91..8b7086a 100644
--- a/models/rfd3/tests/test_selections.py
+++ b/models/rfd3/tests/test_selections.py
@@ -3,18 +3,19 @@ import sys
import numpy as np
import pytest
from atomworks.io.utils.io_utils import load_any
-from rfd3.inference.components import (
+from rfd3.inference.input_parsing import DesignInputSpecification
+from rfd3.inference.parsing import InputSelection
+from rfd3.testing.testing_utils import (
+ TEST_JSON_DATA,
+)
+
+from modelhub.utils.components import (
fetch_mask_from_component,
fetch_mask_from_idx,
fetch_mask_from_name,
get_name_mask,
unravel_components,
)
-from rfd3.inference.input_parsing import InputSpecification
-from rfd3.inference.parsing import InputSelection
-from rfd3.testing.testing_utils import (
- TEST_JSON_DATA,
-)
# TEST 1 - test the selections
args = TEST_JSON_DATA["amidase_helix"]
@@ -116,7 +117,7 @@ atom_array_ref_unindexed = load_any(file)[0]
def test_unindexed_break():
# Assert that unindexed selections with breaks work
sele = InputSelection.from_any(args["unindex"], atom_array=atom_array_ref_unindexed)
- comps, breaks = InputSpecification.break_unindexed(sele)
+ comps, breaks = DesignInputSpecification.break_unindexed(sele)
if __name__ == "__main__":
diff --git a/models/rfd3/tests/test_unindexing.py b/models/rfd3/tests/test_unindexing.py
index 0dece70..efd4530 100644
--- a/models/rfd3/tests/test_unindexing.py
+++ b/models/rfd3/tests/test_unindexing.py
@@ -10,7 +10,7 @@ from rfd3.testing.testing_utils import (
instantiate_example,
load_train_or_val_cfg,
)
-from rfd3.transforms.conditioning_utils import (
+from rfd3.trainer.trainer_utils import (
process_unindexed_outputs,
)
@@ -48,10 +48,10 @@ def test_unindexed_cleanup(example, is_inference):
) # spoof inference label
atom_array_cleaned, metadata = process_unindexed_outputs(atom_array)
- from rfd3.testing.debug_utils import save_pipe_out
+ # from rfd3.testing.debug_utils import save_pipe_out
- save_pipe_out(atom_array)
- print("Metadata:", metadata)
+ # save_pipe_out(atom_array)
+ # print("Metadata:", metadata)
xyz_cleaned = np.nan_to_num(atom_array_cleaned.coord)
xyz_diffused = np.nan_to_num(atom_array[~atom_array.is_motif_atom_unindexed].coord)
diff --git a/models/rfd3/tests/transforms/test_pipeline_regression.py b/models/rfd3/tests/transforms/test_pipeline_regression.py
index 275c3a9..106ca35 100644
--- a/models/rfd3/tests/transforms/test_pipeline_regression.py
+++ b/models/rfd3/tests/transforms/test_pipeline_regression.py
@@ -192,6 +192,8 @@ def test_atom14_pipeline_regression(
# Compare results
config_desc = f" ({config.name})" if config.name != "a14-base-train" else ""
mode_description = f"{'inference' if is_inference else 'training'}{config_desc}"
+ if "specification" in result:
+ expected_result["specification"] = result["specification"]
_assert_pipeline_results_equal(
result, expected_result, example_name, mode_description
diff --git a/pyproject.toml b/pyproject.toml
index 17fca46..81cfba1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,7 +32,7 @@ dependencies = [
# ... typing & documentation
"jaxtyping>=0.2.17,<1",
"beartype>=0.18.0,<1",
- "typer>=0.9.0,<1",
+ "typer>=0.20.0,<1",
# ... ml tools (core)
"torch>=2.2.0,<3",
"lightning>=2.5.0",
@@ -41,12 +41,17 @@ dependencies = [
"einx>=0.1.0,<1",
"opt_einsum>=3.4.0,<4",
"dm-tree>=0.1.6,<1",
+ "zstandard",
+ "pandas",
+ # "biotite",
+ "atomworks"
]
[project.optional-dependencies]
rfd3 = [
"pydantic>=2.8",
+ "toolz",
]
rf3 = [
"cuequivariance_ops_cu12>=0.6.1; sys_platform == 'linux'",
@@ -72,7 +77,8 @@ dev = [
"pytest-dotenv>=0.5.2,<1", # load environment variables from .env file
"pytest-cov>=4.1.0,<5", # generate coverage report
"pytest-benchmark>=5.0.0,<6", # benchmark tests for speed
- "atomworks==1.0.2",
+ # "atomworks==1.0.2",
+ "pre-commit"
]
[project.scripts]
rf3 = "rf3.cli:app"
diff --git a/src/modelhub/__init__.py b/src/modelhub/__init__.py
index e0d7b10..b803c6a 100644
--- a/src/modelhub/__init__.py
+++ b/src/modelhub/__init__.py
@@ -49,5 +49,9 @@ try:
except ImportError:
logger.debug("cuEquivariance unavailable: import failed")
+
+# Whether to disable checkpointing globally
+DISABLE_CHECKPOINTING = False
+
# Export for easy access
-__all__ = ["SHOULD_USE_CUEQUIVARIANCE"]
+__all__ = ["SHOULD_USE_CUEQUIVARIANCE", "DISABLE_CHECKPOINTING"]
diff --git a/src/modelhub/constants.py b/src/modelhub/constants.py
new file mode 100644
index 0000000..59e2d03
--- /dev/null
+++ b/src/modelhub/constants.py
@@ -0,0 +1,28 @@
+# fmt: off
+# ... For convenience, define BKBN, or TIP to be used as a shortcut | TIP is the largest set of fixed atom given at least 2 tip atoms
+TIP_BY_RESTYPE = {
+ "TRP": ["CG","CD1","CD2","NE1","CE2","CE3","CZ2","CZ3","CH2"], # fix both rings
+ "HIS": ["CG","ND1","CD2","CE1","NE2"], # fixed ring
+ "TYR": ["CZ","OH"], # keeps ring dihedral flexible
+ "PHE": ["CG","CD1","CD2","CE1","CE2","CZ"],
+ "ASN": ["CB", "CG","OD1","ND2"],
+ "ASP": ["CB", "CG","OD1","OD2"],
+ "GLN": ["CG", "CD","OE1","NE2"],
+ "GLU": ["CG", "CD","OE1","OE2"],
+ "CYS": ["CB", "SG"],
+ "SER": ["CB", "OG"],
+ "THR": ["CB", "OG1"],
+ "LEU": ["CB", "CG", "CD1", "CD2"],
+ "VAL": ["CG1", "CG2"],
+ "ILE": ["CB", "CG2"],
+ "MET": ["SD", "CE"],
+ "LYS": ["CE","NZ"],
+ "ARG": ["CD","NE","CZ","NH1","NH2"],
+ "PRO": None,
+ "ALA": None,
+ "GLY": None,
+ "UNK": None,
+ "MSK": None
+}
+
+# fmt: on
diff --git a/src/modelhub/inference_engines/base.py b/src/modelhub/inference_engines/base.py
new file mode 100644
index 0000000..f7fe6d7
--- /dev/null
+++ b/src/modelhub/inference_engines/base.py
@@ -0,0 +1,216 @@
+import logging
+from os import PathLike
+from pathlib import Path
+from typing import Any, Dict
+
+import hydra
+import torch
+from biotite.structure import AtomArray
+from lightning.fabric import seed_everything
+from omegaconf import OmegaConf
+
+from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
+from modelhub.utils.logging import print_config_tree
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+ datefmt="%H:%M:%S",
+)
+ranked_logger = RankedLogger(__name__, rank_zero_only=True)
+
+
+def merge(cfg, overrides: dict):
+ return OmegaConf.merge(cfg, OmegaConf.create(overrides))
+
+
+class BaseInferenceEngine:
+ """
+ Base inference engine.
+ Separates model setup (expensive, once) from inference (can run multiple times).
+ """
+
+ def __init__(
+ self,
+ ckpt_path: PathLike,
+ num_nodes: int = 1,
+ devices_per_node: int = 1,
+ # Config overrides
+ transform_overrides={},
+ inference_sampler_overrides={},
+ trainer_overrides={},
+ # Debug
+ print_config: bool = False,
+ seed: int | None = None,
+ ):
+ """Initialize inference engine and load model.
+
+ Model config is loaded from checkpoint and overridden with parameters provided here.
+
+ Args:
+ ckpt_path: Path to model checkpoint.
+ seed: Random seed. If None, uses external RNG state. Defaults to ``None``.
+ num_nodes: Number of nodes for distributed inference. Defaults to ``1``.
+ devices_per_node: Number of devices per node. Defaults to ``1``.
+ print_config: Whether to print config trees. Defaults to ``False``.
+ """
+ # Set attrs
+ self.initialized_ = False
+ self.trainer = None
+ self.pipeline = None
+ self.print_config = print_config
+ self.ckpt_path = ckpt_path
+
+ # Set random seed (only if seed is not None)
+ if seed is not None:
+ ranked_logger.info(f"Seeding everything with seed={seed}...")
+ seed_everything(seed, workers=True, verbose=True)
+ else:
+ ranked_logger.info("Seed is None - using external RNG state")
+ self.seed = seed
+
+ # Stored for later;
+ self.transform_overrides = transform_overrides
+ self.overrides: dict[str, Any] = {}
+
+ base_overrides = {
+ "trainer.seed": seed,
+ "trainer.metrics": {},
+ "trainer.loss": None,
+ "trainer.num_nodes": num_nodes,
+ "trainer.devices_per_node": devices_per_node,
+ }
+ for key, value in base_overrides.items():
+ self._assign_override(key, value)
+
+ for key, value in trainer_overrides.items():
+ self._assign_override(f"trainer.{key}", value)
+
+ for key, value in inference_sampler_overrides.items():
+ self._assign_override(f"model.net.inference_sampler.{key}", value)
+
+ ###################################################################################
+ # Required subclasss methods
+ ###################################################################################
+
+ def initialize(self):
+ if self.initialized_:
+ return getattr(self, "cfg", None)
+
+ # Load checkpoint and config
+ ranked_logger.info(
+ f"Loading checkpoint from {Path(self.ckpt_path).resolve()}..."
+ )
+ checkpoint = torch.load(self.ckpt_path, "cpu", weights_only=False)
+ cfg = self._override_checkpoint_config(checkpoint["train_cfg"])
+
+ # Load pipeline first before trainer/model
+ self._construct_pipeline(cfg)
+ self._construct_trainer(cfg, checkpoint=checkpoint)
+
+ ranked_logger.info("Model loaded and ready for inference.")
+ self.initialized_ = True
+ return cfg
+
+ def run(
+ self,
+ inputs: (
+ Dict[str, dict] | AtomArray | list[AtomArray] | PathLike | list[PathLike]
+ ),
+ *_,
+ ) -> dict[str, dict] | None:
+ self.initialize()
+ raise NotImplementedError(
+ "Subclasses must implement inference logic in `run` method."
+ )
+
+ ###################################################################################
+ # Util methods
+ ###################################################################################
+
+ def _override_checkpoint_config(self, cfg):
+ cfg = merge(cfg, self.overrides)
+ cfg = set_accelerator_based_on_availability(cfg)
+ return cfg
+
+ def _construct_trainer(self, cfg, checkpoint=None):
+ """
+ Sets attr self.trainer
+ """
+ # Instantiate trainer
+ ranked_logger.info("Instantiating trainer...")
+ if self.print_config:
+ print_config_tree(
+ cfg.trainer, resolve=True, title="INFERENCE TRAINER CONFIGURATION"
+ )
+ trainer = hydra.utils.instantiate(
+ cfg.trainer,
+ _convert_="partial",
+ _recursive_=False,
+ )
+
+ # Setup model
+ ranked_logger.info("Setting up model...")
+ trainer.fabric.launch()
+ trainer.initialize_or_update_trainer_state(
+ {"train_cfg": cfg}
+ ) # config from training stores net params
+ trainer.construct_model()
+
+ ranked_logger.info("Loading model weights from checkpoint...")
+ trainer.load_checkpoint(checkpoint=checkpoint or self.ckpt_path)
+
+ # Ensure optimizer isn't loaded
+ trainer.state["optimizer"] = None
+ trainer.state["train_cfg"].model.optimizer = None
+ trainer.setup_model_optimizers_and_schedulers()
+ trainer.state["model"].eval()
+ self.trainer = trainer
+
+ def _assign_override(self, dotted_key: str, value: Any) -> None:
+ """Assign ``value`` into ``self.overrides`` using a dotted path."""
+ target = self.overrides
+ keys = dotted_key.split(".")
+ for key in keys[:-1]:
+ if key not in target or not isinstance(target[key], dict):
+ target[key] = {}
+ target = target[key]
+ target[keys[-1]] = value
+
+ def _construct_pipeline(self, cfg):
+ """
+ Sets attr self.pipeline
+ """
+ # Construct pipeline
+ ranked_logger.info("Building Transform pipeline...")
+ first_val_dataset_key, first_val_dataset = next(iter(cfg.datasets.val.items()))
+ ranked_logger.info(
+ f"Using settings from validation dataset: {first_val_dataset_key}."
+ )
+ transform = first_val_dataset.dataset.transform
+ transform = merge(transform, self.transform_overrides)
+
+ if self.print_config:
+ print_config_tree(
+ transform,
+ resolve=True,
+ title="INFERENCE TRANSFORM PIPELINE",
+ )
+
+ self.pipeline = hydra.utils.instantiate(transform)
+
+ # aliases for run
+ def forward(self, *args, **kwargs):
+ return self.run(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ return self.run(*args, **kwargs)
+
+ # for use as a context manager: e.g. `with BaseInferenceEngine(...) as engine:` to automatically cleanup
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc, tb):
+ self.trainer = None
+ self.pipeline = None
+ self.initialized_ = False
diff --git a/src/modelhub/metrics/losses.py b/src/modelhub/metrics/losses.py
new file mode 100644
index 0000000..3d5bb62
--- /dev/null
+++ b/src/modelhub/metrics/losses.py
@@ -0,0 +1,27 @@
+import hydra
+import torch.nn as nn
+
+
+class Loss(nn.Module):
+ def __init__(self, **losses):
+ super().__init__()
+ self.to_compute = []
+ for loss_name, loss in losses.items():
+ loss_fn = hydra.utils.instantiate(loss)
+ print(f"Adding loss {loss_name} to the loss function")
+ self.to_compute.append(loss_fn)
+
+ def forward(
+ self,
+ network_input,
+ network_output,
+ loss_input,
+ ):
+ loss_dict = {}
+ loss = 0
+ for loss_fn in self.to_compute:
+ loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
+ loss += loss_
+ loss_dict.update(loss_dict_)
+ loss_dict["total_loss"] = loss.detach()
+ return loss, loss_dict
diff --git a/src/modelhub/model/layers/blocks.py b/src/modelhub/model/layers/blocks.py
new file mode 100644
index 0000000..5c769f9
--- /dev/null
+++ b/src/modelhub/model/layers/blocks.py
@@ -0,0 +1,47 @@
+import torch
+import torch.nn as nn
+
+pi = torch.acos(torch.zeros(1)).item() * 2
+
+
+class FourierEmbedding(nn.Module):
+ def __init__(self, c):
+ super().__init__()
+ self.c = c
+ self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
+ self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ # super().reset_parameters()
+ nn.init.normal_(self.w)
+ nn.init.normal_(self.b)
+
+ def forward(
+ self,
+ t, # [D]
+ ):
+ return torch.cos(2 * pi * (t[..., None] * self.w + self.b))
+
+
+class Dropout(nn.Module):
+ # Dropout entire row or column
+ def __init__(self, broadcast_dim=None, p_drop=0.15):
+ super(Dropout, self).__init__()
+ # give ones with probability of 1-p_drop / zeros with p_drop
+ self.sampler = torch.distributions.bernoulli.Bernoulli(
+ torch.tensor([1 - p_drop])
+ )
+ self.broadcast_dim = broadcast_dim
+ self.p_drop = p_drop
+
+ def forward(self, x):
+ if not self.training: # no drophead during evaluation mode
+ return x
+ shape = list(x.shape)
+ if self.broadcast_dim is not None:
+ shape[self.broadcast_dim] = 1
+ mask = self.sampler.sample(shape).to(x.device).view(shape)
+
+ x = mask * x / (1.0 - self.p_drop)
+ return x
diff --git a/src/modelhub/trainers/fabric.py b/src/modelhub/trainers/fabric.py
index 0d08a23..68fc6e0 100755
--- a/src/modelhub/trainers/fabric.py
+++ b/src/modelhub/trainers/fabric.py
@@ -69,11 +69,13 @@ class FabricTrainer(ABC):
checkpoint_every_n_steps: int | None = None,
clip_grad_max_norm: float | None = None,
skip_nan_grad: bool = False,
+ error_if_grad_nonfinite: bool = False,
limit_train_batches: int | float = float("inf"),
limit_val_batches: int | float = float("inf"),
prevalidate: bool = False,
nccl_timeout: int = 3_200,
find_unused_parameters: bool = False,
+ skip_optimizer_loading: bool = False,
) -> None:
"""Base Trainer class built around Lightning Fabric.
@@ -103,6 +105,7 @@ class FabricTrainer(ABC):
clip_grad_max_norm: Maximum gradient norm to clip to (default: None). If None, no gradient clipping is performed.
skip_nan_grad: Whether to skip optimizer updates when gradients contain NaN or Inf values (default: False).
Useful for training stability, especially with mixed precision or challenging datasets.
+ error_if_grad_nonfinite: Whether to raise when gradient clipping detects NaN or Inf gradients (default: False).
limit_train_batches: Limit on the number of training batches per epoch (default: float("inf")).
Helpful for debugging; should NOT be used when training production models.
limit_val_batches: Limit on the number of validation batches per epoch (default: float("inf")).
@@ -111,6 +114,7 @@ class FabricTrainer(ABC):
nccl_timeout: Timeout for NCCL operations (default: 3200). Only used with DDP strategy.
find_unused_parameters: Whether to let DDP find and skip gradient synchronization for unused parameters in the model (default: False). NOTE: Setting to True will incur a performance penalty,
but allow for training for bespoke use cases where parts of the model are frozen.
+ skip_optimizer_loading: Whether to skip loading the optimizer/scheduler state when restoring from checkpoints (default: False).
References:
(1) Fabric Arguments (https://lightning.ai/docs/fabric/stable/api/fabric_args.html)
@@ -145,6 +149,7 @@ class FabricTrainer(ABC):
# Training
self.clip_grad_max_norm = clip_grad_max_norm
self.skip_nan_grad = skip_nan_grad
+ self.error_if_grad_nonfinite = error_if_grad_nonfinite
self.grad_accum_steps = grad_accum_steps
# Stopping
@@ -162,6 +167,7 @@ class FabricTrainer(ABC):
self.output_dir = Path(output_dir) if output_dir else None
self.checkpoint_every_n_epochs = checkpoint_every_n_epochs
self.checkpoint_every_n_steps = checkpoint_every_n_steps
+ self.skip_optimizer_loading = skip_optimizer_loading
# Inject trainer reference into callbacks for easy access to trainer state
self._inject_trainer_into_callbacks()
@@ -262,14 +268,28 @@ class FabricTrainer(ABC):
)
self.initialize_or_update_trainer_state({"scheduler_cfg": scheduler_cfg})
- @abstractmethod
def construct_model(self):
"""Instantiate the model, updating the trainer state in-place.
This method must set the "model" key in the state dictionary using `self.initialize_or_update_trainer_state()`.
For an example, see the `construct_model` method in the `AF3Trainer`
+ Construct the model and optionally wrap with EMA.
"""
- raise NotImplementedError
+ # ... instantiate model with Hydra and Fabric
+ with self.fabric.init_module():
+ ranked_logger.info("Instantiating model...")
+
+ model = hydra.utils.instantiate(
+ self.state["train_cfg"].model.net,
+ _recursive_=False,
+ )
+
+ # Optionally, wrap the model with EMA
+ if self.state["train_cfg"].model.ema is not None:
+ ranked_logger.info("Wrapping model with EMA...")
+ model = EMA(model, **self.state["train_cfg"].model.ema)
+
+ self.initialize_or_update_trainer_state({"model": model})
def setup_model_optimizers_and_schedulers(self) -> None:
"""Setup the model, optimizer(s), and scheduler(s) with Fabric.
@@ -354,6 +374,11 @@ class FabricTrainer(ABC):
), "Checkpoint path not found in checkpoint configuration!"
ckpt_path = Path(ckpt_config.path)
+ reset_optimizer = bool(
+ getattr(ckpt_config, "reset_optimizer", False)
+ or self.skip_optimizer_loading
+ )
+
if ckpt_path.is_dir():
# If given a directory, load the latest checkpoint from the directory
ranked_logger.info(
@@ -362,14 +387,14 @@ class FabricTrainer(ABC):
self.load_checkpoint(
self.get_latest_checkpoint(ckpt_path),
weight_loading_config=ckpt_config.weight_loading_config,
- reset_optimizer=ckpt_config.reset_optimizer,
+ reset_optimizer=reset_optimizer,
)
else:
# If given a specific checkpoint file, load that checkpoint
self.load_checkpoint(
ckpt_path,
weight_loading_config=ckpt_config.weight_loading_config,
- reset_optimizer=ckpt_config.reset_optimizer,
+ reset_optimizer=reset_optimizer,
)
# Apply parameter freezing if a freezing config is provided
@@ -716,7 +741,7 @@ class FabricTrainer(ABC):
module=model,
optimizer=optimizer,
max_norm=self.clip_grad_max_norm,
- error_if_nonfinite=False, # Don't error on NaN/Inf if skip_nan_grad is True
+ error_if_nonfinite=self.error_if_grad_nonfinite,
)
# ... step the optimizer
diff --git a/models/rfd3/src/rfd3/util/alignment.py b/src/modelhub/utils/alignment.py
similarity index 88%
rename from models/rfd3/src/rfd3/util/alignment.py
rename to src/modelhub/utils/alignment.py
index 67bc4ee..97959c8 100644
--- a/models/rfd3/src/rfd3/util/alignment.py
+++ b/src/modelhub/utils/alignment.py
@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
def weighted_rigid_align(
X_L, # [B, L, 3]
X_gt_L, # [B, L, 3]
- X_exists_L, # [L]
- w_L, # [B, L]
+ X_exists_L=None, # [L]
+ w_L=None, # [B, L]
):
"""
Weighted rigid body alignment of X_gt_L onto X_L with weights w_L
@@ -21,6 +21,13 @@ def weighted_rigid_align(
assert X_L.shape == X_gt_L.shape
assert X_L.shape[:-1] == w_L.shape
+ if X_exists_L is None:
+ X_exists_L = torch.ones((X_L.shape[-2]), dtype=torch.bool)
+ if w_L is None:
+ w_L = torch.ones_like(X_L[..., 0])
+ else:
+ w_L = w_L.to(torch.float32)
+
# Assert `X_exists_L` is a boolean mask
assert (
X_exists_L.dtype == torch.bool
diff --git a/models/rfd3/src/rfd3/inference/components.py b/src/modelhub/utils/components.py
similarity index 100%
rename from models/rfd3/src/rfd3/inference/components.py
rename to src/modelhub/utils/components.py
diff --git a/src/modelhub/utils/ddp.py b/src/modelhub/utils/ddp.py
index bc7b015..8df0b29 100644
--- a/src/modelhub/utils/ddp.py
+++ b/src/modelhub/utils/ddp.py
@@ -44,6 +44,7 @@ def set_accelerator_based_on_availability(cfg: dict | DictConfig):
cfg.trainer.num_nodes = 1
else:
cfg.trainer.accelerator = "gpu"
+ return cfg
class RankedLogger(logging.LoggerAdapter):
diff --git a/models/rf3/src/rf3/flow_matching/rigid_utils.py b/src/modelhub/utils/rigid.py
similarity index 100%
rename from models/rf3/src/rf3/flow_matching/rigid_utils.py
rename to src/modelhub/utils/rigid.py
diff --git a/models/rf3/src/rf3/data/rotation_augmentation.py b/src/modelhub/utils/rotation_augmentation.py
similarity index 97%
rename from models/rf3/src/rf3/data/rotation_augmentation.py
rename to src/modelhub/utils/rotation_augmentation.py
index 857f52b..1f0ee71 100644
--- a/models/rf3/src/rf3/data/rotation_augmentation.py
+++ b/src/modelhub/utils/rotation_augmentation.py
@@ -1,7 +1,8 @@
import math
import torch
-from rf3.flow_matching.rigid_utils import rot_vec_mul
+
+from modelhub.utils.rigid import rot_vec_mul
def centre(X_L, X_exists_L):