refactor source files for open sourcing (#648)

* mc

* Add base class for inference engine

* refactor inference engine

* Move constants and components out of rfd3 folder

* Fixes to engine

* Update with working checkpoint

* revert layer utils

* Fix more imports

* Move alignment, conditiontransitionblock

* Update sampler name

* mc

* More import fixes

* make format

* Minor fixes

* mc

* Fix rf3 inference engine

* Fix inference sampler

* Fix modules

* Running inference

* Make format

* add pre-commit hook

* fix: RF3 inference (#670)

fix: make rf3 tests in new format

* Minor cleanup

---------

Co-authored-by: Nathaniel Corley <ncorley@uw.edu>
This commit is contained in:
Jasper Butcher
2025-11-20 16:29:47 -08:00
committed by GitHub
parent aa4cb6875f
commit 3dba499b6d
82 changed files with 2552 additions and 2318 deletions

8
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,8 @@
repos:
- repo: local
hooks:
- id: make-format
name: Run `make format`
entry: make format
language: system
pass_filenames: false

View File

@@ -111,3 +111,14 @@ To add a new model:
2. Add `modelhub` as a dependency
3. Implement model-specific code in `models/<model_name>/src/`
4. Users can install with: `uv pip install -e ./models/<model_name>`
### Pre-commit Formatting
We ship a `.pre-commit-config.yaml` that runs `make format` (via `ruff format`) before each commit. Enable it once per clone:
```bash
pip install pre-commit # if not already installed
pre-commit install
```
After installation the hook automatically formats the repo whenever you `git commit`. Use `pre-commit run --all-files` to apply it manually.

View File

@@ -621,3 +621,4 @@ view(atom_array)
**Alternative viewing options:**
- View in PyMol like normal, or using `pymol_remote`
- Use the `view_pymol()` function for direct PyMol integration

View File

@@ -24,7 +24,7 @@ classifiers = [
dependencies = [
# Core functionality shared across all models
"modelforge",
# "modelforge",
# CLI
"typer>=0.9.0,<1",
# RF3-specific ML dependencies

View File

@@ -7,7 +7,7 @@ from typing import Any
import numpy as np
from atomworks.common import exists
from atomworks.enums import ChainType
from atomworks.ml.datasets import logger, StructuralDatasetWrapper
from atomworks.ml.datasets import StructuralDatasetWrapper, logger
from atomworks.ml.datasets.parsers import (
MetadataRowParser,
load_example_from_metadata_row,

View File

@@ -1,9 +1,9 @@
import torch
from beartype.typing import Any, Literal
from jaxtyping import Float
from rf3.data.rotation_augmentation import centre_random_augmentation
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.rotation_augmentation import centre_random_augmentation
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -2,7 +2,6 @@ import logging
from os import PathLike
from pathlib import Path
import hydra
import pandas as pd
import torch
import torch.distributed as dist
@@ -12,13 +11,12 @@ from atomworks.ml.preprocessing.msa.finding import (
)
from atomworks.ml.samplers import LoadBalancedDistributedSampler
from biotite.structure import AtomArray
from lightning.fabric import seed_everything
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from modelhub.inference_engines.base import BaseInferenceEngine
from modelhub.metrics.metric import MetricManager
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
from modelhub.utils.logging import print_config_tree
from modelhub.utils.ddp import RankedLogger
from rf3.model.RF3 import ShouldEarlyStopFn
from rf3.utils.inference import (
InferenceInput,
@@ -63,7 +61,7 @@ def should_early_stop_by_mean_plddt(
return fn
class RF3InferenceEngine:
class RF3InferenceEngine(BaseInferenceEngine):
"""RF3 inference engine.
Separates model setup (expensive, once) from inference (can run multiple times).
@@ -84,71 +82,44 @@ class RF3InferenceEngine:
def __init__(
self,
ckpt_path: PathLike,
# Model parameters
n_recycles: int = 10,
diffusion_batch_size: int = 5,
num_steps: int = 50,
seed: int | None = None,
# Templating, MSAs, etc.
template_noise_scale: float = 1e-5,
early_stopping_plddt_threshold: float | None = None,
metrics_cfg: dict | OmegaConf | MetricManager | None = None,
num_nodes: int = 1,
devices_per_node: int = 1,
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
# Output control
compress_outputs: bool = True,
# Debug
print_config: bool = False,
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
early_stopping_plddt_threshold: float | None = None,
# Metrics
metrics_cfg: dict | OmegaConf | MetricManager | None = None,
**kwargs,
):
"""Initialize inference engine and load model.
Model config is loaded from checkpoint and overridden with parameters provided here.
Args:
ckpt_path: Path to model checkpoint.
n_recycles: Number of recycles. Defaults to ``10``.
diffusion_batch_size: Number of structures to generate per input. Defaults to ``5``.
num_steps: Number of diffusion steps. Defaults to ``50``.
seed: Random seed. If None, uses external RNG state. Defaults to ``None``.
template_noise_scale: Noise scale for template coordinates. Defaults to ``1e-5``.
raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
compress_outputs: Whether to gzip output files. Defaults to ``True``.
early_stopping_plddt_threshold: Stop early if pLDDT below threshold. Defaults to ``None``.
metrics_cfg: Metrics configuration. Can be:
- dict/OmegaConf with Hydra configs
- Pre-instantiated MetricManager
- None (no metrics).
Defaults to ``None``.
num_nodes: Number of nodes for distributed inference. Defaults to ``1``.
devices_per_node: Number of devices per node. Defaults to ``1``.
compress_outputs: Whether to gzip output files. Defaults to ``True``.
print_config: Whether to print config trees. Defaults to ``False``.
raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
**kwargs: Additional arguments passed to BaseInferenceEngine:
- ckpt_path (PathLike, required): Path to model checkpoint.
- seed (int | None): Random seed. If None, uses external RNG state. Defaults to ``None``.
- num_nodes (int): Number of nodes for distributed inference. Defaults to ``1``.
- devices_per_node (int): Number of devices per node. Defaults to ``1``.
- print_config (bool): Whether to print config trees. Defaults to ``False``.
"""
# Load checkpoint and config
ranked_logger.info(f"Loading checkpoint from {Path(ckpt_path).resolve()}...")
checkpoint = torch.load(ckpt_path, "cpu", weights_only=False)
self.cfg = OmegaConf.create(checkpoint["train_cfg"])
# Override config with inference parameters
self.cfg.model.net.inference_sampler.num_timesteps = num_steps
self.cfg.trainer.num_nodes = num_nodes
self.cfg.trainer.devices_per_node = devices_per_node
# Set metrics - can be dict/OmegaConf or MetricManager
# Store MetricManager separately since OmegaConf can't serialize it
self._custom_metric_manager = None
if isinstance(metrics_cfg, MetricManager):
# Already instantiated - store separately and pass to trainer later
self._custom_metric_manager = metrics_cfg
self.cfg.trainer["metrics"] = {} # Empty dict in config
elif metrics_cfg is not None:
# Hydra config dict
self.cfg.trainer["metrics"] = metrics_cfg
else:
self.cfg.trainer["metrics"] = {}
set_accelerator_based_on_availability(self.cfg)
# set MSA directories
if env_var_msa_dirs := get_msa_dirs_from_env(raise_if_not_set=False):
override_msa_dirs = [str(msa_dir) for msa_dir in env_var_msa_dirs]
@@ -163,112 +134,72 @@ class RF3InferenceEngine:
]
ranked_logger.info(f"Using default MSA directories: {override_msa_dirs}")
# Dataset overrides
self.dataset_overrides = {
"diffusion_batch_size": diffusion_batch_size,
"n_recycles": n_recycles,
"raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
"undesired_res_names": [],
"template_noise_scales": {
"atomized": template_noise_scale,
"not_atomized": template_noise_scale,
super().__init__(
transform_overrides={
"diffusion_batch_size": diffusion_batch_size,
"n_recycles": n_recycles,
"raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
"undesired_res_names": [],
"template_noise_scales": {
"atomized": template_noise_scale,
"not_atomized": template_noise_scale,
},
"allowed_chain_types_for_conditioning": None,
"protein_msa_dirs": [
{
"dir": msa_dir,
"extension": extension.value,
"directory_depth": depth,
}
for msa_dir, depth, extension in [
(msa_dir, *get_msa_depth_and_ext_from_folder(Path(msa_dir)))
for msa_dir in override_msa_dirs
]
],
"rna_msa_dirs": [],
# (Paranoia - in validation, these should be set correctly anyhow)
"p_give_polymer_ref_conf": 0.0,
"p_give_non_polymer_ref_conf": 0.0,
"p_dropout_ref_conf": 0.0,
"use_element_for_atom_names_of_atomized_tokens": True,
},
"allowed_chain_types_for_conditioning": None,
"protein_msa_dirs": [
{
"dir": msa_dir,
"extension": extension.value,
"directory_depth": depth,
}
for msa_dir, depth, extension in [
(msa_dir, *get_msa_depth_and_ext_from_folder(Path(msa_dir)))
for msa_dir in override_msa_dirs
]
],
"rna_msa_dirs": [],
# (Paranoia - in validation, these should be set correctly anyhow)
"p_give_polymer_ref_conf": 0.0,
"p_give_non_polymer_ref_conf": 0.0,
"p_dropout_ref_conf": 0.0,
"use_element_for_atom_names_of_atomized_tokens": True,
}
self.print_config = print_config
# Set random seed (only if seed is not None)
if seed is not None or self.cfg.seed is not None:
seed = seed or self.cfg.seed
ranked_logger.info(f"Seeding everything with seed={seed}...")
seed_everything(seed, workers=True, verbose=True)
else:
ranked_logger.info("Seed is None - using external RNG state")
# Instantiate trainer
ranked_logger.info("Instantiating trainer...")
if self.print_config:
print_config_tree(
self.cfg.trainer, resolve=True, title="INFERENCE TRAINER CONFIGURATION"
)
self.trainer = hydra.utils.instantiate(
self.cfg.trainer,
_convert_="partial",
_recursive_=False,
inference_sampler_overrides={
"num_timesteps": num_steps,
},
**kwargs,
)
# If we have a custom MetricManager, override the trainer's metrics
if self._custom_metric_manager is not None:
self.trainer.metrics = self._custom_metric_manager
# remove loss override if present (i.e. keep from checkpoint)
self.overrides["trainer"].pop("loss", None)
self.ckpt_path = ckpt_path
# Store metrics config for later - will be set directly on trainer in initialize()
self._metrics_cfg = metrics_cfg
# Dataset overrides
self.early_stopping_plddt_threshold = early_stopping_plddt_threshold
self.compress_outputs = compress_outputs
# Setup model
ranked_logger.info("Setting up model...")
self.trainer.fabric.launch()
self.trainer.initialize_or_update_trainer_state({"train_cfg": self.cfg})
self.trainer.construct_model()
def initialize(self):
cfg = super().initialize()
ranked_logger.info("Loading model weights from checkpoint...")
self.trainer.load_checkpoint(checkpoint=checkpoint)
if cfg is not None:
self.cfg = cfg # store for later use
# Ensure optimizer isn't loaded
self.trainer.state["optimizer"] = None
self.trainer.state["train_cfg"].model.optimizer = None
# Set trainer metrics directly based on what was requested
# This bypasses the OmegaConf merge issue with empty dicts
if isinstance(self._metrics_cfg, MetricManager):
# Already instantiated - use directly
self.trainer.metrics = self._metrics_cfg
elif self._metrics_cfg is not None:
# Hydra config dict - instantiate MetricManager
self.trainer.metrics = MetricManager.instantiate_from_hydra(
metrics_cfg=self._metrics_cfg
)
else:
# No metrics requested - disable them
self.trainer.metrics = None
self.trainer.setup_model_optimizers_and_schedulers()
self.trainer.state["model"].eval()
# Construct pipeline
ranked_logger.info("Building Transform pipeline...")
first_val_dataset_key, first_val_dataset = next(
iter(self.cfg.datasets.val.items())
)
ranked_logger.info(
f"Using settings from validation dataset: {first_val_dataset_key}."
)
assert (
first_val_dataset.dataset.transform.is_inference
), "Inference must be enabled for the validation dataset."
# Provide manual overrides to dataset config
for key, value in self.dataset_overrides.items():
first_val_dataset.dataset.transform[key] = value
if self.print_config:
print_config_tree(
first_val_dataset.dataset.transform,
resolve=True,
title="INFERENCE TRANSFORM PIPELINE",
)
self.pipeline = hydra.utils.instantiate(
first_val_dataset.dataset.transform,
)
ranked_logger.info("Model loaded and ready for inference.")
return cfg
def run(
self,
@@ -314,6 +245,8 @@ class RF3InferenceEngine:
If ``out_dir`` is None: Dict mapping example_id to results dict.
If ``out_dir`` is set: None (results saved to disk).
"""
self.initialize()
# Setup output directory if provided
out_dir = Path(out_dir) if out_dir else None
if out_dir:

View File

@@ -1,10 +1,9 @@
import hydra
import numpy as np
import torch
import torch.nn as nn
from rf3.alignment import weighted_rigid_align
from modelhub.training.checkpoint import activation_checkpointing
from modelhub.utils.alignment import weighted_rigid_align
# resolve residue-level symmetries in native vs pred
@@ -360,31 +359,6 @@ class SubunitSymmetryResolution(nn.Module):
return loss_input
class Loss(nn.Module):
def __init__(self, **losses):
super().__init__()
self.to_compute = []
for loss_name, loss in losses.items():
loss_fn = hydra.utils.instantiate(loss)
print(f"Adding loss {loss_name} to the loss function")
self.to_compute.append(loss_fn)
def forward(
self,
network_input,
network_output,
loss_input,
):
loss_dict = {}
loss = 0
for loss_fn in self.to_compute:
loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
loss += loss_
loss_dict.update(loss_dict_)
loss_dict["total_loss"] = loss.detach()
return loss, loss_dict
class ProteinLigandBondLoss(nn.Module):
def __init__(self, weight):
super().__init__()

View File

@@ -15,6 +15,7 @@ from rf3.model.layers.pairformer_layers import (
RF3TemplateEmbedder,
)
from modelhub.model.layers.blocks import FourierEmbedding
from modelhub.training.checkpoint import activation_checkpointing
logger = logging.getLogger(__name__)
@@ -229,29 +230,6 @@ class DiffusionConditioning(nn.Module):
return _run_conditioning(Z_II, S_trunk_I, S_inputs_I)
pi = torch.acos(torch.zeros(1)).item() * 2
class FourierEmbedding(nn.Module):
def __init__(self, c):
super().__init__()
self.c = c
self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
self.reset_parameters()
def reset_parameters(self) -> None:
# super().reset_parameters()
nn.init.normal_(self.w)
nn.init.normal_(self.b)
def forward(
self,
t, # [D]
):
return torch.cos(2 * pi * (t[:, None] * self.w + self.b))
class DistogramHead(nn.Module):
def __init__(
self,

View File

@@ -16,10 +16,10 @@ from rf3.model.layers.outer_product import (
OuterProductMean_AF3,
)
from rf3.model.RF3_blocks import MSAPairWeightedAverage, MSASubsampleEmbedder
from rf3.util_module import Dropout
from torch import nn
from torch.nn.functional import one_hot, relu
from modelhub.model.layers.blocks import Dropout
from modelhub.training.checkpoint import activation_checkpointing

View File

@@ -5,7 +5,6 @@ from einops import repeat
from jaxtyping import Float, Int
from lightning_utilities import apply_to_collection
from omegaconf import DictConfig
from rf3.loss.af3_losses import Loss as AF3Loss
from rf3.loss.af3_losses import (
ResidueSymmetryResolution,
SubunitSymmetryResolution,
@@ -15,6 +14,7 @@ from rf3.utils.io import build_stack_from_atom_array_and_batched_coords
from rf3.utils.recycling import get_recycle_schedule
from modelhub.common import exists
from modelhub.metrics.losses import Loss
from modelhub.metrics.metric import MetricManager
from modelhub.trainers.fabric import FabricTrainer
from modelhub.training.EMA import EMA
@@ -42,6 +42,7 @@ class RF3Trainer(FabricTrainer):
n_recycles_train: int | None = None,
loss: DictConfig | dict | None = None,
metrics: DictConfig | dict | MetricManager | None = None,
seed=None, # dumped
**kwargs,
):
"""See `FabricTrainer` for the additional initialization arguments.
@@ -79,7 +80,7 @@ class RF3Trainer(FabricTrainer):
self.metrics = None
# Loss
self.loss = AF3Loss(**loss) if loss else None
self.loss = Loss(**loss) if loss else None
# (Symmetry resolution)
self.subunit_symm_resolve = SubunitSymmetryResolution()

View File

@@ -1,6 +1,5 @@
import numpy as np
import torch
import torch.nn as nn
def init_lecun_normal(module, scale=1.0):
@@ -38,29 +37,6 @@ def create_custom_forward(module, **kwargs):
return custom_forward
class Dropout(nn.Module):
# Dropout entire row or column
def __init__(self, broadcast_dim=None, p_drop=0.15):
super(Dropout, self).__init__()
# give ones with probability of 1-p_drop / zeros with p_drop
self.sampler = torch.distributions.bernoulli.Bernoulli(
torch.tensor([1 - p_drop])
)
self.broadcast_dim = broadcast_dim
self.p_drop = p_drop
def forward(self, x):
if not self.training: # no drophead during evaluation mode
return x
shape = list(x.shape)
if self.broadcast_dim is not None:
shape[self.broadcast_dim] = 1
mask = self.sampler.sample(shape).to(x.device).view(shape)
x = mask * x / (1.0 - self.p_drop)
return x
def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5):
# Distance radial basis function
D_max = D_min + (D_count - 1) * D_sigma

View File

@@ -8,8 +8,8 @@ from atomworks.ml.utils.io import apply_sharding_pattern
from atomworks.ml.utils.misc import hash_sequence
from beartype.typing import Literal
from biotite.structure import AtomArray, AtomArrayStack, stack
from rf3.alignment import weighted_rigid_align
from modelhub.utils.alignment import weighted_rigid_align
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -171,7 +171,7 @@ def test_inference_regression(example_id, rmsd_tolerance, csv_tolerance):
baseline_dir = TEST_DATA_DIR / "inference_regression_tests" / example_id
with (
initialize(config_path="../configs"),
initialize(config_path="../configs", version_base="1.3"),
tempfile.TemporaryDirectory() as temp_dir,
rng_state(create_rng_state_from_seeds(1, 1, 1)),
):
@@ -327,5 +327,6 @@ def test_inference_regression_in_memory(example_id, rmsd_tolerance, csv_toleranc
rmsd_difference < rmsd_tolerance
), f"Mean RMSD difference {rmsd_difference:.4f}Å exceeds {rmsd_tolerance}Å tolerance for {example_id}"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -108,3 +108,23 @@ This is then passed through the same processing pipeline as in training with `is
<figcaption>Overview of important transforms in the Atom14 conditioning pipeline.
</figcaption>
</p>
## Citation
If you use this code or data in your work, please cite:
```bibtex
@article {butcher2025_rfdiffusion3,
author = {Butcher, Jasper and Krishna, Rohith and Mitra, Raktim and Brent, Rafael Isaac and Li, Yanjing and Corley, Nathaniel and Kim, Paul T and Funk, Jonathan and Mathis, Simon Valentin and Salike, Saman and Muraishi, Aiko and Eisenach, Helen and Thompson, Tuscan Rock and Chen, Jie and Politanska, Yuliya and Sehgal, Enisha and Coventry, Brian and Zhang, Odin and Qiang, Bo and Didi, Kieran and Kazman, Maxwell and DiMaio, Frank and Baker, David},
title = {De novo Design of All-atom Biomolecular Interactions with RFdiffusion3},
elocation-id = {2025.09.18.676967},
year = {2025},
doi = {10.1101/2025.09.18.676967},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2025/11/19/2025.09.18.676967},
eprint = {https://www.biorxiv.org/content/early/2025/11/19/2025.09.18.676967.full.pdf},
journal = {bioRxiv}
}
```

View File

@@ -35,17 +35,6 @@ dataset:
use_element_for_atom_names_of_atomized_tokens: ${datasets.global_transform_args.use_element_for_atom_names_of_atomized_tokens}
residue_cache_dir: ${paths.data.residue_cache_dir}
# PPI
keep_full_binder_in_spatial_crop: ${datasets.global_transform_args.keep_full_binder_in_spatial_crop}
max_binder_length: ${datasets.global_transform_args.max_binder_length}
max_ppi_hotspots_frac_to_provide: ${datasets.global_transform_args.max_ppi_hotspots_frac_to_provide}
ppi_hotspot_max_distance: ${datasets.global_transform_args.ppi_hotspot_max_distance}
# Secondary structure
max_ss_frac_to_provide: ${datasets.global_transform_args.max_ss_frac_to_provide}
min_ss_island_len: ${datasets.global_transform_args.min_ss_island_len}
max_ss_island_len: ${datasets.global_transform_args.max_ss_island_len}
# Other dataset-specific parameters
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
token_1d_features: ${model.net.token_initializer.token_1d_features}

View File

@@ -159,23 +159,6 @@ trainer:
prevalidate: False
validate_every_n_epochs: 4
precision: bf16-mixed
loss:
verbose_diffusion_loss: null
simple_diffusion_loss:
_target_: rfd3.metrics.losses.SimpleDiffusionLoss
sigma_data: ${model.net.diffusion_module.sigma_data}
weight: 4.0
lddt_weight: 0.25
alpha_virtual_atom: 1.0
alpha_polar_residues: 1.0
lp_weight: 0.0
unindexed_norm_p: 1.0
alpha_unindexed_diffused: 1.0
unindexed_t_alpha: 0.75
normalize_virtual_atom_weight: False
alpha_ligand: 10.0
# callbacks:
# activations_tracking_callback:
# _target_: modelhub.callbacks.health_logging.ActivationsGradientsWeightsTracker

View File

@@ -7,10 +7,9 @@ defaults:
ckpt_path: ???
num_nodes: 1
devices_per_node: 1
print_config: False
seed: null
# Parameters for RFD3InferenceEngine.run()
inputs: ???
out_dir: ???
dump_predictions: true
dump_trajectories: false
one_model_per_file: false

View File

@@ -3,14 +3,14 @@ defaults:
- base
- _self_
_target_: rfd3.inference.engine.RFD3InferenceEngine
_target_: rfd3.engine.RFD3InferenceEngine
out_dir: ???
inputs: ??? # null, json, pdb or
ckpt_path: /projects/ml/aa_design/models/rfd3_latest.ckpt
# ckpt_path: /projects/ml/aa_design/models/rfd3_latest.ckpt
ckpt_path: /projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt
json_keys_subset: null
skip_existing: True
seed: null # if null samples seed integer based on timestamp
#########################################################
# Design spec args: overrides args from input json
@@ -21,7 +21,6 @@ specification: {}
diffusion_batch_size: 8
n_batches: 1
# Inference sampler args | set to None to use the default in the checkpoint's config
inference_sampler:
kind: "default" # "default" or "symmetry" to choose the sampler
@@ -35,14 +34,9 @@ inference_sampler:
cfg_t_max: null # max t to apply cfg guidance
cfg_scale: 1.5
center_option: "all" # Options are ["all", "motif", "diffuse"]
move_noise_to_reset_com: False # Reset the COM of the diffuse region after the re-noising operation in each diffusion step
s_trans: 1.0 # Translational noise scale for augmentation during inference
fraction_of_steps_to_fix_motif: 0.0 # Fraction of steps to let the model not move the motif. e.g. if we have 10 steps, set this value to 0.2 will make model not move motif for the last 2 steps.
skip_few_diffusion_steps: False # Choose to skip some diffusion steps based on the noise scheme
inference_noise_scaling_factor: 1.0
allow_realignment: False
zero_drift_noise: False
use_frame_guidance: False
# Diffusion args:
num_timesteps: 200
@@ -51,7 +45,6 @@ inference_sampler:
p: 7
gamma_0: 0.6 # Previously 1.0 | 0.0 for ODE sampling
gamma_min: 1.0
gamma_min2: 0.0
s_jitter_origin: 0.0 # Sigma of gaussian noise to jitter the motif offset (equivalent to ORI token Jitter)
# Saving args
@@ -66,13 +59,8 @@ output_full_json: True
# e.g. Empty string -> f'{jsonkey}_{batch}_{model}'
# e.g. Chunk string -> f'{chunkprefix_}{jsonkey}_{batch}_{model}' (pipelines usage)
global_prefix: null
dump_prediction_metadata_json: True
dump_trajectories: False
one_model_per_file: True
align_trajectory_structures: False
# Additional args
print_config: False
prevalidate_inputs: True
low_memory_mode: False # False for standard mode, True for memory efficient tokenization mode

View File

@@ -1,4 +1,4 @@
_target_: rfd3.model.aa_design.RFD3
_target_: rfd3.model.RFD3.RFD3
c_s: 384
c_z: 128
@@ -53,7 +53,7 @@ token_initializer: # formerly known as the trunk
n_attn_keys: 128
diffusion_module:
_target_: rfd3.model.rfd3_diffusion_module.RFD3DiffusionModule
_target_: rfd3.model.RFD3_diffusion_module.RFD3DiffusionModule
c_token: 768
c_t_embed: 256 # Time embedding dimension
sigma_data: 16

View File

@@ -26,6 +26,6 @@ design_benchmark_data_dir: /projects/ml/aa_design/benchmarks
design_model_weight_dir: /projects/ml/aa_design/models
# path to directory with cached residue data
residue_cache_dir: /net/tukwila/ncorley/datahub/MACE-OFF23_medium
residue_cache_dir: /net/tukwila/lschaaf/datahub/MACE-Egret-3-noH/mace_embeddings
cif_cache_dir: /net/tukwila/ncorley/cifutils/cache

View File

@@ -1,6 +1,7 @@
defaults:
- ddp
- loss/losses/diffusion_loss@loss.verbose_diffusion_loss
- loss/losses/diffusion_loss@loss.diffusion_loss
- loss/losses/sequence_loss@loss.sequence_loss
- metrics: design_metrics
- _self_
@@ -27,11 +28,4 @@ clip_grad_max_norm: 10.0
output_dir: ${paths.output_dir}
n_recycles_train: 2
grad_accum_steps: 3 # overridden by launch.sh
skip_optimizer_loading: True
loss:
verbose_diffusion_loss:
sequence_head_loss:
_target_: rfd3.metrics.losses.SequenceLoss
weight: 0.1
max_t: 1
skip_optimizer_loading: True

View File

@@ -1,21 +1,11 @@
# _target_: modelhub.loss.af3_losses.DiffusionLoss
# sigma_data: ${model.net.diffusion_module.sigma_data}
# alpha_dna: 5
# alpha_rna: 5
# alpha_ligand: 10
# edm_lambda: True
# se3_invariant_loss: True
# clamp_diffusion_loss: False
sigma_data: ${model.net.diffusion_module.sigma_data}
weight: 4.0
_target_: rfd3.metrics.losses.VerboseDiffusionLoss
alpha_motif: 1.0
alpha_ca_atom: 1.0
alpha_virtual_atom: 1.0
alpha_fixed_motif: 2.0
alpha_unindexed_diffused: 2.0
lddt_weight: 0.25
clamp_diffusion_loss: True
align_prediction: False
use_motif_aligned_loss: False
alpha_virtual_atom: 1.0
alpha_polar_residues: 1.0
lp_weight: 0.0
unindexed_norm_p: 1.0
alpha_unindexed_diffused: 1.0
unindexed_t_alpha: 0.75
normalize_virtual_atom_weight: False
alpha_ligand: 10.0

View File

@@ -0,0 +1,3 @@
_target_: rfd3.metrics.losses.SequenceLoss
weight: 0.1
max_t: 1

View File

@@ -3,7 +3,7 @@ name = "rfd3"
dynamic = ["version"]
description = "De novo Design of All-atom Biomolecular Interactions with RFdiffusion3"
readme = "README.md"
requires-python = ">= 3.12"
requires-python = ">=3.12"
authors = [
{ name = "Institute for Protein Design", email = "contact@ipd.uw.edu" },
]

View File

@@ -19,26 +19,15 @@ def design(ctx: typer.Context):
# Get all arguments
args = ctx.params.get("args", []) + ctx.args
# Parse arguments
hydra_overrides = []
if len(args) == 1 and "=" not in args[0]:
# Old style: single positional argument assumed to be inputs
hydra_overrides.append(f"inputs={args[0]}")
else:
# New style: all arguments are hydra overrides
hydra_overrides.extend(args)
args = [a for a in args if a not in ["design", "fold"]]
# Ensure we have at least a default inference_engine if not specified
has_inference_engine = any(
arg.startswith("inference_engine=") for arg in hydra_overrides
)
has_inference_engine = any(arg.startswith("inference_engine=") for arg in args)
if not has_inference_engine:
hydra_overrides.append("inference_engine=rfdiffusion3")
args.append("inference_engine=rfdiffusion3")
with initialize_config_dir(config_dir=config_path, version_base="1.3"):
cfg = compose(config_name="inference", overrides=hydra_overrides)
cfg = compose(config_name="inference", overrides=args)
# Lazy import to avoid loading heavy dependencies at CLI startup
from rfd3.run_inference import run_inference

View File

@@ -1,5 +1,8 @@
import numpy as np
from atomworks.constants import CRYSTALLIZATION_AIDS
from modelhub.constants import TIP_BY_RESTYPE
TIP_BY_RESTYPE
# Annot: default (diffused default)
REQUIRED_CONDITIONING_ANNOTATION_VALUES = {
@@ -204,46 +207,3 @@ SELECTION_NONPROTEIN = [
"MACROLIDE",
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
]
# fmt: off
# ... For convenience, we allow BKBN, or TIP to be used as a shortcut | TIP is the largest set of fixed atom given at least 2 tip atoms
TIP_BY_RESTYPE = {
"TRP": ["CG","CD1","CD2","NE1","CE2","CE3","CZ2","CZ3","CH2"], # fix both rings
"HIS": ["CG","ND1","CD2","CE1","NE2"], # fixed ring
"TYR": ["CZ","OH"], # keeps ring dihedral flexible
"PHE": ["CG","CD1","CD2","CE1","CE2","CZ"],
"ASN": ["CB", "CG","OD1","ND2"],
"ASP": ["CB", "CG","OD1","OD2"],
"GLN": ["CG", "CD","OE1","NE2"],
"GLU": ["CG", "CD","OE1","OE2"],
"CYS": ["CB", "SG"],
"SER": ["CB", "OG"],
"THR": ["CB", "OG1"],
"LEU": ["CB", "CG", "CD1", "CD2"],
"VAL": ["CG1", "CG2"],
"ILE": ["CB", "CG2"],
"MET": ["SD", "CE"],
"LYS": ["CE","NZ"],
"ARG": ["CD","NE","CZ","NH1","NH2"],
"PRO": None,
"ALA": None,
"GLY": None,
"UNK": None,
"MSK": None
}
# fmt: on
STANDARD_PARSER_ARGS = {
"add_missing_atoms": True,
"add_id_and_entity_annotations": True,
"add_bond_types_from_struct_conn": ("covale",),
"remove_ccds": tuple(CRYSTALLIZATION_AIDS),
"remove_waters": True,
"fix_ligands_at_symmetry_centers": True,
"fix_arginines": True,
"fix_formal_charges": True,
"fix_bond_types": True,
"convert_mse_to_met": True, # Changed from False to True vs. atomworks.io.parser.parse default
"hydrogen_policy": "keep",
"model": None, # all models
}

View File

@@ -1,7 +1,7 @@
# RFdiffusion3 — Input specification (dialect **2**)
> **TL;DR**
> Inputs are now defined with a single `InputSpecification` class.
> Inputs are now defined with a single `DesignInputSpecification` class.
> Selections like “whats fixed?”, “whats sequence-free?”, “which atoms are donors/acceptors?” are all expressed with the same **InputSelection** mini-language.
> Everything is reproducibly logged back out alongside your generation.
@@ -10,7 +10,7 @@
- [What changed (high level)](#what-changed-high-level)
- [Quick start](#quick-start)
- [The `InputSelection` mini-language](#the-inputselection-mini-language)
- [Full schema: `InputSpecification`](#full-schema-inputspecification)
- [Full schema: `DesignInputSpecification`](#full-schema-DesignInputSpecification)
- [Common recipes (cookbook)](#common-recipes-cookbook)
- [Partial diffusion](#partial-diffusion)
- [Symmetry](#symmetry)
@@ -40,7 +40,7 @@
---
## InputSpecification
## DesignInputSpecification
| Field | Type | Description |
| -------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------- |

View File

@@ -0,0 +1,435 @@
import json
import logging
import os
import time
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List
import torch
import yaml
from biotite.structure import AtomArray
from toolz import merge_with
from modelhub.common import exists
from modelhub.inference_engines.base import BaseInferenceEngine
from modelhub.utils.ddp import RankedLogger
from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
from rfd3.inference.datasets import (
assemble_distributed_inference_loader_from_json,
)
from rfd3.inference.input_parsing import DesignInputSpecification
from rfd3.utils.inference import ensure_input_is_abspath
from rfd3.utils.io import (
CIF_LIKE_EXTENSIONS,
dump_metadata,
dump_structures,
dump_trajectories,
extract_example_id_from_path,
find_files_with_extension,
)
logging.basicConfig(level=logging.INFO)
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
class RFD3InferenceEngine(BaseInferenceEngine):
"""Inference engine for RFdiffusion3"""
def __init__(
self,
*,
# Default input handling args
skip_existing: bool,
json_keys_subset: None | List[str],
prevalidate_inputs: bool,
# Base inference engine args
diffusion_batch_size: int,
inference_sampler: dict,
specification: dict | None,
# Structure dumping arguments
global_prefix: str | None,
cleanup_guideposts: bool,
cleanup_virtual_atoms: bool,
read_sequence_from_sequence_head: bool,
output_full_json: bool,
dump_prediction_metadata_json: bool,
dump_trajectories: bool,
align_trajectory_structures: bool,
low_memory_mode: bool,
**kwargs,
):
super().__init__(
transform_overrides={"diffusion_batch_size": diffusion_batch_size},
inference_sampler_overrides={**inference_sampler},
trainer_overrides={
"cleanup_guideposts": cleanup_guideposts,
"cleanup_virtual_atoms": cleanup_virtual_atoms,
"read_sequence_from_sequence_head": read_sequence_from_sequence_head,
"output_full_json": output_full_json,
},
**kwargs,
)
# save
self.specification_overrides = dict(specification or {})
# Setup output directories and args
self.global_prefix = global_prefix
self.json_keys_subset = json_keys_subset
self.prevalidate_inputs = prevalidate_inputs
self.skip_existing = skip_existing
# Saving / other args
self.dump_prediction_metadata_json = dump_prediction_metadata_json
self.dump_trajectories = dump_trajectories
self.align_trajectory_structures = align_trajectory_structures
if not cleanup_guideposts:
ranked_logger.warning(
"Guideposts will not be cleaned up. This is intended for debugging purposes."
)
if not cleanup_virtual_atoms:
ranked_logger.warning(
"Virtual atoms will not be cleaned up. Some tools like MPNN may run, but outputs will not be like native structures."
)
# Check which example ids already exist in the output directory
if low_memory_mode:
ranked_logger.info("Low memory mode enabled.")
# HACK: Set attribute to the diffusion module
os.environ["RFD3_LOW_MEMORY_MODE"] = "1"
def run(
self,
*,
inputs: str | PathLike | AtomArray | DesignInputSpecification,
n_batches: int | None = None,
out_dir: str | PathLike | None = None,
):
self._set_out_dir(out_dir)
inputs = self._canonicalize_inputs(inputs)
design_specifications = self._multiply_specifications(
inputs=inputs,
n_batches=n_batches,
)
# init before
self.initialize()
outputs = self._run_multi(design_specifications)
return outputs
def _set_out_dir(self, out_dir: str | PathLike | None):
out_dir = Path(out_dir) if out_dir else None
if out_dir:
out_dir.mkdir(parents=True, exist_ok=True)
ranked_logger.info(f"Outputs will be written to {out_dir.resolve()}.")
self.out_dir = out_dir
def _run_multi(self, specs):
# ==============================================================================
# Prepare pipeline and inference loader
# ==============================================================================
loader = assemble_distributed_inference_loader_from_json(
# Passed directly to ContigJSONDataset
# data={spec.example_id: spec for spec in spec.values()},
data=specs,
transform=self.pipeline,
name="inference-dataset",
cif_parser_args=None,
subset_to_keys=None,
eval_every_n=1,
# Sampler args
world_size=self.trainer.fabric.world_size,
rank=self.trainer.fabric.global_rank,
)
loader = self.trainer.fabric.setup_dataloaders(
loader,
use_distributed_sampler=False,
)
# ==============================================================================
# Evaluate, using `validation_step`
# ==============================================================================
outputs = {}
for batch_idx, batch in enumerate(loader):
pipeline_output = batch[0]
output = self._model_forward(pipeline_output)
if self.out_dir:
self.save_batch_outputs(
out_dir=self.out_dir,
example_id=pipeline_output["example_id"],
network_output=output["network_output"],
prediction_metadata=output["prediction_metadata"],
predicted_atom_array_stack=output["predicted_atom_array_stack"],
pipeline_output=pipeline_output,
)
else:
outputs[pipeline_output["example_id"]] = {
"network_output": output["network_output"],
"prediction_metadata": output["prediction_metadata"],
"predicted_atom_array_stack": output["predicted_atom_array_stack"],
}
return outputs
def _model_forward(self, pipeline_output):
t0 = time.time()
with torch.no_grad():
pipeline_output = self.trainer.fabric.to_device(pipeline_output)
output = self.trainer.validation_step(
batch=pipeline_output,
batch_idx=0,
compute_metrics=False,
)
t_end = time.time()
# Add additional information to prediction metadata
for key in output["prediction_metadata"].keys():
ckpt = Path(self.ckpt_path)
if ckpt.is_symlink():
ckpt = ckpt.resolve(strict=True) # follow symlink to target
output["prediction_metadata"][key]["ckpt_path"] = str(ckpt)
output["prediction_metadata"][key]["seed"] = self.seed
ranked_logger.info(f"Finished inference batch in {t_end - t0:.2f} seconds.")
return output
###############################################
# Input merging
###############################################
def _canonicalize_inputs(
self, inputs
) -> Dict[str, dict | DesignInputSpecification]:
is_json_like = (isinstance(inputs, (str, PathLike, Path))) or (
isinstance(inputs, list)
and all([isinstance(i, (str, PathLike, Path)) for i in inputs])
)
is_specification_like = isinstance(inputs, DesignInputSpecification) or (
isinstance(inputs, list)
and all([isinstance(i, DesignInputSpecification) for i in inputs])
)
is_atom_array_like = isinstance(inputs, (AtomArray, list)) or (
isinstance(inputs, list) and all([isinstance(i, AtomArray) for i in inputs])
)
if inputs is None:
# Create empty specification dictionary
return {"": {**self.specification_overrides}}
elif is_json_like:
# List of file paths
inputs = process_input(
inputs,
json_keys_subset=self.json_keys_subset,
global_prefix=self.global_prefix,
specification_overrides=self.specification_overrides,
validate=self.prevalidate_inputs,
) # any -> Dict[Name: DesignInputSpecification]
elif is_specification_like:
# List of DesignInputSpecifications
if isinstance(inputs, DesignInputSpecification):
inputs = [inputs]
inputs = {f"backbone_{i}": spec for i, spec in enumerate(inputs)}
elif is_atom_array_like:
raise NotImplementedError("AtomArray inputs not yet supported.")
else:
raise ValueError(
f"Invalid input type: {type(inputs)}. Expected JSON/YAML file paths, AtomArray, or DesignInputSpecification.\nInput: {inputs}"
)
return design_specifications
def _multiply_specifications(
self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
) -> Dict[str, Dict[str, Any]]:
# Find existing example IDS in output directory
if exists(self.out_dir):
existing_example_ids = set(
extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
)
ranked_logger.info(
f"Found {len(existing_example_ids)} existing example IDs in the output directory."
)
# Based on inputs, construct the specifications to loop through
design_specifications = {}
for prefix, example_spec in inputs.items():
# ... Create n_batches for example
for batch_id in range((n_batches) if exists(n_batches) else 1):
# ... Example ID
example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
if (
self.skip_existing
and exists(self.out_dir)
and example_id in existing_example_ids
):
ranked_logger.info(
f"Skipping design specification for example {example_id} | Already exists."
)
continue
design_specifications[example_id] = example_spec
return design_specifications
def save_batch_outputs(
self,
*,
out_dir,
network_output,
prediction_metadata,
predicted_atom_array_stack,
pipeline_output,
example_id,
):
out_dir = Path(out_dir)
dump_structures(
atom_arrays=predicted_atom_array_stack,
base_path=out_dir / example_id,
one_model_per_file=True,
extra_fields=SAVED_CONDITIONING_ANNOTATIONS,
)
if self.dump_prediction_metadata_json:
dump_metadata(
prediction_metadata=prediction_metadata,
base_path=out_dir / example_id,
one_model_per_file=True,
)
if self.dump_trajectories:
dump_trajectories(
trajectory_list=network_output["X_denoised_L_traj"],
atom_array=pipeline_output["atom_array"],
base_path=out_dir / f"{example_id}_denoised",
align_structures=self.align_trajectory_structures,
)
dump_trajectories(
trajectory_list=network_output["X_noisy_L_traj"],
atom_array=pipeline_output["atom_array"],
base_path=out_dir / f"{example_id}_noisy",
align_structures=self.align_trajectory_structures,
)
ranked_logger.info(
f"Outputs for {example_id} written to {out_dir / example_id}."
)
def normalize_inputs(inputs: str | list | None) -> list[str | None]:
"""
inputs: str | list[str] | None
- Can be:
- A single path to a JSON, YAML, or regular input file (cif or pdb)
- A comma-separated string of paths (e.g. "a.json,b.json")
- A list of file paths
- None or an empty list, in which case a dummy input is added (used for e.g. motif-only design)
- Returns list of paths or [None] if no inputs are provided
"""
if inputs is None or (isinstance(inputs, list) and len(inputs) == 0):
inputs = [None]
elif isinstance(inputs, str):
inputs = inputs.split(",")
elif not isinstance(inputs, list):
raise ValueError(
f"Invalid input type: {type(inputs)}. Expected str, list, or None.\nInput: {inputs}"
)
return inputs
def process_input(
inputs: str | list | None,
json_keys_subset: str | list | None = None,
global_prefix: str | None = None,
specification_overrides: dict | None = None,
validate: bool = True,
) -> Dict[str, dict]:
"""
inputs: Any -> list[str | None] (see normalize_inputs)
json_keys_subset: extract only subset of JSON keys. None will keep all keys
prefix: If provided, prefix all example ids with said prefix
returns: Dictionaries of specifcation args pre-batching:
{
'jsonfile_jsonkey1': {
**args_from_key1
},
'jsonfile_jsonkey2': {
**args_from_key2
}
}
"""
specification_overrides = dict(specification_overrides or {})
def merge_args(example_args: dict) -> dict:
return merge_with(lambda x: x[-1], example_args, specification_overrides)
inputs = normalize_inputs(inputs)
# If global_prefix is not provided, then default to using the basename of the JSON or YAML file (when provided)
if global_prefix is None:
use_json_basename_prefix = True
else:
use_json_basename_prefix = False
# ... Convert all inputs to list of inputs (e.g. if comma-separated)
if exists(inputs) and "," in inputs:
inputs = inputs.split(",")
elif not exists(inputs):
# If inputs is None or empty, we will create a dummy input
inputs = []
inputs = inputs if isinstance(inputs, list) else [inputs]
if len(inputs) == 0:
inputs = [None]
# ... Determine prefix of sample to create
all_specs = {}
for input in inputs:
if exists(input) and (input.endswith(".json") or input.endswith(".yaml")):
# ... Load JSON or YAML file
with open(input, "r") as f:
data = json.load(f) if input.endswith(".json") else yaml.safe_load(f)
# ... Apply any global args for this input file
if "global_args" in data:
global_args = data.pop("global_args")
for example in data:
data[example].update(global_args)
# ... Subset to keys
if json_keys_subset is not None:
json_keys_subset = (
json_keys_subset.split(",")
if isinstance(json_keys_subset, str)
else json_keys_subset
)
data = {
example: data[example]
for example in json_keys_subset
if example in data
}
# ... Extract each accumulated example in data.
for example, args in data.items():
args = ensure_input_is_abspath(args, input)
if use_json_basename_prefix:
name = os.path.splitext(os.path.basename(input))[0]
prefix = f"{name}_{example}"
else:
prefix = f"{global_prefix}{example}"
args["extra"] = args.get("extra", {}) | {"example": example}
all_specs[prefix] = dict(merge_args(args))
elif exists(input):
prefix = os.path.basename(os.path.splitext(input)[0])
if global_prefix is not None:
prefix = f"{global_prefix}{prefix}"
all_specs[prefix] = dict(merge_args({"input": input}))
else:
all_specs["backbone"] = dict(specification_overrides)
if validate:
for prefix, example_spec in all_specs.items():
ranked_logger.info(
f"Prevalidating design specification for example: {prefix}"
)
DesignInputSpecification.safe_init(**example_spec)
return all_specs

View File

@@ -5,29 +5,17 @@
import json
import os
import textwrap
import time
from os import PathLike
from typing import Any, List
from typing import Any, Dict, List
import yaml
from atomworks.io.parser import parse_atom_array
from atomworks.ml.datasets import MolecularDataset
from atomworks.ml.transforms.base import Compose, Transform, TransformedDict
from biotite.structure import BondList
from atomworks.ml.transforms.base import Compose, Transform
from omegaconf import DictConfig, OmegaConf
from rfd3.constants import (
INFERENCE_ANNOTATIONS,
REQUIRED_INFERENCE_ANNOTATIONS,
)
from rfd3.inference.inference_utils import ensure_input_is_abspath
from rfd3.inference.input_parsing import (
create_atom_array_from_design_specification,
)
from rfd3.transforms.conditioning_base import (
check_has_required_conditioning_annotations,
convert_existing_annotations_to_bool,
DesignInputSpecification,
)
from rfd3.utils.inference import ensure_input_is_abspath
from torch.utils.data import (
DataLoader,
SequentialSampler,
@@ -49,7 +37,7 @@ class ContigJsonDataset(MolecularDataset):
def __init__(
self,
*,
data: PathLike | dict,
data: PathLike | Dict[str, dict | DesignInputSpecification],
cif_parser_args: dict | None,
transform: Transform | Compose | None,
name: str | None,
@@ -137,7 +125,7 @@ class ContigJsonDataset(MolecularDataset):
def _check_json_keys(self):
"""Check if the JSON keys are valid."""
for k, data in self.data.items():
if not isinstance(data, dict):
if not isinstance(data, (dict, DesignInputSpecification)):
raise ValueError("Each item in the JSON data must be a dictionary.")
@property
@@ -171,77 +159,19 @@ class ContigJsonDataset(MolecularDataset):
spec = self.data[example_id]
# if 'input' in metadata and not abspath, prepend the source json directory to the file path
spec = ensure_input_is_abspath(spec, self.json_path)
if not isinstance(spec, DesignInputSpecification):
spec = ensure_input_is_abspath(spec, self.json_path)
spec["cif_parser_args"] = self.cif_parser_args
spec = DesignInputSpecification(**spec)
# ... Create atom array with conditioning annotations
atom_array, spec_dict = create_atom_array_from_design_specification(
cif_parser_args=self.cif_parser_args, **spec
)
# Create pipeline input
data = spec.to_pipeline_input(example_id=example_id)
# ... Forward into
data = prepare_pipeline_input_from_atom_array(atom_array)
data["example_id"] = example_id
# ... Wrap up with additional features
if "extra" not in spec_dict:
spec_dict["extra"] = {}
spec_dict["extra"]["example_id"] = example_id
data["specification"] = spec_dict
# ... Send through pipeline
# Apply transforms and return
data = self.transform(data)
return data
def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers.base.load_example_from_metadata_row
atom_array_orig,
) -> dict:
"""
Load or create an example from a metadata dictionary.
If the file path is not provided in the metadata dictionary, create a spoofed CIF file based on the length.
Args:
atom_array_orig: Atom array instantiated with conditioning annotations
Returns:
dict: A dictionary containing the parsed row data and additional loaded CIF data.
"""
_start_parse_time = time.time()
# HACK: Set empty bond graph:
if atom_array_orig.bonds is None:
atom_array_orig.bonds = BondList(atom_array_orig.array_length())
# Temporary spoof of chain IDs to ensure duplicates aren't dropped:
result_dict = parse_atom_array(
atom_array_orig,
remove_ccds=[],
fix_arginines=False,
add_missing_atoms=False,
extra_fields=INFERENCE_ANNOTATIONS,
build_assembly=None,
hydrogen_policy="remove",
)
atom_array = result_dict["asym_unit"][0]
# HACK: Set iid information manually
# We currently do not preserve this information from the input,
# if you want these we'd need to remove the spoofing here
check_has_required_conditioning_annotations(
atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
)
atom_array = convert_existing_annotations_to_bool(atom_array)
atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
data = {
"atom_array": atom_array, # First model
"chain_info": result_dict["chain_info"],
"ligand_info": result_dict["ligand_info"],
"metadata": result_dict["metadata"],
}
_stop_parse_time = time.time()
data = TransformedDict(data)
return data
def assemble_distributed_inference_loader_from_json(
*, rank: int, world_size: int, **dataset_kwargs
) -> DataLoader:

View File

@@ -1,448 +0,0 @@
import json
import logging
import os
import time
from os import PathLike
from pathlib import Path
from typing import Dict
import hydra
import torch
import yaml
from lightning.fabric import seed_everything
from omegaconf import OmegaConf
from rfd3.constants import (
SAVED_CONDITIONING_ANNOTATIONS,
)
from rfd3.inference.datasets import (
assemble_distributed_inference_loader_from_json,
)
from rfd3.inference.inference_utils import ensure_input_is_abspath
from rfd3.inference.input_parsing import InputSpecification
from toolz import merge_with
from modelhub.common import exists
from modelhub.inference_engines.af3 import AF3InferenceEngine
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
from modelhub.utils.io import (
CIF_LIKE_EXTENSIONS,
extract_example_id_from_path,
find_files_with_extension,
)
from modelhub.utils.logging import print_config_tree
logging.basicConfig(level=logging.INFO)
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
def normalize_inputs(inputs: str | list | None) -> list[str | None]:
"""
inputs: str | list[str] | None
- Can be:
- A single path to a JSON, YAML, or regular input file (cif or pdb)
- A comma-separated string of paths (e.g. "a.json,b.json")
- A list of file paths
- None or an empty list, in which case a dummy input is added (used for e.g. motif-only design)
- Returns list of paths or [None] if no inputs are provided
"""
if inputs is None or (isinstance(inputs, list) and len(inputs) == 0):
inputs = [None]
elif isinstance(inputs, str):
inputs = inputs.split(",")
elif not isinstance(inputs, list):
raise ValueError(
f"Invalid input type: {type(inputs)}. Expected str, list, or None.\nInput: {inputs}"
)
return inputs
def process_input(
inputs: str | list | None,
json_keys_subset: str | list | None = None,
global_prefix: str = None,
specification_overrides: dict = {},
) -> Dict[str, dict]:
"""
inputs: Any -> list[str | None] (see normalize_inputs)
json_keys_subset: extract only subset of JSON keys. None will keep all keys
prefix: If provided, prefix all example ids with said prefix
returns: Dictionaries of specifcation args pre-batching:
{
'jsonfile_jsonkey1': {
**args_from_key1
},
'jsonfile_jsonkey2': {
**args_from_key2
}
}
"""
merge_args = lambda d: merge_with(lambda x: x[-1], d, specification_overrides) # noqa
inputs = normalize_inputs(inputs)
# If global_prefix is not provided, then default to using the basename of the JSON or YAML file (when provided)
if global_prefix is None:
use_json_basename_prefix = True
else:
use_json_basename_prefix = False
# ... Convert all inputs to list of inputs (e.g. if comma-separated)
if exists(inputs) and "," in inputs:
inputs = inputs.split(",")
elif not exists(inputs):
# If inputs is None or empty, we will create a dummy input
inputs = []
inputs = inputs if isinstance(inputs, list) else [inputs]
if len(inputs) == 0:
inputs = [None]
# ... Determine prefix of sample to create
all_specs = {}
for input in inputs:
if exists(input) and (input.endswith(".json") or input.endswith(".yaml")):
# ... Load JSON or YAML file
with open(input, "r") as f:
data = json.load(f) if input.endswith(".json") else yaml.safe_load(f)
# ... Apply any global args for this input file
if "global_args" in data:
global_args = data.pop("global_args")
for example in data:
data[example].update(global_args)
# ... Subset to keys
if json_keys_subset is not None:
json_keys_subset = (
json_keys_subset.split(",")
if isinstance(json_keys_subset, str)
else json_keys_subset
)
data = {
example: data[example]
for example in json_keys_subset
if example in data
}
# ... Extract each accumulated example in data.
for example, args in data.items():
args = ensure_input_is_abspath(args, input)
if use_json_basename_prefix:
name = os.path.splitext(os.path.basename(input))[0]
prefix = f"{name}_{example}"
else:
prefix = f"{global_prefix}{example}"
args["extra"] = args.get("extra", {}) | {"example": example}
all_specs[prefix] = dict(merge_args(args))
elif exists(input):
prefix = os.path.basename(os.path.splitext(input)[0])
if global_prefix is not None:
prefix = f"{global_prefix}{prefix}"
all_specs[prefix] = dict(merge_args({"input": input}))
else:
all_specs["backbone"] = specification_overrides
return all_specs
class RFD3InferenceEngine(AF3InferenceEngine):
"""Inference engine for RFdiffusion3"""
def __init__(
self,
# Required args:
out_dir: str | PathLike,
ckpt_path: str | PathLike,
inputs: str | PathLike,
json_keys_subset: None | list[str],
n_batches: int,
*,
# Default design specification args:
specification: dict,
# Base inference engine args
diffusion_batch_size: int,
skip_existing: bool,
inference_sampler: dict,
# Structure dumping arguments
cleanup_guideposts: bool,
cleanup_virtual_atoms: bool,
read_sequence_from_sequence_head: bool,
output_full_json: bool,
dump_prediction_metadata_json: bool,
dump_trajectories: bool,
align_trajectory_structures: bool,
one_model_per_file: bool,
global_prefix: str | None,
###############################################
num_nodes: int,
devices_per_node: int,
print_config: bool,
seed: int | None,
temp_dir=None,
low_memory_mode: bool = False,
prevalidate_inputs: bool = True,
# Atom array instantiation args collapsed into default dict:
):
"""
Design specification args:
# Main:
inputs: JSON, YAML, PDB or CIF (comma-separated string or single) containing coordinate data to be parsed or args to override
length: length of designed structure (int or str like '20-100'). Default: None (specified by contig string)
contig: string of residues to use as backbone. Default: None (specified by length) Example: '10-20,A11-13,10-20'
fixed_atoms: Dict[str] of atom names to use as indexed motif atoms with fixed coordinates. Default: None
unindex: list of residues in input_src to unindex in design. Default: None
redesign_motif_sidechains: bool or str. Specifies which motif residues have fixed sidechains by default. Default: True
ligand: str. Ligands in input file to include as motif
atomwise_rasa: Dict[str] of atomwise rasas to use for design. Default: None
ori_token: list of 3 floats indicating the origin to center the design around. Default: None
seed: int or None. If None, a random seed will be sampled.
Additional args:
ckpt_path: Path to checkpoint file
n_batches: Number of samples to
"""
if not os.path.isabs(out_dir):
out_dir = os.path.abspath(out_dir)
ranked_logger.info("Using absolute path for out_dir: {}".format(out_dir))
# Convert input sources to design specification dictionaries
inputs = process_input(
inputs,
json_keys_subset=json_keys_subset,
global_prefix=global_prefix,
specification_overrides=specification,
) # any -> Dict[Name: InputSpecification]
self.design_specifications = {}
for prefix, example_spec in inputs.items():
# ... Set example key as the prefix
if prevalidate_inputs:
ranked_logger.info(
f"Prevalidating design specification for example: {prefix}"
)
InputSpecification.safe_init(**example_spec)
# ... Create n_batches for example
for batch_i in range(n_batches):
# ... Example ID
example_id = f"{prefix}_{batch_i}"
self.design_specifications[example_id] = example_spec
############################################################
# Feed-forward inputs similar to MH-AF3 inference engine
############################################################
# ... set the random seed for reproducibility (and for augmentation, e.g., for antibodies)
if not exists(seed):
seed = int(time.time() * 1000) % (2**31)
ranked_logger.info(f"Seeding everything with seed={seed}...")
seed_everything(seed, workers=True, verbose=True)
self.seed = seed
# We only extract the `train_cfg` from the checkpoint initially
self.load_and_override_ckpt_config(
ckpt_path=ckpt_path,
num_nodes=num_nodes,
devices_per_node=devices_per_node,
inference_sampler=inference_sampler,
)
set_accelerator_based_on_availability(self.cfg)
# (b) based on the dataset (we will apply when constructing the pipeline)
self.dataset_overrides = {
"diffusion_batch_size": diffusion_batch_size,
}
# ... instantiate the trainer with the (modified) configuration
self.trainer = hydra.utils.instantiate(
self.cfg.trainer,
_convert_="partial",
_recursive_=False,
)
# Set the output directory for the CIF files (e.g., predicted structures)
self.cif_out_dir = Path(out_dir) if out_dir else Path("./")
# Structure dumping
self.dump_prediction_metadata_json = dump_prediction_metadata_json
self.dump_trajectories = dump_trajectories
self.one_model_per_file = one_model_per_file
self.align_trajectory_structures = align_trajectory_structures
self.trainer.cleanup_virtual_atoms = cleanup_virtual_atoms
self.trainer.cleanup_guideposts = cleanup_guideposts
self.trainer.read_sequence_from_sequence_head = read_sequence_from_sequence_head
self.trainer.output_full_json = output_full_json
self.trainer.inference_sampler_overrides = inference_sampler
self.prediction_extra_fields = SAVED_CONDITIONING_ANNOTATIONS
self.skip_existing = skip_existing
self.dump_predictions = True
self.print_config = print_config
if not cleanup_guideposts:
ranked_logger.warning(
"Guideposts will not be cleaned up. This is intended for debugging purposes."
)
if not cleanup_virtual_atoms:
ranked_logger.warning(
"Virtual atoms will not be cleaned up. Some tools like MPNN may run, but outputs will not be like native structures."
)
# Check which example ids already exist in the output directory
self.existing_example_ids = set(
extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
for path in find_files_with_extension(out_dir, CIF_LIKE_EXTENSIONS)
)
ranked_logger.info(
f"Found {len(self.existing_example_ids)} existing example IDs in the output directory."
)
if low_memory_mode:
ranked_logger.info("Low memory mode enabled.")
# HACK: Set attribute to the diffusion module
os.environ["RFD3_LOW_MEMORY_MODE"] = "1"
def load_and_override_ckpt_config(
self, ckpt_path, num_nodes, devices_per_node, inference_sampler
):
assert exists(ckpt_path), f"Checkpoint path ({ckpt_path}) not provided."
ranked_logger.info(f"Loading checkpoint from {Path(ckpt_path).resolve()}...")
self.ckpt_path = ckpt_path
self.cfg = OmegaConf.create(
torch.load(self.ckpt_path, "cpu", weights_only=False)["train_cfg"]
)
# Override specific parameters within the Hydra config:
# (a) based on the input arguments
self.cfg.trainer.num_nodes = num_nodes
self.cfg.trainer.devices_per_node = devices_per_node
for k, v in inference_sampler.items():
if v is None:
continue
setattr(self.cfg.model.net.inference_sampler, k, v)
# Set metrics / callbacks to be null s.t. they aren't loaded
self.cfg.trainer.metrics = None
# Record the random seed to be dumped in the output JSON
self.cfg.trainer.seed = self.seed
def example_id_exists(self, example_id, verbose=False):
# TODO: Move this to another file to standardize better with src
if not self.one_model_per_file:
# Check if one file exists
all_exist = example_id in self.existing_example_ids
if all_exist and verbose:
ranked_logger.info(
f"Model file for example {example_id} already exists in the output directory."
)
else:
all_exist = all(
[
(f"{example_id}_model_{i}" in self.existing_example_ids)
for i in range(self.dataset_overrides["diffusion_batch_size"])
]
)
if all_exist and verbose:
ranked_logger.info(
f"All models for example {example_id} already exist in the output directory."
)
return all_exist
def eval(self):
"""
Run design on a set of specifications
"""
if self.print_config:
print_config_tree(
self.cfg.model, resolve=True, title="INFERENCE MODEL CONFIGURATION"
)
# ... spawn processes for distributed training, if using multiple GPUs
ranked_logger.info(
f"Spawning {self.trainer.fabric.world_size} processes from {self.trainer.fabric.global_rank}..."
)
# ==============================================================================
# Construct the model and load the checkpoint
# ==============================================================================
self.trainer.initialize_or_update_trainer_state({"train_cfg": self.cfg})
self.trainer.construct_model()
self.trainer.load_checkpoint(ckpt_path=self.ckpt_path, is_inference=True)
# Ensure optimizer isn't loaded
self.trainer.state["optimizer"] = None
self.trainer.state["train_cfg"].model.optimizer = None
self.trainer.setup_model_optimizers_and_schedulers()
self.trainer.state["model"].eval()
# ==============================================================================
# Prepare pipeline and inference loader
# ==============================================================================
# TODO: have name be the basename of the JSON or YAML file
loader = assemble_distributed_inference_loader_from_json(
# Passed directly to ContigJSONDataset
data=self.design_specifications,
transform=self.construct_pipeline(),
name="inference-dataset",
cif_parser_args={},
subset_to_keys=None,
eval_every_n=1,
# Sampler args
world_size=self.trainer.fabric.world_size,
rank=self.trainer.fabric.global_rank,
)
loader = self.trainer.fabric.setup_dataloaders(
loader,
use_distributed_sampler=False,
)
# ==============================================================================
# Evaluate, using `validation_step``
# ==============================================================================
for batch_idx, batch in enumerate(loader):
pipeline_output = batch[0]
example_id = pipeline_output["example_id"]
if self.skip_existing:
if self.example_id_exists(example_id, verbose=True):
ranked_logger.info(
f"Skipping structure {batch_idx + 1}/{len(loader)}: {example_id} | Already exists."
)
continue
else:
ranked_logger.info(
f"Predicting structure {batch_idx + 1}/{len(loader)}: {example_id}"
)
# Model inference
t0 = time.time()
with torch.no_grad():
pipeline_output = self.trainer.fabric.to_device(pipeline_output)
output = self.trainer.validation_step(
batch=pipeline_output,
batch_idx=0,
compute_metrics=False,
)
t_end = time.time()
# Add additional information to prediction metadata
for key in output["prediction_metadata"].keys():
ckpt = Path(self.ckpt_path)
if ckpt.is_symlink():
ckpt = ckpt.resolve(strict=True) # follow symlink to target
output["prediction_metadata"][key]["ckpt_path"] = str(ckpt)
output["prediction_metadata"][key]["seed"] = self.seed
ranked_logger.info(f"Finished inference batch in {t_end - t0:.2f} seconds.")
self.save_batch_outputs(
example_id=example_id,
network_output=output["network_output"],
prediction_metadata=output["prediction_metadata"],
predicted_atom_array_stack=output["predicted_atom_array_stack"],
pipeline_output=pipeline_output,
)

View File

@@ -1,11 +1,18 @@
import copy
import json
import logging
import os
import time
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
import numpy as np
from atomworks.constants import STANDARD_AA
from atomworks.io.parser import parse_atom_array
# from atomworks.ml.datasets.datasets import BaseDataset
from atomworks.ml.transforms.base import TransformedDict
from atomworks.ml.utils.token import (
get_token_starts,
)
@@ -18,20 +25,10 @@ from pydantic import (
model_validator,
)
from rfd3.constants import (
INFERENCE_ANNOTATIONS,
REQUIRED_CONDITIONING_ANNOTATION_VALUES,
REQUIRED_INFERENCE_ANNOTATIONS,
)
from rfd3.inference.components import (
get_design_pattern_with_constraints,
get_motif_components_and_breaks,
)
from rfd3.inference.inference_utils import (
extract_ligand_array,
inference_load_,
set_com,
set_common_annotations,
set_indices,
)
from rfd3.inference.legacy_input_parsing import (
create_atom_array_from_design_specification_legacy,
)
@@ -48,8 +45,19 @@ from rfd3.transforms.conditioning_base import (
set_default_conditioning_annotations,
)
from rfd3.transforms.util_transforms import assign_types_
from rfd3.utils.inference import (
extract_ligand_array,
inference_load_,
set_com,
set_common_annotations,
set_indices,
)
from modelhub.common import exists
from modelhub.utils.components import (
get_design_pattern_with_constraints,
get_motif_components_and_breaks,
)
from modelhub.utils.ddp import RankedLogger
logging.basicConfig(level=logging.DEBUG)
@@ -62,216 +70,6 @@ logger = RankedLogger(__name__, rank_zero_only=True)
#################################################################################
@contextmanager
def validator_context(validator_name: str, data: dict = None):
"""Context manager for validator execution with logging."""
logger.debug(f"Starting validator: {validator_name}")
try:
yield
logger.debug(f"✓ Completed validator: {validator_name}")
except Exception as e:
logger.error(
f"✗ Failed in validator: {validator_name}\n"
f" Error: {str(e)}\n"
f" Error type: {type(e).__name__}"
)
raise e
def create_diffused_residues(n, additional_annotations=None):
if n <= 0:
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
atoms = []
[
atoms.extend(
[
struc.Atom(
np.array([0.0, 0.0, 0.0], dtype=np.float32),
res_name="ALA",
res_id=idx,
)
for _ in range(5)
]
)
for idx in range(1, n + 1)
]
array = struc.array(atoms)
array.set_annotation(
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
)
array.set_annotation(
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
)
array = set_default_conditioning_annotations(
array, motif=False, additional=additional_annotations
)
array = set_common_annotations(array)
return array
def create_motif_residue(
token,
strip_sidechains_by_default: bool,
):
if strip_sidechains_by_default and token.res_name in STANDARD_AA:
n_atoms = token.shape[0]
diffuse_oxygen = False
if n_atoms < 3:
raise ValueError(
f"Not enough data for {src_chain}{src_resid} in input atom array."
)
if n_atoms == 3:
# Handle cases with N, CA, C only;
token = token + create_o_atoms(token.copy())
diffuse_oxygen = True # flag oxygen for generation
# Subset to the first 4 atoms (N, CA, C, O) only
token = token[np.isin(token.atom_name, ["N", "CA", "C", "O"])]
# exactly N, CA, C, O but no CB. Place CB onto idealized position and conver to ALA
# Sequence name ALA ensures the padded atoms to be diffused from the fixed backbone
# are placed on the CB so as to not leak the identity of the residue.
token = token + create_cb_atoms(token.copy())
# Sequence name must be set to ALA such that the central atom is correctly CB
token.res_name = np.full_like(token.res_name, "ALA", dtype=token.res_name.dtype)
token.set_annotation(
"is_motif_atom_with_fixed_coord",
np.where(
np.arange(token.shape[0], dtype=int) < (4 - int(diffuse_oxygen)),
token.is_motif_atom_with_fixed_coord,
0,
),
)
check_has_required_conditioning_annotations(token)
token = set_common_annotations(token)
token.set_annotation("res_id", np.full(token.shape[0], 1)) # Reset to 1
return token
def accumulate_components(
components_to_accumulate: List[Union[str, int]],
*,
# Tokens from input
indexed_tokens: Dict[str, AtomArray],
unindexed_tokens: Dict[str, AtomArray],
# Additional parameters
atom_array_accum=[],
start_chain: str = "A",
start_resid: int = 1,
unindexed_breaks: Optional[List[bool]] = [],
**kwargs,
) -> AtomArray:
# ... Create list of components
assert (
x := (set(list(indexed_tokens.keys()) + list(unindexed_tokens.keys())))
).issubset(
(y := set(components_to_accumulate))
), "Unindexed and indexed set {} is not subset of components to accumulate {}".format(
x, y
)
all_tokens = indexed_tokens | unindexed_tokens
all_annots = []
[
all_annots.extend(list(tok.get_annotation_categories()))
for tok in all_tokens.values()
]
all_annots = set(all_annots)
# ... For-loop accum variables
unindexed_components_started = (
False # once one unindexed component is added, stop adding diffused residues
)
chain = start_chain
res_id = start_resid
molecule_id = 0
# ... Insert contig information one- by one-
assert len(components_to_accumulate) == len(
unindexed_breaks
), "Mismatch in number of components to accumulate and breaks"
for component, is_break in zip(components_to_accumulate, unindexed_breaks):
if component == "/0":
# Reset iterators on next chain
chain = chr(ord(chain) + 1)
molecule_id += 1
res_id = 1
continue
# ... Create array to insert
if str(component)[0].isalpha(): # motif (e.g. "A22")
n = 1
# ... Fetch the motif residue
token = all_tokens[component]
# ... Insert breakpoint when break clause is met
if exists(is_break) and is_break:
if not unindexed_components_started:
chain = start_chain
unindexed_components_started = True
token.set_annotation(
"is_motif_atom_unindexed_motif_breakpoint",
np.ones(token.shape[0], dtype=int),
)
else:
token.set_annotation(
"is_motif_atom_unindexed_motif_breakpoint",
np.zeros(token.shape[0], dtype=int),
)
else:
n = int(component)
# ... Skip if none or unindexed
if n == 0 or unindexed_components_started:
res_id += n
continue
# ... Create diffused residues
token = create_diffused_residues(n, all_annots)
# ... Set index of insertion
token = set_indices(
array=token,
chain=chain,
res_id_start=res_id,
molecule_id=molecule_id,
component=component,
)
assert (
len(get_token_starts(token)) == n
), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
# ... Insert & Increment residue ID
atom_array_accum.append(token)
res_id += n
# ... Concatenate all components
atom_array_accum = struc.concatenate(atom_array_accum)
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
# Reset res_id for unindexed residues to avoid duplicates
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
atom_array_accum.is_motif_atom_unindexed.astype(bool)
):
max_id = np.max(
atom_array_accum[
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
].res_id
)
atom_array_accum.res_id[
atom_array_accum.is_motif_atom_unindexed.astype(bool)
] += max_id + 1
# ... Bonds
if atom_array_accum.bonds is None:
atom_array_accum.bonds = BondList(atom_array_accum.array_length())
return atom_array_accum
class LegacySpecification(BaseModel):
"""Legacy specification for compatibility with legacy input parsing."""
@@ -280,20 +78,34 @@ class LegacySpecification(BaseModel):
extra="allow",
)
def build(self):
def build(self, *args, **kwargs):
"""Build atom array using legacy input parsing."""
atom_array = create_atom_array_from_design_specification_legacy(
design_specification=self.model_dump(),
)
return atom_array, self.model_dump()
def to_pipeline_input(self, example_id):
atom_array, spec_dict = self.build(return_metadata=True)
# ... Forward into
data = prepare_pipeline_input_from_atom_array(atom_array)
data["example_id"] = example_id
# ... Wrap up with additional features
if "extra" not in spec_dict:
spec_dict["extra"] = {}
spec_dict["extra"]["example_id"] = example_id
data["specification"] = spec_dict
return data
# ========================================================================
# Input specification
# ========================================================================
class InputSpecification(BaseModel):
class DesignInputSpecification(BaseModel):
"""Validated and parsed input specification before resolution."""
model_config = ConfigDict(
@@ -932,12 +744,75 @@ class InputSpecification(BaseModel):
else:
return cls(**spec_kwargs)
def to_pipeline_input(self, example_id):
atom_array, spec_dict = self.build(return_metadata=True)
# ... Forward into
data = prepare_pipeline_input_from_atom_array(atom_array)
data["example_id"] = example_id
# ... Wrap up with additional features
if "extra" not in spec_dict:
spec_dict["extra"] = {}
spec_dict["extra"]["example_id"] = example_id
data["specification"] = spec_dict
return data
# ============================================================================
# Public API
# APIs and utils
# ============================================================================
def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers.base.load_example_from_metadata_row
atom_array_orig,
) -> dict:
"""
Load or create an example from a metadata dictionary.
If the file path is not provided in the metadata dictionary, create a spoofed CIF file based on the length.
Args:
atom_array_orig: Atom array instantiated with conditioning annotations
Returns:
dict: A dictionary containing the parsed row data and additional loaded CIF data.
"""
_start_parse_time = time.time()
# HACK: Set empty bond graph:
if atom_array_orig.bonds is None:
atom_array_orig.bonds = BondList(atom_array_orig.array_length())
# Temporary spoof of chain IDs to ensure duplicates aren't dropped:
result_dict = parse_atom_array(
atom_array_orig,
remove_ccds=[],
fix_arginines=False,
add_missing_atoms=False,
extra_fields=INFERENCE_ANNOTATIONS,
build_assembly=None,
hydrogen_policy="remove",
)
atom_array = result_dict["asym_unit"][0]
# HACK: Set iid information manually
# We currently do not preserve this information from the input,
# if you want these we'd need to remove the spoofing here
check_has_required_conditioning_annotations(
atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
)
atom_array = convert_existing_annotations_to_bool(atom_array)
atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
data = {
"atom_array": atom_array, # First model
"chain_info": result_dict["chain_info"],
"ligand_info": result_dict["ligand_info"],
"metadata": result_dict["metadata"],
}
_stop_parse_time = time.time()
data = TransformedDict(data)
return data
def create_atom_array_from_design_specification(
**spec_kwargs,
) -> tuple[AtomArray, dict]:
@@ -952,6 +827,216 @@ def create_atom_array_from_design_specification(
return atom_array, {}
# Create input specfication and build
spec = InputSpecification(**spec_kwargs)
spec = DesignInputSpecification(**spec_kwargs)
atom_array, metadata = spec.build(return_metadata=True)
return atom_array, metadata
@contextmanager
def validator_context(validator_name: str, data: dict = None):
"""Context manager for validator execution with logging."""
logger.debug(f"Starting validator: {validator_name}")
try:
yield
logger.debug(f"✓ Completed validator: {validator_name}")
except Exception as e:
logger.error(
f"✗ Failed in validator: {validator_name}\n"
f" Error: {str(e)}\n"
f" Error type: {type(e).__name__}"
)
raise e
def create_diffused_residues(n, additional_annotations=None):
if n <= 0:
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
atoms = []
[
atoms.extend(
[
struc.Atom(
np.array([0.0, 0.0, 0.0], dtype=np.float32),
res_name="ALA",
res_id=idx,
)
for _ in range(5)
]
)
for idx in range(1, n + 1)
]
array = struc.array(atoms)
array.set_annotation(
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
)
array.set_annotation(
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
)
array = set_default_conditioning_annotations(
array, motif=False, additional=additional_annotations
)
array = set_common_annotations(array)
return array
def create_motif_residue(
token,
strip_sidechains_by_default: bool,
):
if strip_sidechains_by_default and token.res_name in STANDARD_AA:
n_atoms = token.shape[0]
diffuse_oxygen = False
if n_atoms < 3:
raise ValueError(
f"Not enough data for {src_chain}{src_resid} in input atom array."
)
if n_atoms == 3:
# Handle cases with N, CA, C only;
token = token + create_o_atoms(token.copy())
diffuse_oxygen = True # flag oxygen for generation
# Subset to the first 4 atoms (N, CA, C, O) only
token = token[np.isin(token.atom_name, ["N", "CA", "C", "O"])]
# exactly N, CA, C, O but no CB. Place CB onto idealized position and conver to ALA
# Sequence name ALA ensures the padded atoms to be diffused from the fixed backbone
# are placed on the CB so as to not leak the identity of the residue.
token = token + create_cb_atoms(token.copy())
# Sequence name must be set to ALA such that the central atom is correctly CB
token.res_name = np.full_like(token.res_name, "ALA", dtype=token.res_name.dtype)
token.set_annotation(
"is_motif_atom_with_fixed_coord",
np.where(
np.arange(token.shape[0], dtype=int) < (4 - int(diffuse_oxygen)),
token.is_motif_atom_with_fixed_coord,
0,
),
)
check_has_required_conditioning_annotations(token)
token = set_common_annotations(token)
token.set_annotation("res_id", np.full(token.shape[0], 1)) # Reset to 1
return token
def accumulate_components(
components_to_accumulate: List[Union[str, int]],
*,
# Tokens from input
indexed_tokens: Dict[str, AtomArray],
unindexed_tokens: Dict[str, AtomArray],
# Additional parameters
atom_array_accum=[],
start_chain: str = "A",
start_resid: int = 1,
unindexed_breaks: Optional[List[bool]] = [],
**kwargs,
) -> AtomArray:
# ... Create list of components
assert (
x := (set(list(indexed_tokens.keys()) + list(unindexed_tokens.keys())))
).issubset(
(y := set(components_to_accumulate))
), "Unindexed and indexed set {} is not subset of components to accumulate {}".format(
x, y
)
all_tokens = indexed_tokens | unindexed_tokens
all_annots = []
[
all_annots.extend(list(tok.get_annotation_categories()))
for tok in all_tokens.values()
]
all_annots = set(all_annots)
# ... For-loop accum variables
unindexed_components_started = (
False # once one unindexed component is added, stop adding diffused residues
)
chain = start_chain
res_id = start_resid
molecule_id = 0
# ... Insert contig information one- by one-
assert len(components_to_accumulate) == len(
unindexed_breaks
), "Mismatch in number of components to accumulate and breaks"
for component, is_break in zip(components_to_accumulate, unindexed_breaks):
if component == "/0":
# Reset iterators on next chain
chain = chr(ord(chain) + 1)
molecule_id += 1
res_id = 1
continue
# ... Create array to insert
if str(component)[0].isalpha(): # motif (e.g. "A22")
n = 1
# ... Fetch the motif residue
token = all_tokens[component]
# ... Insert breakpoint when break clause is met
if exists(is_break) and is_break:
if not unindexed_components_started:
chain = start_chain
unindexed_components_started = True
token.set_annotation(
"is_motif_atom_unindexed_motif_breakpoint",
np.ones(token.shape[0], dtype=int),
)
else:
token.set_annotation(
"is_motif_atom_unindexed_motif_breakpoint",
np.zeros(token.shape[0], dtype=int),
)
else:
n = int(component)
# ... Skip if none or unindexed
if n == 0 or unindexed_components_started:
res_id += n
continue
# ... Create diffused residues
token = create_diffused_residues(n, all_annots)
# ... Set index of insertion
token = set_indices(
array=token,
chain=chain,
res_id_start=res_id,
molecule_id=molecule_id,
component=component,
)
assert (
len(get_token_starts(token)) == n
), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
# ... Insert & Increment residue ID
atom_array_accum.append(token)
res_id += n
# ... Concatenate all components
atom_array_accum = struc.concatenate(atom_array_accum)
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
# Reset res_id for unindexed residues to avoid duplicates
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
atom_array_accum.is_motif_atom_unindexed.astype(bool)
):
max_id = np.max(
atom_array_accum[
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
].res_id
)
atom_array_accum.res_id[
atom_array_accum.is_motif_atom_unindexed.astype(bool)
] += max_id + 1
# ... Bonds
if atom_array_accum.bonds is None:
atom_array_accum.bonds = BondList(atom_array_accum.array_length())
return atom_array_accum

View File

@@ -17,23 +17,6 @@ from rfd3.constants import (
OPTIONAL_CONDITIONING_VALUES,
REQUIRED_INFERENCE_ANNOTATIONS,
)
from rfd3.inference.components import (
fetch_mask_from_component,
fetch_mask_from_idx,
get_design_pattern_with_constraints,
get_motif_components_and_breaks,
get_name_mask,
split_contig,
)
from rfd3.inference.inference_utils import (
create_cb_atoms,
create_o_atoms,
extract_ligand_array,
inference_load_,
set_com,
set_common_annotations,
set_indices,
)
from rfd3.inference.symmetry.symmetry_utils import (
center_symmetric_src_atom_array,
make_symmetric_atom_array,
@@ -45,8 +28,25 @@ from rfd3.transforms.conditioning_base import (
set_default_conditioning_annotations,
)
from rfd3.transforms.util_transforms import assign_types_
from rfd3.utils.inference import (
create_cb_atoms,
create_o_atoms,
extract_ligand_array,
inference_load_,
set_com,
set_common_annotations,
set_indices,
)
from modelhub.common import exists
from modelhub.utils.components import (
fetch_mask_from_component,
fetch_mask_from_idx,
get_design_pattern_with_constraints,
get_motif_components_and_breaks,
get_name_mask,
split_contig,
)
from modelhub.utils.ddp import RankedLogger
logging.basicConfig(level=logging.INFO)

View File

@@ -9,7 +9,8 @@ from pydantic import (
model_serializer,
model_validator,
)
from rfd3.inference.components import (
from modelhub.utils.components import (
ComponentStr,
fetch_mask_from_idx,
get_name_mask,

View File

@@ -1,5 +1,6 @@
import numpy as np
from rfd3.inference.components import fetch_mask_from_idx
from modelhub.utils.components import fetch_mask_from_idx
def expand_contig_to_resid_from_string(contig_string):

View File

@@ -8,7 +8,6 @@ from pydantic import (
ConfigDict,
Field,
)
from rfd3.inference.components import fetch_mask_from_component
from rfd3.inference.symmetry.atom_array import (
FIXED_ENTITY_ID,
FIXED_TRANSFORM_ID,
@@ -33,6 +32,7 @@ from rfd3.inference.symmetry.frames import (
)
from rfd3.transforms.conditioning_base import get_motif_features
from modelhub.utils.components import fetch_mask_from_component
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -4,30 +4,6 @@ import torch.nn as nn
from modelhub.training.checkpoint import activation_checkpointing
class Loss(nn.Module):
def __init__(self, **losses):
super().__init__()
self.to_compute = []
for loss_name, loss in losses.items():
loss_fn = hydra.utils.instantiate(loss)
print(f"Adding loss {loss_name} to the loss function")
self.to_compute.append(loss_fn)
def forward(
self,
network_input,
network_output,
loss_input,
):
loss_dict = {}
loss = 0
for loss_fn in self.to_compute:
loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
loss += loss_
loss_dict.update(loss_dict_)
loss_dict["total_loss"] = loss.detach()
return loss, loss_dict
class SequenceLoss(nn.Module):
def __init__(self, weight, min_t=0, max_t=torch.inf):
super().__init__()
@@ -80,7 +56,7 @@ class SequenceLoss(nn.Module):
return self.weight * token_loss, outs
class SimpleDiffusionLoss(nn.Module):
class DiffusionLoss(nn.Module):
def __init__(
self,
*,
@@ -94,8 +70,7 @@ class SimpleDiffusionLoss(nn.Module):
unindexed_t_alpha=1.0,
unindexed_norm_p=1.0,
lp_weight=0.0,
normalize_virtual_atom_weight=False,
alpha_normalized_virtual_atom_weight=3.9,
**_, # dump args from old configs
):
super().__init__()
self.weight = weight
@@ -108,8 +83,6 @@ class SimpleDiffusionLoss(nn.Module):
self.unindexed_t_alpha = unindexed_t_alpha
self.lp_weight = lp_weight
self.alpha_ligand = alpha_ligand
self.normalize_virtual_atom_weight = normalize_virtual_atom_weight
self.alpha_normalized_virtual_atom_weight = alpha_normalized_virtual_atom_weight
self.alpha_polar_residues = alpha_polar_residues
self.get_lambda = (
@@ -123,18 +96,10 @@ class SimpleDiffusionLoss(nn.Module):
crd_mask_L = loss_input["crd_mask_L"] # (D, L)
crd_mask_L = crd_mask_L.unsqueeze(0).expand(D, -1)
tok_idx = network_input["f"]["atom_to_token_map"]
is_motif_token = network_input["f"]["is_motif_token"] # N
t = network_input["t"] # (D,)
is_original_unindexed_token = loss_input["is_original_unindexed_token"][tok_idx]
is_polar_atom = network_input["f"]["is_polar"][tok_idx]
is_ligand = network_input["f"]["is_ligand"][tok_idx]
# Treat fully fixed atoms as non-lossable atoms to provide stable normalization
# is_motif_atom_with_fixed_coord = network_input["f"][
# "is_motif_atom_with_fixed_coord"
# ]
# crd_mask_L = crd_mask_L * ~is_motif_atom_with_fixed_coord[None]
is_virtual_atom = network_input["f"]["is_virtual"] # L
is_sidechain_atom = network_input["f"]["is_sidechain"] # L
is_sidechain_atom = is_sidechain_atom & ~is_virtual_atom
@@ -148,29 +113,6 @@ class SimpleDiffusionLoss(nn.Module):
# Upweight polar residues
w_L[is_polar_atom] *= self.alpha_polar_residues
if self.normalize_virtual_atom_weight:
# Divide by the number of virtual atoms within a token
n_virtual_atoms_per_token = torch.zeros_like(is_motif_token).float()
n_virtual_atoms_per_token.scatter_add_(
0, tok_idx.long(), is_virtual_atom.float()
)
n_virtual_atoms_per_token = n_virtual_atoms_per_token.clamp(min=1)
# Also divide by the number of heavy atoms within a token
n_sc_atoms_per_token = torch.zeros_like(is_motif_token).float()
n_sc_atoms_per_token.scatter_add_(
0, tok_idx.long(), is_sidechain_atom.float()
)
n_sc_atoms_per_token = n_sc_atoms_per_token.clamp(min=1)
w_L[is_virtual_atom] /= (
self.alpha_normalized_virtual_atom_weight
) * n_virtual_atoms_per_token[tok_idx][is_virtual_atom]
w_L[is_sidechain_atom] /= (
10 - self.alpha_normalized_virtual_atom_weight
) * n_sc_atoms_per_token[tok_idx][is_sidechain_atom]
w_L = w_L[None].expand(D, -1) * crd_mask_L
X_gt_L = torch.nan_to_num(loss_input["X_gt_L_in_input_frame"])

View File

@@ -3,7 +3,7 @@ import numpy as np
from biotite.structure.info import residue
from scipy.spatial.distance import cdist
from modelhub.metrics.base import Metric
from modelhub.metrics.metric import Metric
def collapsing_virtual_atoms_batched(

View File

@@ -0,0 +1,105 @@
import os
import hydra
import torch
from omegaconf import DictConfig
from rfd3.model.cfg_utils import (
strip_f,
)
from rfd3.model.inference_sampler import ConditionalDiffusionSampler
from rfd3.model.layers.encoders import TokenInitializer
from torch import nn
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
class RFD3(nn.Module):
"""
Simplified model for generation
This module level serves to wrap the diffusion module of AF3
to be roughly equivalent to the AF3 model w/o trunk processing.
Allows the same sampler to be used
"""
def __init__(
self,
*,
# Channel dimensions ('global' features)
c_s: int,
c_z: int,
c_atom: int,
c_atompair: int,
# Arguments for modules that will be instantiated
token_initializer: DictConfig | dict,
diffusion_module: DictConfig | dict,
inference_sampler: DictConfig | dict,
**_,
):
super().__init__()
# Check for chunked P_LL mode via environment variable
use_chunked_pll = os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"
ranked_logger.info(f"RFD3 initialized with chunked_pll={use_chunked_pll}")
# Simple constant-feature initializer
self.token_initializer = TokenInitializer(
c_s=c_s,
c_z=c_z,
c_atom=c_atom,
c_atompair=c_atompair,
use_chunked_pll=use_chunked_pll,
**token_initializer,
)
# Diffusion module instantiated to allow for config scripting
self.diffusion_module = hydra.utils.instantiate(
diffusion_module, c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z
)
self.use_classifier_free_guidance = (
inference_sampler["use_classifier_free_guidance"]
and inference_sampler["cfg_scale"] != 1.0
)
self.cfg_features = inference_sampler.pop("cfg_features", [])
# ... initialize the inference sampler, which performs a full diffusion rollout during inference
self.inference_sampler = ConditionalDiffusionSampler(**inference_sampler)
def forward(
self,
input: dict,
coord_atom_lvl_to_be_noised: torch.Tensor = None,
n_cycle=None,
**_,
) -> dict:
initializer_outputs = self.token_initializer(input["f"])
if self.training:
# Single denoising step
return self.diffusion_module(
X_noisy_L=input["X_noisy_L"],
t=input["t"],
f=input["f"],
n_recycle=n_cycle,
**initializer_outputs,
) # [D, L, 3]
else:
if self.use_classifier_free_guidance:
f_ref = strip_f(input["f"], self.cfg_features)
ref_initializer_outputs = self.token_initializer(f_ref)
else:
f_ref = None
ref_initializer_outputs = None
return self.inference_sampler.sample_diffusion_like_af3(
f=input["f"],
f_ref=f_ref, # for cfg
diffusion_module=self.diffusion_module,
diffusion_batch_size=coord_atom_lvl_to_be_noised.shape[0],
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
# Forwarded as **kwargs:
initializer_outputs=initializer_outputs,
ref_initializer_outputs=ref_initializer_outputs, # for cfg
)

View File

@@ -5,11 +5,11 @@ from contextlib import ExitStack
import torch
import torch.nn as nn
from rfd3.model.block_utils import (
from rfd3.model.layers.block_utils import (
bucketize_scaled_distogram,
create_attention_indices,
)
from rfd3.model.blocks import (
from rfd3.model.layers.blocks import (
CompactStreamingDecoder,
Downcast,
LinearEmbedWithPool,
@@ -17,17 +17,14 @@ from rfd3.model.blocks import (
LocalAtomTransformer,
LocalTokenTransformer,
)
from rfd3.model.encoders import (
from rfd3.model.layers.encoders import (
DiffusionTokenEncoder,
)
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
from modelhub.model.AF3_structure import (
from modelhub.model.layers.blocks import (
FourierEmbedding,
)
from modelhub.model.layers.af3_diffusion_transformer import (
DiffusionTransformer,
)
from modelhub.model.layers.layer_utils import RMSNorm, linearNoBias
logger = logging.getLogger(__name__)
@@ -53,7 +50,7 @@ class RFD3DiffusionModule(nn.Module):
atom_attention_decoder,
# upcast,
downcast,
use_local_token_attention=False,
use_local_token_attention=True,
**_,
):
super().__init__()
@@ -111,20 +108,12 @@ class RFD3DiffusionModule(nn.Module):
**diffusion_token_encoder,
)
if not use_local_token_attention:
self.diffusion_transformer = DiffusionTransformer(
c_token=c_token,
c_tokenpair=c_z,
c_s=c_s,
**diffusion_transformer,
)
else:
self.diffusion_transformer = LocalTokenTransformer(
c_token=c_token,
c_tokenpair=c_z,
c_s=c_s,
**diffusion_transformer,
)
self.diffusion_transformer = LocalTokenTransformer(
c_token=c_token,
c_tokenpair=c_z,
c_s=c_s,
**diffusion_transformer,
)
self.decoder = CompactStreamingDecoder(
c_atom=c_atom,
@@ -342,21 +331,18 @@ class RFD3DiffusionModule(nn.Module):
)
# ... Diffusion transformer
if not self.use_local_token_attention:
A_I = self.diffusion_transformer(A_I, S_I, Z_II, Beta_II=None)
else:
A_I = self.diffusion_transformer(
A_I,
S_I,
Z_II,
f=f,
X_L=(
X_noisy_L[..., f["is_ca"], :]
if X_L_self is None
else X_L_self[..., f["is_ca"], :]
),
full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"),
)
A_I = self.diffusion_transformer(
A_I,
S_I,
Z_II,
f=f,
X_L=(
X_noisy_L[..., f["is_ca"], :]
if X_L_self is None
else X_L_self[..., f["is_ca"], :]
),
full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"),
)
# ... Decoder readout
# Check if using chunked P_LL mode

View File

@@ -1,12 +0,0 @@
from rfd3.model.encoders import SimpleRecycler
from modelhub.model.AF3 import AF3
class AF3DesignTrunk(AF3):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.recycler = SimpleRecycler(
c_s=kwargs["c_s"], c_z=kwargs["c_z"], **kwargs["recycler"]
)
self.distogram_head = None

View File

@@ -1,143 +1,20 @@
import inspect
import os
from dataclasses import dataclass
from typing import Any, Literal
import hydra
import torch
from beartype.typing import Any
from jaxtyping import Float
from omegaconf import DictConfig
from rfd3.inference.symmetry.symmetry_utils import (
apply_symmetry_to_xyz_atomwise,
)
from rfd3.model.cfg_utils import (
strip_f,
strip_X,
)
from rfd3.model.encoders import TokenInitializer
from torch import nn
from modelhub import SWAP_LAYER_NORM_FOR_RMS_NORM
from modelhub.alignment import weighted_rigid_align
from modelhub.common import exists
from modelhub.data.rotation_augmentation import (
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.rotation_augmentation import (
rot_vec_mul,
uniform_random_rotation,
)
from modelhub.diffusion_samplers.inference_sampler import SampleDiffusion
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
class RFD3(nn.Module):
"""
Simplified model for generation
This module level serves to wrap the diffusion module of AF3
to be roughly equivalent to the AF3 model w/o trunk processing.
Allows the same sampler to be used
"""
def __init__(
self,
*,
# Channel dimensions ('global' features)
c_s: int,
c_z: int,
c_atom: int,
c_atompair: int,
# Arguments for modules that will be instantiated
token_initializer: DictConfig | dict,
diffusion_module: DictConfig | dict,
inference_sampler: DictConfig | dict,
**_,
):
super().__init__()
# Register whether the model uses RMSNorms or LayerNorms
self.register_buffer(
"use_rmsnorm",
torch.tensor(SWAP_LAYER_NORM_FOR_RMS_NORM, dtype=torch.bool),
)
# Check for chunked P_LL mode via environment variable
use_chunked_pll = os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"
ranked_logger.info(f"RFD3 initialized with chunked_pll={use_chunked_pll}")
# Simple constant-feature initializer
self.token_initializer = TokenInitializer(
c_s=c_s,
c_z=c_z,
c_atom=c_atom,
c_atompair=c_atompair,
use_chunked_pll=use_chunked_pll,
**token_initializer,
)
# Diffusion module instantiated to allow for config scripting
self.diffusion_module = hydra.utils.instantiate(
diffusion_module, c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z
)
self.use_classifier_free_guidance = (
inference_sampler["use_classifier_free_guidance"]
and inference_sampler["cfg_scale"] != 1.0
)
self.cfg_features = inference_sampler.pop("cfg_features", [])
# ... initialize the inference sampler, which performs a full diffusion rollout during inference
self.inference_sampler = ConditionalDiffusionSampler(**inference_sampler)
def forward(
self,
input: dict,
coord_atom_lvl_to_be_noised: torch.Tensor = None,
n_cycle=None,
**_,
) -> dict:
# Assert that the correct swap is used
if bool(self.use_rmsnorm.item()) != SWAP_LAYER_NORM_FOR_RMS_NORM:
raise ValueError(
"Loaded checkpoint with use RMSNorm {} but environment variable set expects {}".format(
self.use_rmsnorm.item(),
SWAP_LAYER_NORM_FOR_RMS_NORM,
)
+ " Set environment variable SWAP_LAYER_NORM_FOR_RMS_NORM to {}".format(
{True: "1", False: "0"}[self.use_rmsnorm.item()]
)
)
initializer_outputs = self.token_initializer(input["f"])
if self.training:
# Single denoising step
return self.diffusion_module(
X_noisy_L=input["X_noisy_L"],
t=input["t"],
f=input["f"],
n_recycle=n_cycle,
**initializer_outputs,
) # [D, L, 3]
else:
if self.use_classifier_free_guidance:
f_ref = strip_f(input["f"], self.cfg_features)
ref_initializer_outputs = self.token_initializer(f_ref)
else:
f_ref = None
ref_initializer_outputs = None
return self.inference_sampler.sample_diffusion_like_af3(
f=input["f"],
f_ref=f_ref, # for cfg
diffusion_module=self.diffusion_module,
diffusion_batch_size=input["t"].shape[0],
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
# Forwarded as **kwargs:
initializer_outputs=initializer_outputs,
ref_initializer_outputs=ref_initializer_outputs, # for cfg
)
def centre_random_augment_around_motif(
X_L: torch.Tensor, # (D, L, 3) noisy diffused coordinates
coord_atom_lvl_to_be_noised: torch.Tensor, # (D, L, 3) original coordinates
@@ -193,159 +70,70 @@ def centre_random_augment_around_motif(
return X_L, R
class SampleDiffusionWithMotif(SampleDiffusion):
def __init__(
self,
center_option: str = "all",
move_noise_to_reset_com: bool = False, # Reset the COM of the diffuse region after the re-noising operation in each diffusion step
s_trans: float = 1.0, # Translational noise scale for augmentation during inference
s_jitter_origin: float = 0.0, # Random translation of motif at the start of diffusion
fraction_of_steps_to_fix_motif: float = 0.0, # Fraction of steps to let the model not move the motif. e.g. if we have 10 steps, set this value to 0.2 will make model not move motif for the first 2 steps.
skip_few_diffusion_steps: bool = False, # Choose to skip some diffusion steps based on the noise scheme
inference_noise_scaling_factor: float = 1.0,
# Additional argumnets
gamma_min2: float = 0.0,
allow_realignment: bool = False,
insert_motif_at_end: bool = True,
use_classifier_free_guidance: bool = False,
cfg_scale: float = 2.0,
use_frame_guidance: bool = False, # Use frame guidance to align the virtual atoms to the central atom
fg_scale: float = 1.5,
zero_drift_noise: bool = False,
cfg_t_max: float
| None = None, # If not None, use classifier-free guidance only for t < cfg_t_max
**kwargs,
):
self.gamma_min2 = gamma_min2
self.allow_realignment = allow_realignment
self.insert_motif_at_end = insert_motif_at_end
self.use_classifier_free_guidance = use_classifier_free_guidance
self.cfg_scale = cfg_scale
self.cfg_t_max = cfg_t_max
@dataclass(kw_only=True)
class SampleDiffusionWithMotif:
"""Diffusion sampler that supports optional motif alignment."""
self.center_option = center_option
self.fraction_of_steps_to_fix_motif = fraction_of_steps_to_fix_motif
self.move_noise_to_reset_com = move_noise_to_reset_com
self.s_trans = s_trans
self.skip_few_diffusion_steps = skip_few_diffusion_steps
self.s_jitter_origin = s_jitter_origin
self.inference_noise_scaling_factor = inference_noise_scaling_factor
self.zero_drift_noise = zero_drift_noise
# Standard EDM args
num_timesteps: int # AF-3: 200
min_t: int # AF-3: 0
max_t: int # AF-3: 1
sigma_data: int # AF-3: 16
s_min: float # AF-3: 4e-4
s_max: int # AF-3: 160
p: int # AF-3: 7
gamma_0: float # AF-3: 0.8
gamma_min: float # AF-3: 1.0
noise_scale: float # AF-3: 1.003
step_scale: float # AF-3: 1.5
solver: Literal["af3"]
self.use_frame_guidance = use_frame_guidance
self.fg_scale = fg_scale
# RFD3 / design args
center_option: str = "all"
s_trans: float = 1.0
s_jitter_origin: float = 0.0
fraction_of_steps_to_fix_motif: float = 0.0
skip_few_diffusion_steps: bool = False
allow_realignment: bool = False
insert_motif_at_end: bool = True
use_classifier_free_guidance: bool = False
cfg_scale: float = 2.0
cfg_t_max: float | None = None
use_frame_guidance: bool = False
fg_scale: float = 1.0
super().__init__(**kwargs)
# TODO: Make this a properly-parametrized function in terms of instance variables provided in the configs
# For now, it's just hard-coded for early testing
def modify_noise_schedule(self, noise_schedule: torch.Tensor) -> torch.Tensor:
"""
Modify the noise schedule to skip more steps at high noise and fewer at low noise.
"""
mask = torch.ones_like(noise_schedule, dtype=bool)
mask_len = len(mask)
mask[: mask_len // 4] = torch.arange(mask_len // 4) % 5 == 0
mask[mask_len // 4 : mask_len // 2] = (
torch.arange(mask_len // 4, mask_len // 2) % 3 == 0
)
mask[mask_len // 2 : -mask_len // 4] = (
torch.arange(mask_len // 2, mask_len - mask_len // 4) % 2 == 0
)
return noise_schedule[mask]
def _get_initial_structure(
self,
c0: torch.Tensor,
D: int,
L: int,
coord_atom_lvl_to_be_noised: torch.Tensor,
is_motif_atom_with_fixed_coord,
def _construct_inference_noise_schedule(
self, device: torch.device, partial_t: float = None
) -> torch.Tensor:
noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
noise[..., is_motif_atom_with_fixed_coord, :] = 0 # Zero out noise going in
X_L = noise + coord_atom_lvl_to_be_noised
return X_L
"""Constructs a noise schedule for use during inference.
def _move_noise_to_reset_com(self, X_noisy_L, is_motif_atom_with_fixed_coord):
The inference noise schedule is defined in the AF-3 supplement as:
t_hat = sigma_data * (s_max**(1/p) + t * (s_min**(1/p) - s_max**(1/p)))**p
Returns:
torch.Tensor: A tensor representing the noise schedule `t_hat`.
Reference:
AlphaFold 3 Supplement, Section 3.7.1.
"""
Reset the COM of the diffuse region after the re-noising operation in each diffusion step.
"""
if self.center_option == "motif":
print(
"Warning: Moving the noise is not relevant when centering on the motif! Will be ignored."
# Create a linearly spaced tensor of timesteps between min_t and max_t
t = torch.linspace(self.min_t, self.max_t, self.num_timesteps, device=device)
# Construct the noise schedule, using the formula provided in the reference
t_hat = (
self.sigma_data
* (
(self.s_max) ** (1 / self.p)
+ t * (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p))
)
elif self.center_option == "diffuse":
displacement_vec = torch.mean(
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :],
dim=-2,
keepdim=True,
) # (D, 1, 3) - COM of noisy diffused atoms
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] = (
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] - displacement_vec
)
else:
n_diffused = (~is_motif_atom_with_fixed_coord).sum()
displacement_vec = (
torch.sum(
X_noisy_L,
dim=-2,
keepdim=True,
)
/ n_diffused
)
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] = (
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] - displacement_vec
)
return X_noisy_L
def _skip_few_diffusion_steps(self, noise_schedule: torch.Tensor) -> torch.Tensor:
"""
Modify the noise schedule to skip more steps at high noise and fewer at low noise.
i.e. When the noise is high (first few diffusion steps), skip more steps;
When the noise is lower, skip fewer steps;
When the noise is low, keep all the steps.
"""
mask = torch.ones_like(noise_schedule, dtype=bool)
mask_len = len(mask)
mask[: mask_len // 4] = torch.arange(mask_len // 4) % 5 == 0
mask[mask_len // 4 : mask_len // 2] = (
torch.arange(mask_len // 4, mask_len // 2) % 3 == 0
)
mask[mask_len // 2 : -mask_len // 4] = (
torch.arange(mask_len // 2, mask_len - mask_len // 4) % 2 == 0
)
return noise_schedule[mask]
def sample_diffusion_like_af3(
self,
*,
f: dict[str, Any],
diffusion_module: torch.nn.Module,
diffusion_batch_size: int,
coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
initializer_outputs,
ref_initializer_outputs: dict[str, Any] | None,
f_ref: dict[str, Any] | None,
) -> dict[str, Any]:
# Motif setup to recenter the motif at every step
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
# Book-keeping
noise_schedule = self._construct_inference_noise_schedule(
device=coord_atom_lvl_to_be_noised.device
** self.p
)
# Choose to adjust the noise schedule
if self.skip_few_diffusion_steps:
noise_schedule = self._skip_few_diffusion_steps(noise_schedule)
if "partial_t" in f:
if partial_t is not None:
# For now, partial t is a global parameter
partial_t = f["partial_t"].mean()
partial_t = float(partial_t.mean())
noise_schedule = t_hat
ranked_logger.info("Using partial diffusion with t={}".format(partial_t))
# Debug the noise schedule filtering
@@ -382,11 +170,44 @@ class SampleDiffusionWithMotif(SampleDiffusion):
f"Using fallback: final step with t={noise_schedule[0].item():.6f}"
)
return t_hat
def _get_initial_structure(
self,
c0: torch.Tensor,
D: int,
L: int,
coord_atom_lvl_to_be_noised: torch.Tensor,
is_motif_atom_with_fixed_coord,
) -> torch.Tensor:
noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
noise[..., is_motif_atom_with_fixed_coord, :] = 0 # Zero out noise going in
X_L = noise + coord_atom_lvl_to_be_noised
return X_L
def sample_diffusion_like_af3(
self,
*,
f: dict[str, Any],
diffusion_module: torch.nn.Module,
diffusion_batch_size: int,
coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
initializer_outputs,
ref_initializer_outputs: dict[str, Any] | None,
f_ref: dict[str, Any] | None,
) -> dict[str, Any]:
# Motif setup to recenter the motif at every step
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
# Book-keeping
noise_schedule = self._construct_inference_noise_schedule(
device=coord_atom_lvl_to_be_noised.device,
partial_t=f.get("partial_t", None),
)
L = f["ref_element"].shape[0]
D = diffusion_batch_size
noise_schedule = noise_schedule * self.inference_noise_scaling_factor
X_L = self._get_initial_structure(
c0=noise_schedule[0],
D=D,
@@ -433,7 +254,7 @@ class SampleDiffusionWithMotif(SampleDiffusion):
# Update gamma & step scale
gamma = self.gamma_0 if c_t > self.gamma_min else 0
step_scale = self.step_scale if c_t > self.gamma_min2 else 3.0
step_scale = self.step_scale
# Compute the value of t_hat
t_hat = c_t_minus_1 * (gamma + 1)
@@ -444,19 +265,11 @@ class SampleDiffusionWithMotif(SampleDiffusion):
* torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
* torch.normal(mean=0.0, std=1.0, size=X_L.shape, device=X_L.device)
)
if self.zero_drift_noise:
epsilon_L = epsilon_L - torch.mean(epsilon_L, dim=-2, keepdim=True)
epsilon_L[..., is_motif_atom_with_fixed_coord, :] = (
0 # No noise injection for fixed atoms
)
X_noisy_L = X_L + epsilon_L
# Adjustg the center of mass
if self.move_noise_to_reset_com:
X_noisy_L = self._move_noise_to_reset_com(
X_noisy_L, is_motif_atom_with_fixed_coord
)
# Denoise the coordinates
# Handle chunked mode vs standard mode
if "chunked_pairwise_embedder" in initializer_outputs:
@@ -624,53 +437,10 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
# Book-keeping
noise_schedule = self._construct_inference_noise_schedule(
device=coord_atom_lvl_to_be_noised.device
device=coord_atom_lvl_to_be_noised.device,
partial_t=f.get("partial_t", None),
)
# Handle partial_t for symmetry sampler (same as regular sampler)
if "partial_t" in f:
# For now, partial t is a global parameter
partial_t = f["partial_t"].mean()
ranked_logger.info(
"Symmetry sampler: Using partial diffusion with t={}".format(partial_t)
)
# Debug the noise schedule filtering
original_schedule_len = len(noise_schedule)
original_max = noise_schedule.max().item()
original_min = noise_schedule.min().item()
noise_schedule = noise_schedule[noise_schedule <= partial_t]
new_schedule_len = len(noise_schedule)
if new_schedule_len > 0:
new_max = noise_schedule.max().item()
new_min = noise_schedule.min().item()
ranked_logger.info(
f"Symmetry noise schedule: {original_schedule_len}{new_schedule_len} steps"
)
ranked_logger.info(
f"Symmetry original range: [{original_min:.3f}, {original_max:.3f}]"
)
ranked_logger.info(
f"Symmetry filtered range: [{new_min:.3f}, {new_max:.3f}]"
)
else:
ranked_logger.warning(
f"Symmetry sampler: No noise schedule steps found with t <= {partial_t}!"
)
ranked_logger.info(
f"Symmetry original schedule range: [{original_min:.3f}, {original_max:.3f}]"
)
# Fallback to smallest available step
noise_schedule_original = self._construct_inference_noise_schedule(
device=coord_atom_lvl_to_be_noised.device
)
noise_schedule = noise_schedule_original[-1:] # Just use the final step
ranked_logger.info(
f"Symmetry using fallback: final step with t={noise_schedule[0].item():.6f}"
)
L = f["ref_element"].shape[0]
D = diffusion_batch_size
X_L = self._get_initial_structure(
@@ -711,7 +481,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
# Update gamma & step scale
gamma = self.gamma_0 if c_t > self.gamma_min else 0
step_scale = self.step_scale if c_t > self.gamma_min2 else 1.05
step_scale = self.step_scale
# Compute the value of t_hat
t_hat = c_t_minus_1 * (gamma + 1)

View File

@@ -6,18 +6,18 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from opt_einsum import contract as einsum
from rfd3.model.block_utils import (
from rfd3.model.layers.block_utils import (
create_attention_indices,
indices_to_mask,
)
from modelhub.common import exists
from modelhub.model.layers.layer_utils import (
from rfd3.model.layers.layer_utils import (
AdaLN,
LinearBiasInit,
RMSNorm,
linearNoBias,
)
from modelhub.common import exists
from modelhub.training.checkpoint import activation_checkpointing
from modelhub.utils.ddp import RankedLogger
@@ -417,7 +417,6 @@ def sparse_pairbias_attention(
if B.ndim == 3:
B_gathered = B[query_idx, indices, :] # (D, L, k, H)
elif B.ndim == 4: # (D, L, L, H)
# import ipdb; ipdb.set_trace()
B_gathered = B[batch_idx, query_idx, indices, :] # (D, L, k, H)
else:
assert B.shape == (D, L, k, H), "B must be batched with shape (D, L, k, H)"

View File

@@ -6,35 +6,62 @@ import torch.nn as nn
import torch.nn.functional as F
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from einops import rearrange
from rfd3.model.attention import (
from rfd3.model.layers.attention import (
GatedCrossAttention,
LocalAttentionPairBias,
)
from rfd3.model.block_utils import (
from rfd3.model.layers.block_utils import (
build_valid_mask,
create_attention_indices,
group_atoms,
ungroup_atoms,
)
from torch.nn.functional import one_hot
from modelhub import DISABLE_CHECKPOINTING
from modelhub.common import exists
from modelhub.model.layers.af3_diffusion_transformer import (
ConditionedTransitionBlock,
)
from modelhub.model.layers.layer_utils import (
from rfd3.model.layers.layer_utils import (
AdaLN,
EmbeddingLayer,
LinearBiasInit,
RMSNorm,
Transition,
collapse,
linearNoBias,
)
from modelhub.model.layers.pairformer_layers import PairformerBlock
from rfd3.model.layers.pairformer_layers import PairformerBlock
from torch.nn.functional import one_hot
from modelhub import DISABLE_CHECKPOINTING
from modelhub.common import exists
logger = logging.getLogger(__name__)
# SwiGLU transition block with adaptive layernorm
class ConditionedTransitionBlock(nn.Module):
def __init__(self, c_token, c_s, n=2):
super().__init__()
self.ada_ln = AdaLN(c_a=c_token, c_s=c_s)
self.linear_1 = linearNoBias(c_token, c_token * n)
self.linear_2 = linearNoBias(c_token, c_token * n)
self.linear_output_project = nn.Sequential(
LinearBiasInit(c_s, c_token, biasinit=-2.0),
nn.Sigmoid(),
)
self.linear_3 = linearNoBias(c_token * n, c_token)
def forward(
self,
Ai, # [B, I, C_token]
Si, # [B, I, C_token]
):
Ai = self.ada_ln(Ai, Si)
# BUG: This is not the correct implementation of SwiGLU
# Bi = torch.sigmoid(self.linear_1(Ai)) * self.linear_2(Ai)
# FIX: This is the correct implementation of SwiGLU
Bi = torch.nn.functional.silu(self.linear_1(Ai)) * self.linear_2(Ai)
# Output projection (from adaLN-Zero)
return self.linear_output_project(Si) * self.linear_3(Bi)
class PositionPairDistEmbedder(nn.Module):
def __init__(self, c_atompair, embed_frame=True):
super().__init__()

View File

@@ -10,8 +10,7 @@ from typing import Optional
import torch
import torch.nn as nn
from modelhub.model.layers.layer_utils import RMSNorm, linearNoBias
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
class ChunkedPositionPairDistEmbedder(nn.Module):

View File

@@ -3,11 +3,11 @@ import logging
import torch
import torch.nn as nn
from rfd3.model.block_utils import (
from rfd3.model.layers.block_utils import (
bucketize_scaled_distogram,
pairwise_mean_pool,
)
from rfd3.model.blocks import (
from rfd3.model.layers.blocks import (
Downcast,
LocalAtomTransformer,
OneDFeatureEmbedder,
@@ -15,19 +15,19 @@ from rfd3.model.blocks import (
RelativePositionEncodingWithIndexRemoval,
SinusoidalDistEmbed,
)
from rfd3.model.chunked_pairwise import (
from rfd3.model.layers.chunked_pairwise import (
ChunkedPairwiseEmbedder,
ChunkedPositionPairDistEmbedder,
ChunkedSinusoidalDistEmbed,
)
from modelhub.common import exists
from modelhub.model.layers.layer_utils import (
from rfd3.model.layers.layer_utils import (
RMSNorm,
Transition,
linearNoBias,
)
from modelhub.model.layers.pairformer_layers import PairformerBlock
from rfd3.model.layers.pairformer_layers import PairformerBlock
from modelhub.common import exists
from modelhub.training.checkpoint import activation_checkpointing
logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,197 @@
import math
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import silu
from modelhub.training.checkpoint import activation_checkpointing
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
ranked_logger.info("Fused RMSNorm enabled!")
RMSNorm_ = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
ranked_logger.warning(
"Using nn.RMSNorm instead of apex.normalization.fused_layer_norm.FusedRMSNorm."
"Ensure you're using the correct apptainer"
)
RMSNorm_ = nn.RMSNorm
# Allow bias=False to be passed for RMSNorm
def RMSNorm(*args, **kwargs):
if "bias" in kwargs:
kwargs.pop("bias")
return RMSNorm_(*args, **kwargs)
SWAP_LAYER_NORM_FOR_RMS_NORM = True
RMSNorm = RMSNorm if SWAP_LAYER_NORM_FOR_RMS_NORM else nn.LayerNorm
linearNoBias = partial(torch.nn.Linear, bias=False)
class EmbeddingLayer(nn.Linear):
"""
Specialized linear layer for correct weight initialization for embedding layers.
Embedding layers are functionally a multiplication of an N channel input by an NxC weight matrix to produce an
embedding of length C. However, we compute the components separately with a ModuleDict, then sum at the end, for
embedding reusability and interoperability purposes.
This layer uses Xavier initialization as described in [1]_.
References
----------
.. [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty
of training deep feedforward neural networks." (2010)
http://proceedings.mlr.press/v9/glorot10a.html
"""
def __init__(
self,
this_in_features,
total_embedding_features,
out_features,
device=None,
dtype=None,
):
self.total_embedding_features = total_embedding_features
self.out_features = out_features
super().__init__(
this_in_features, out_features, bias=False, device=device, dtype=dtype
)
self.reset_parameters()
def reset_parameters(self, **kwargs):
super().reset_parameters()
a = math.sqrt(6.0 / float(self.total_embedding_features + self.out_features))
nn.init._no_grad_uniform_(self.weight, -a, a)
def collapse(x, L):
return x.reshape((L, x.numel() // L))
class MultiDimLinear(nn.Linear):
def __init__(self, in_features, out_shape, norm=False, **kwargs):
self.out_shape = out_shape
out_features = np.prod(out_shape)
super().__init__(in_features, out_features, **kwargs)
if norm:
self.ln = RMSNorm((out_features,))
self.use_ln = True
else:
self.use_ln = False
self.reset_parameters()
def reset_parameters(self, **kwargs) -> None:
super().reset_parameters()
nn.init.xavier_uniform_(self.weight)
def forward(self, x):
out = super().forward(x)
if self.use_ln:
out = self.ln(out)
return out.reshape(x.shape[:-1] + self.out_shape)
class LinearBiasInit(nn.Linear):
def __init__(self, *args, biasinit, **kwargs):
assert biasinit == -2.0 # Sanity check
self.biasinit = biasinit
super().__init__(*args, **kwargs)
def reset_parameters(self) -> None:
super().reset_parameters()
self.bias.data.fill_(self.biasinit)
class Transition(nn.Module):
def __init__(self, n, c):
super().__init__()
self.layer_norm_1 = RMSNorm(c)
self.linear_1 = linearNoBias(c, n * c)
self.linear_2 = linearNoBias(c, n * c)
self.linear_3 = linearNoBias(n * c, c)
@activation_checkpointing
def forward(
self,
X,
):
X = self.layer_norm_1(X)
A = self.linear_1(X)
B = self.linear_2(X)
X = self.linear_3(silu(A) * B)
return X
class AdaLN(nn.Module):
def __init__(self, c_a, c_s, n=2):
super().__init__()
self.ln_a = RMSNorm(normalized_shape=(c_a,), elementwise_affine=False)
self.ln_s = RMSNorm(normalized_shape=(c_s,), bias=False)
self.to_gain = nn.Sequential(
nn.Linear(c_s, c_a),
nn.Sigmoid(),
)
self.to_bias = linearNoBias(c_s, c_a)
def forward(
self,
Ai, # [B, I, C_a]
Si, # [B, I, C_s]
):
"""
Output:
[B, I, C_a]
"""
Ai = self.ln_a(Ai)
Si = self.ln_s(Si)
return self.to_gain(Si) * Ai + self.to_bias(Si)
def create_batch_dimension_if_not_present(batched_n_dim):
"""
Decorator for adapting a function which expects batched arguments with ndim `batched_n_dim` also
accept unbatched arguments.
"""
def wrap(f):
def _wrap(arg):
inserted_batch_dim = False
if arg.ndim == batched_n_dim - 1:
arg = arg[None]
inserted_batch_dim = True
elif arg.ndim == batched_n_dim:
pass
else:
raise Exception(
f"arg must have {batched_n_dim - 1} or {batched_n_dim} dimensions, got shape {arg.shape=}"
)
o = f(arg)
if inserted_batch_dim:
assert o.shape[0] == 1, f"{o.shape=}[0] != 1"
return o[0]
return o
return _wrap
return wrap
def unpack_args_for_checkpointing(arg_names):
def wrap(f):
def _wrap(*args):
f = args[0]
return f(**dict(zip(arg_names, args)))
return _wrap
return wrap

View File

@@ -0,0 +1,128 @@
import torch
from rfd3.model.layers.layer_utils import (
MultiDimLinear,
RMSNorm,
Transition,
linearNoBias,
)
from torch import nn
from modelhub.training.checkpoint import activation_checkpointing
from modelhub.utils.torch import device_of
class AttentionPairBiasPairformerDeepspeed(nn.Module):
def __init__(self, c_a, c_s, c_pair, n_head, kq_norm=False):
super().__init__()
self.n_head = n_head
self.c_a = c_a
self.c_pair = c_pair
self.c = c_a // n_head
self.to_q = MultiDimLinear(c_a, (n_head, self.c))
self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
self.to_b = linearNoBias(c_pair, n_head)
self.to_g = nn.Sequential(
MultiDimLinear(c_a, (n_head, self.c), bias=False),
nn.Sigmoid(),
)
self.to_a = linearNoBias(c_a, c_a)
# self.linear_output_project = nn.Sequential(
# LinearBiasInit(c_s, c_a, biasinit=-2.),
# nn.Sigmoid(),
# )
self.ln_0 = RMSNorm((c_pair,))
# self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
self.ln_1 = RMSNorm((c_a,))
self.use_deepspeed_evo = False
self.force_bfloat16 = True
def forward(
self,
A_I, # [I, C_a]
S_I, # [I, C_a] | None
Z_II, # [I, I, C_z]
Beta_II=None, # [I, I]
):
# Input projections
assert S_I is None
A_I = self.ln_1(A_I)
if self.use_deepspeed_evo or self.force_bfloat16:
A_I = A_I.to(torch.bfloat16)
Q_IH = self.to_q(A_I) # / np.sqrt(self.c)
K_IH = self.to_k(A_I)
V_IH = self.to_v(A_I)
B_IIH = self.to_b(self.ln_0(Z_II)) + Beta_II[..., None]
G_IH = self.to_g(A_I)
B, L = B_IIH.shape[:2]
if not self.use_deepspeed_evo or L <= 24:
Q_IH = Q_IH / torch.sqrt(
torch.tensor(self.c).to(Q_IH.device, torch.bfloat16)
)
# Attention
A_IIH = torch.softmax(
torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2
) # softmax over j
## G_IH: [I, H, C]
## A_IIH: [I, I, H]
## V_IH: [I, H, C]
A_I = torch.einsum("...ijh,...jhc->...ihc", A_IIH, V_IH)
A_I = G_IH * A_I # [B, I, H, C]
A_I = A_I.flatten(start_dim=-2) # [B, I, Ca]
else:
raise NotImplementedError
A_I = self.to_a(A_I)
return A_I
class PairformerBlock(nn.Module):
"""
Attempt to replicate AF3 architecture from scratch.
"""
def __init__(
self,
c_s,
c_z,
attention_pair_bias,
p_drop=0.1,
triangle_multiplication=None,
triangle_attention=None,
n_transition=4,
use_deepspeed_evo=True,
use_triangle_mult=False,
use_triangle_attn=False,
):
super().__init__()
# self.drop_row = Dropout(broadcast_dim=-2, p_drop=p_drop)
# self.drop_col = Dropout(broadcast_dim=-3, p_drop=p_drop)
self.z_transition = Transition(c=c_z, n=n_transition)
if c_s > 0:
self.s_transition = Transition(c=c_s, n=n_transition)
self.attention_pair_bias = AttentionPairBiasPairformerDeepspeed(
c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias
)
@activation_checkpointing
def forward(self, S_I, Z_II):
with torch.amp.autocast(
device_type=device_of(self).type, enabled=True, dtype=torch.bfloat16
):
Z_II = Z_II + self.z_transition(Z_II)
if S_I is not None:
S_I = S_I + self.attention_pair_bias(
S_I, None, Z_II, Beta_II=torch.tensor([0.0], device=Z_II.device)
)
S_I = S_I + self.s_transition(S_I)
return S_I, Z_II

View File

@@ -1,13 +1,12 @@
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rfd3_exec.sh" "$0" "$@"'
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
import os
import tempfile
from pathlib import Path
import hydra
import rootutils
from dotenv import load_dotenv
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from modelhub.utils.logging import suppress_warnings
@@ -15,6 +14,8 @@ from modelhub.utils.logging import suppress_warnings
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
load_dotenv(override=True)
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
@@ -27,16 +28,18 @@ _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
def run_inference(cfg: DictConfig) -> None:
"""Execute the specified inference pipeline"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
import ipdb
run_params_set = {"inputs", "n_batches", "out_dir"}
run_params = {k: v for k, v in cfg.items() if k in run_params_set}
ipdb.set_trace()
# Create __init__ args by filtering for all configs not in run_params
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
init_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in run_params_set}
init_cfg = OmegaConf.create(init_cfg_dict)
inference_engine = instantiate(init_cfg, _convert_="partial", _recursive_=False)
inference_engine = instantiate(cfg, temp_dir=temp_dir, _convert_="partial")
with suppress_warnings(is_inference=True):
inference_engine.eval()
# # Run inference
with suppress_warnings(is_inference=True):
inference_engine.run(**run_params)
if __name__ == "__main__":

View File

@@ -4,7 +4,6 @@ import logging
import os
import sys
import time
from pathlib import Path
import hydra
import ipdb # noqa: F401
@@ -38,7 +37,7 @@ print(f"Project root: {os.environ.get('PROJECT_ROOT', '../..')}")
is_inference = True
outdir = Path("/home/jbutch/Projects/HT25/af3/modelhub_refactor/rfd3/tests/outs")
# outdir = Path("/home/jbutch/Projects/HT25/af3/modelhub_refactor/rfd3/tests/outs")
args = TEST_JSON_DATA["1qys-1-refactored"]
input = instantiate_example(args, is_inference=is_inference)

View File

@@ -3,7 +3,7 @@ from atomworks.common import sum_string_arrays
from atomworks.io.utils.io_utils import to_cif_file
from atomworks.ml.transforms.center_random_augmentation import CenterRandomAugmentation
from biotite.structure import AtomArrayStack
from rfd3.trainer.rfd3_trainer import _reassign_unindexed_token_chains
from rfd3.trainer.rfd3 import _reassign_unindexed_token_chains
from rfd3.transforms.design_transforms import (
MotifCenterRandomAugmentation,
)

View File

@@ -1,6 +1,5 @@
import copy
import getpass
import inspect
import json
import logging
import os
@@ -23,14 +22,12 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../src")
import atomworks
from atomworks import parse
from atomworks.io.parser import STANDARD_PARSER_ARGS
from atomworks.io.utils.io_utils import to_cif_file
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf
from rfd3.constants import STANDARD_PARSER_ARGS
from rfd3.inference.datasets import (
prepare_pipeline_input_from_atom_array,
)
from rfd3.inference.input_parsing import (
DesignInputSpecification,
create_atom_array_from_design_specification,
)
from rfd3.transforms.pipelines import (
@@ -106,7 +103,7 @@ TEST_CFG_TRAIN = load_train_or_val_cfg()
##########################################################################################
DIRS = [
os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../../tests'),
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../tests"),
os.path.join(os.path.dirname(os.path.abspath(__file__))),
TEST_CFG_TRAIN.paths.data.design_benchmark_data_dir,
]
@@ -156,9 +153,6 @@ def load_test_json():
TEST_JSON_DATA = load_test_json()
assert TEST_JSON_DATA, "No test json data loaded!"
sig = inspect.signature(create_atom_array_from_design_specification)
valid_keys_ = sig.parameters.keys()
def filter_inference_args(args):
return {k: v for k, v in args.items() if k in valid_keys_}
@@ -171,8 +165,11 @@ def instantiate_example(args, is_inference=True):
if is_inference:
# Keep only the kwargs that the function actually accepts
# args = filter_inference_args(args)
atom_array, spec = create_atom_array_from_design_specification(**args)
input = prepare_pipeline_input_from_atom_array(atom_array)
# atom_array, spec = create_atom_array_from_design_specification(**args)
# input = prepare_pipeline_input_from_atom_array(atom_array)
input = DesignInputSpecification.safe_init(**args).to_pipeline_input(
example_id=args.get("example_id", "example")
)
else:
file = args.get("input")
if file is None:

View File

@@ -219,7 +219,6 @@ class FabricTrainer(ABC):
)
self.initialize_or_update_trainer_state({"scheduler_cfg": scheduler_cfg})
@abstractmethod
def construct_model(self):
"""Instantiate the model, updating the trainer state in-place.

View File

@@ -1,54 +1,34 @@
from collections import OrderedDict
import hydra
import numpy as np
import torch
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.utils.token import (
get_token_starts,
spread_token_wise,
)
from beartype.typing import Any, List, Union
from biotite.structure import AtomArray, AtomArrayStack, concatenate, infer_elements
from biotite.structure import AtomArray, AtomArrayStack
from biotite.structure.residues import get_residue_starts
from einops import repeat
from jaxtyping import Float, Int
from lightning_utilities import apply_to_collection
from omegaconf import DictConfig
from rfd3.constants import (
ATOM14_ATOM_NAMES,
VIRTUAL_ATOM_ELEMENT_NAME,
association_schemes,
association_schemes_stripped,
)
from rfd3.metrics.design_metrics import get_all_backbone_metrics
from rfd3.metrics.hbonds_hbplus_metrics import get_hbond_metrics
from rfd3.trainer.fabric_trainer import FabricTrainer
from rfd3.trainer.recycling import get_recycle_schedule
from rfd3.transforms.conditioning_utils import (
from rfd3.trainer.trainer_utils import (
_build_atom_array_stack,
_cleanup_virtual_atoms_and_assign_atom_name_elements,
_reassign_unindexed_token_chains,
_reorder_dict,
process_unindexed_outputs,
)
from rfd3.util.io import (
from rfd3.utils.io import (
build_stack_from_atom_array_and_batched_coords,
)
from modelhub.common import exists
from modelhub.metrics.losses import Loss
from modelhub.metrics.metric import MetricManager
from modelhub.training.EMA import EMA
from modelhub.trainers.fabric import FabricTrainer
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.torch import assert_no_nans, assert_same_shape
global_logger = RankedLogger(__name__, rank_zero_only=False)
def _remap_outputs(
xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
) -> Float[torch.Tensor, "D L 3"]:
"""Helper function to remap outputs using a mapping tensor."""
for i in range(xyz.shape[0]):
xyz[i, mapping[i]] = xyz[i].clone()
return xyz
class AADesignTrainer(FabricTrainer):
"""Mostly for unique things like saving outputs and parsing inputs
@@ -90,9 +70,6 @@ class AADesignTrainer(FabricTrainer):
)
self.association_scheme = association_scheme
self.seed = None
self.inference_sampler_overrides = None
super().__init__(**kwargs)
# (Initialize recycle schedule upfront so all GPU's can sample the same number of recycles within a batch)
self.n_recycles_train = n_recycles_train
@@ -110,31 +87,10 @@ class AADesignTrainer(FabricTrainer):
if metrics
else None
)
# Loss (full precision)
with torch.autocast(device_type=self.fabric.device.type, enabled=False):
self.loss = AF3Loss(**loss) if loss else None
self.loss = Loss(**loss) if loss else None
# FROM RF3
def construct_model(self):
"""Construct the model and optionally wrap with EMA."""
# ... instantiate model with Hydra and Fabric
with self.fabric.init_module():
ranked_logger.info("Instantiating model...")
model = hydra.utils.instantiate(
self.state["train_cfg"].model.net,
_recursive_=False,
)
# Optionally, wrap the model with EMA
if self.state["train_cfg"].model.ema is not None:
ranked_logger.info("Wrapping model with EMA...")
model = EMA(model, **self.state["train_cfg"].model.ema)
self.initialize_or_update_trainer_state({"model": model})
# ~~~FROM RF3
def _assemble_network_inputs(self, example: dict) -> dict:
"""Assemble and validate the network inputs."""
assert_same_shape(example["coord_atom_lvl_to_be_noised"], example["noise"])
@@ -159,7 +115,7 @@ class AADesignTrainer(FabricTrainer):
network_input["X_noisy_L"] = torch.nan_to_num(
network_input["X_noisy_L"]
)
ranked_logger.warning(str(e))
global_logger.warning(str(e))
else:
# During validation, since we do not crop, there should be no NaN's in the coordinates to noise
# (They were either removed, as is done with fully unresolved chains, or resolved accoring to our pipeline's rules)
@@ -172,40 +128,6 @@ class AADesignTrainer(FabricTrainer):
return network_input
# FROM RF3
def _assemble_metrics_extra_info(self, example: dict, network_output: dict) -> dict:
"""Prepares the extra info for the metrics"""
# We need the same information as for the loss...
metrics_extra_info = self._assemble_loss_extra_info(example)
# ... and possibly some additional metadata from the example dictionary
# TODO: Generalize, so we always use the `extra_info` key, rather than unpacking the ground truth as well
metrics_extra_info.update(
{
# TODO: Remove, instead using `extra_info` for all keys
**{
k: example["ground_truth"][k]
for k in [
"interfaces_to_score",
"pn_units_to_score",
"chain_iid_token_lvl",
]
if k in example["ground_truth"]
},
"example_id": example[
"example_id"
], # We require the example ID for logging
# (From the parser)
**example.get("extra_info", {}),
}
)
# Record metrics_tags for this example
metrics_extra_info["metrics_tags"] = example.get("metrics_tags", set())
# (Create a shallow copy to avoid modifying the original dictionary)
return {**metrics_extra_info}
def training_step(
self,
batch: Any,
@@ -296,9 +218,6 @@ class AADesignTrainer(FabricTrainer):
# (Note that forward() passes to the EMA/shadow model if the model is not training)
network_output = model.forward(
input=network_input,
n_cycle=example["feats"]["msa_stack"].shape[
0
], # Determine the number of recycles from the MSA stack shape
coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"],
)
@@ -307,28 +226,19 @@ class AADesignTrainer(FabricTrainer):
msg=f"network_output for example_id: {example['example_id']}",
)
# ... Convert output to a stack of atom arrays
predicted_atom_array_stack, prediction_metadata = (
self._build_predicted_atom_array_stack(network_output, example)
)
metrics_output = {}
if compute_metrics and exists(self.metrics):
if compute_metrics:
assert self.metrics is not None, "Metrics are not defined!"
metrics_extra_info = self._assemble_metrics_extra_info(
example, network_output
)
# Symmetry resolution
# TODO: Refactor such that symmetry returns the ideal coordinate permutation, we apply permutation, and pass adjusted prediction to metrics
# (without needing to use `extra_info` as we are now)
# TODO: Update symmetry resolution to be functional (vs. using class variable), take explicit inputs (vs. all from netowork_ouput), and use extra_info for the keys it needs
metrics_extra_info = self.subunit_symm_resolve(
network_output,
metrics_extra_info,
example["symmetry_resolution"],
)
metrics_extra_info = self.residue_symm_resolve(
network_output,
metrics_extra_info,
example["automorphisms"],
)
metrics_output = self.metrics(
network_input=network_input,
network_output=network_output,
@@ -337,22 +247,33 @@ class AADesignTrainer(FabricTrainer):
ground_truth_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
metrics_extra_info["X_gt_L"], example.get("atom_array", None)
),
predicted_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
network_output["X_L"], example.get("atom_array", None)
),
predicted_atom_array_stack=predicted_atom_array_stack,
prediction_metadata=prediction_metadata,
)
if "X_gt_index_to_X" in metrics_extra_info:
# Remap outputs to minimize error with ground truth
# TODO: Remap before computing metrics, so that we can avoid pass `extra_info` to metrics (we instead just pass the remapped prediction)
mapping = metrics_extra_info["X_gt_index_to_X"] # [D, L]
network_output["X_L"] = _remap_outputs(network_output["X_L"], mapping)
# Avoid gradients in stored values to prevent memory leaks
if metrics_output is not None:
metrics_output = apply_to_collection(
metrics_output, torch.Tensor, lambda x: x.detach()
)
network_output = apply_to_collection(
network_output, torch.Tensor, lambda x: x.detach()
)
if network_output is not None:
network_output = apply_to_collection(
network_output, torch.Tensor, lambda x: x.detach()
)
return {"metrics_output": metrics_output, "network_output": network_output}
return {
"metrics_output": metrics_output,
"network_output": network_output,
"predicted_atom_array_stack": predicted_atom_array_stack,
"prediction_metadata": prediction_metadata,
}
def _assemble_loss_extra_info(self, example: dict) -> dict:
"""Assembles metadata arguments to the loss function (incremental to the network inputs and outputs)."""
@@ -501,7 +422,7 @@ class AADesignTrainer(FabricTrainer):
# Align ca and calculate RMSD:
if xyz_ca_input.shape == xyz_ca_output.shape:
try:
from rfd3.util.alignment import weighted_rigid_align
from rfd3.utils.alignment import weighted_rigid_align
xyz_ca_output_aligned = (
weighted_rigid_align(
@@ -532,266 +453,3 @@ class AADesignTrainer(FabricTrainer):
# Reorder metadata dictionaries to ensure 'metrics' and 'specification' are last
metadata_dict = {k: _reorder_dict(d) for k, d in metadata_dict.items()}
return atom_array_stack, metadata_dict
def _reorder_dict(d: dict) -> OrderedDict:
"""
Reorders keys in the dictionary to ensure 'metrics' and 'specification' are last (in that order if both present).
"""
ordered = OrderedDict()
first_keys = ["task", "diffused_index_map"]
last_keys = ["metrics", "specification", "inference_sampler"]
# First
for k in first_keys:
if k in d:
ordered[k] = d[k]
# Middle
for k in d:
if k not in last_keys and k not in first_keys:
ordered[k] = d[k]
# Last
for k in last_keys:
if k in d:
ordered[k] = d[k]
return ordered
def _build_atom_array_stack(
coords,
src_atom_array,
sequence_indices,
sequence_logits,
allow_sequence_outputs=True,
read_sequence_from_sequence_head=True,
association_scheme: str = "atom14",
):
"""
Wraps around build_atom_array_and_batched_coords to also include additional modifications to atom array
"""
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
coords, src_atom_array.copy()
)
# ... Spoof empty sequences to alanines
atom_array_stack.res_name[
atom_array_stack.is_protein & (atom_array_stack.res_name == "UNK")
] = "ALA"
# ... Add sequence if available
if allow_sequence_outputs:
array_list = []
if read_sequence_from_sequence_head and exists(sequence_logits):
sequence_encoding = AF3SequenceEncoding()
for i, (atom_array, seq_indices, seq_logits) in enumerate(
zip(atom_array_stack, sequence_indices, sequence_logits)
):
# Set residue names
diffused_mask = ~atom_array.is_motif_atom_with_fixed_seq
three_letter_sequence = sequence_encoding.decode(
seq_indices.cpu().numpy().astype(int)
) # [I]
atom_array.res_name[diffused_mask] = three_letter_sequence[
atom_array.token_id
][diffused_mask] # [L]
# Set bfactor column as entropy of sequence logits
p = torch.softmax(seq_logits, dim=-1).cpu().numpy() # shape (L, 32)
res_entropy = -np.sum(p * np.log(p + 1e-10), axis=-1) # shape (L,)
atom_array.b_factor = spread_token_wise(atom_array, res_entropy)
array_list.append(atom_array.copy())
else:
# This automatically deletes virtual atoms and assigns resname, atom name, and elements
for atom_array in atom_array_stack:
atom_array = _readout_seq_from_struc(
atom_array, association_scheme=association_scheme
)
array_list.append(atom_array)
# Return as list
atom_array_stack = array_list
return atom_array_stack
def _reassign_unindexed_token_chains(atom_array):
if np.any((mask := atom_array.is_motif_atom_unindexed)):
# HACK: Since res_ids are the same, we should save them with a different chain index.
atom_array.chain_id[mask] = "X"
atom_array.res_id[mask] = atom_array.orig_res_id[mask]
# Parse to separate chains
starts = get_token_starts(atom_array)
unindexed_starts = starts[mask[starts]]
token_breaks = atom_array[
unindexed_starts
].is_motif_atom_unindexed_motif_breakpoint
token_group_id = np.cumsum(token_breaks, dtype=int) # Group by motif breaks
token_chain_id = np.array([f"X{i}" for i in token_group_id])
chains = atom_array.chain_id[starts]
chains[mask[starts]] = token_chain_id
atom_array.chain_id = spread_token_wise(atom_array, chains)
return atom_array
def _cleanup_virtual_atoms_and_assign_atom_name_elements(
atom_array, association_scheme: str = "atom14"
):
## remove virtual atoms based on predicted residue and assign correct atom name and elements
ret_mask = []
atom_names = []
# This is used to indicate which residue is unidentified, probably due to an invalid structure.
# This is different from the ref_mask, which is used to delete virtual atoms, but this one is used to assign UNK resname for invalid residues.
invalid_mask = []
# ... Iterate through each residue.
# Here we iterate through res_id instead of token_id to avoid some atomization cases or something else.
res_ids = atom_array.res_id
res_start_indices = np.concatenate(
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
)
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
warning_issued = False
for start, end in zip(res_start_indices, res_end_indices):
res_array = atom_array[start:end]
is_seq_known = all(
np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
# ... If sequence is known for the original atom array, just skip
if is_seq_known:
ret_mask += [True] * len(res_array)
invalid_mask += [False] * len(res_array)
res_name = res_array[0].res_name
atom_names += res_array.gt_atom_name.tolist()
continue
# ... If sequence is unknown for the original atom array, use the predicted / inferred sequence
res_name = res_array[0].res_name
if res_name not in association_schemes[association_scheme]:
global_logger.warning(
"Model predicted non-protein sequence for diffused residue. Cannot clean up outputs. Assigning unknown residue token."
)
warning_issued = True
ret_mask += [True] * len(res_array)
invalid_mask += [True] * len(res_array)
atom_names += res_array.atom_name.tolist()
continue
scheme = association_schemes[association_scheme][res_name]
ret_mask += [True if item is not None else False for item in scheme]
atom_names += [item.strip() if item is not None else "VX" for item in scheme]
invalid_mask += [False] * len(scheme)
if len(atom_names) != atom_array.array_length():
global_logger.warning(
f"{atom_names=}\n{atom_array.atom_name=}\nAtom names length {len(atom_names)} does not match original array length {atom_array.array_length()}."
"\nCould not cleanup atom array!!!"
)
if not warning_issued:
raise ValueError("Atom names length does not match original array length. ")
return atom_array
atom_array.atom_name = atom_names
atom_array.element = np.where(
atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME,
infer_elements(atom_names),
atom_array.element,
)
atom_array.res_name[invalid_mask] = np.array(["UNK"] * sum(invalid_mask))
return atom_array[ret_mask]
def _readout_seq_from_struc(
atom_array, central_atom="CB", threshold=0.5, association_scheme: str = "atom14"
):
cur_atom_array_list = []
# Iterate through each residue
res_ids = atom_array.res_id
res_start_indices = np.concatenate(
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
)
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
for start, end in zip(res_start_indices, res_end_indices):
# ... Check if the current residue is after padding (seq unknown):
cur_res_atom_array = atom_array[start:end]
is_seq_known = all(
np.array(cur_res_atom_array.is_motif_atom_with_fixed_seq, dtype=bool)
)
# Here it assumes that every non-protein part has its sequence shown (not padded)
if not is_seq_known:
# For Glycine: it doesn't have CB, so set the virtual atom as CA.
# The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
# There might be a better way to do this.
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
if np.linalg.norm(CA_coord - CB_coord) < threshold:
cur_central_atom = "CA"
else:
cur_central_atom = central_atom
central_mask = cur_res_atom_array.atom_name == cur_central_atom
# ... Calculate the distance to the central atom
central_coord = cur_res_atom_array.coord[central_mask][
0
] # Should only have one central atom anyway
dists = np.linalg.norm(cur_res_atom_array.coord - central_coord, axis=-1)
# ... Select virtual atom by the distance. Shouldn't count the central atom itself.
is_virtual = (dists < threshold) & ~central_mask
# ... Throw away virtual atoms
cur_res_atom_array_wo_virtual = cur_res_atom_array[~is_virtual]
cur_pred_res_atom_names = (
cur_res_atom_array_wo_virtual.atom_name
) # e.g. [N, CA, C, O, CB, V6, V2]
# ... Iterate over the possible restypes and find the matched one if there is any
has_restype_assigned = False
for restype, atom_names in association_schemes_stripped[
association_scheme
].items():
atom_names = np.array(atom_names)
# Shouldn't match these two
if restype in ["UNK", "MSK"]:
continue
# ... Find the index of virtual atom names in the standard atom14 names
atom_name_idx_in_atom14_scheme = np.array(
[
np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
for atom_name in cur_pred_res_atom_names
]
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
# ... Find the matched restype by checking if all the non-None posititons and None positions match
# This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
x is None for x in atom_names[~atom14_scheme_mask]
):
cur_res_atom_array.res_name = np.array(
[restype] * len(cur_res_atom_array)
)
cur_atom_array_list.append(cur_res_atom_array)
has_restype_assigned = True
break
else:
cur_atom_array_list.append(cur_res_atom_array)
has_restype_assigned = True
# ... Give UNK as the residue name if the mapping fails (unrealistic sidechain)
if not has_restype_assigned:
cur_res_atom_array.res_name = np.array(["UNK"] * len(cur_res_atom_array))
cur_atom_array_list.append(cur_res_atom_array)
cur_atom_array = concatenate(cur_atom_array_list)
return cur_atom_array

View File

@@ -0,0 +1,502 @@
from collections import Counter, OrderedDict
import numpy as np
import torch
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.utils.token import (
get_token_starts,
spread_token_wise,
)
from biotite.structure import concatenate, infer_elements
from jaxtyping import Float, Int
from rfd3.constants import (
ATOM14_ATOM_NAMES,
VIRTUAL_ATOM_ELEMENT_NAME,
association_schemes,
association_schemes_stripped,
)
from rfd3.utils.io import (
build_stack_from_atom_array_and_batched_coords,
)
from scipy.optimize import linear_sum_assignment
from modelhub.common import exists
from modelhub.utils.ddp import RankedLogger
global_logger = RankedLogger(__name__, rank_zero_only=False)
#######################################################################
# Pythonic Helper functions
#######################################################################
def _remap_outputs(
xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
) -> Float[torch.Tensor, "D L 3"]:
"""Helper function to remap outputs using a mapping tensor."""
for i in range(xyz.shape[0]):
xyz[i, mapping[i]] = xyz[i].clone()
return xyz
def _reorder_dict(d: dict) -> OrderedDict:
"""
Reorders keys in the dictionary to ensure 'metrics' and 'specification' are last (in that order if both present).
"""
ordered = OrderedDict()
first_keys = ["task", "diffused_index_map"]
last_keys = ["metrics", "specification", "inference_sampler"]
# First
for k in first_keys:
if k in d:
ordered[k] = d[k]
# Middle
for k in d:
if k not in last_keys and k not in first_keys:
ordered[k] = d[k]
# Last
for k in last_keys:
if k in d:
ordered[k] = d[k]
return ordered
#######################################################################
# Biotite-related helper functions
#######################################################################
def _build_atom_array_stack(
coords,
src_atom_array,
sequence_indices,
sequence_logits,
allow_sequence_outputs=True,
read_sequence_from_sequence_head=True,
association_scheme: str = "atom14",
):
"""
Wraps around build_atom_array_and_batched_coords to also include additional modifications to atom array
"""
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
coords, src_atom_array.copy()
)
# ... Spoof empty sequences to alanines
atom_array_stack.res_name[
atom_array_stack.is_protein & (atom_array_stack.res_name == "UNK")
] = "ALA"
# ... Add sequence if available
if allow_sequence_outputs:
array_list = []
if read_sequence_from_sequence_head and exists(sequence_logits):
sequence_encoding = AF3SequenceEncoding()
for i, (atom_array, seq_indices, seq_logits) in enumerate(
zip(atom_array_stack, sequence_indices, sequence_logits)
):
# Set residue names
diffused_mask = ~atom_array.is_motif_atom_with_fixed_seq
three_letter_sequence = sequence_encoding.decode(
seq_indices.cpu().numpy().astype(int)
) # [I]
atom_array.res_name[diffused_mask] = three_letter_sequence[
atom_array.token_id
][diffused_mask] # [L]
# Set bfactor column as entropy of sequence logits
p = torch.softmax(seq_logits, dim=-1).cpu().numpy() # shape (L, 32)
res_entropy = -np.sum(p * np.log(p + 1e-10), axis=-1) # shape (L,)
atom_array.b_factor = spread_token_wise(atom_array, res_entropy)
array_list.append(atom_array.copy())
else:
# This automatically deletes virtual atoms and assigns resname, atom name, and elements
for atom_array in atom_array_stack:
atom_array = _readout_seq_from_struc(
atom_array, association_scheme=association_scheme
)
array_list.append(atom_array)
# Return as list
atom_array_stack = array_list
return atom_array_stack
def _cleanup_virtual_atoms_and_assign_atom_name_elements(
atom_array, association_scheme: str = "atom14"
):
## remove virtual atoms based on predicted residue and assign correct atom name and elements
ret_mask = []
atom_names = []
# This is used to indicate which residue is unidentified, probably due to an invalid structure.
# This is different from the ref_mask, which is used to delete virtual atoms, but this one is used to assign UNK resname for invalid residues.
invalid_mask = []
# ... Iterate through each residue.
# Here we iterate through res_id instead of token_id to avoid some atomization cases or something else.
res_ids = atom_array.res_id
res_start_indices = np.concatenate(
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
)
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
warning_issued = False
for start, end in zip(res_start_indices, res_end_indices):
res_array = atom_array[start:end]
is_seq_known = all(
np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
# ... If sequence is known for the original atom array, just skip
if is_seq_known:
ret_mask += [True] * len(res_array)
invalid_mask += [False] * len(res_array)
res_name = res_array[0].res_name
atom_names += res_array.gt_atom_name.tolist()
continue
# ... If sequence is unknown for the original atom array, use the predicted / inferred sequence
res_name = res_array[0].res_name
if res_name not in association_schemes[association_scheme]:
global_logger.warning(
"Model predicted non-protein sequence for diffused residue. Cannot clean up outputs. Assigning unknown residue token."
)
warning_issued = True
ret_mask += [True] * len(res_array)
invalid_mask += [True] * len(res_array)
atom_names += res_array.atom_name.tolist()
continue
scheme = association_schemes[association_scheme][res_name]
ret_mask += [True if item is not None else False for item in scheme]
atom_names += [item.strip() if item is not None else "VX" for item in scheme]
invalid_mask += [False] * len(scheme)
if len(atom_names) != atom_array.array_length():
global_logger.warning(
f"{atom_names=}\n{atom_array.atom_name=}\nAtom names length {len(atom_names)} does not match original array length {atom_array.array_length()}."
"\nCould not cleanup atom array!!!"
)
if not warning_issued:
raise ValueError("Atom names length does not match original array length. ")
return atom_array
atom_array.atom_name = atom_names
atom_array.element = np.where(
atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME,
infer_elements(atom_names),
atom_array.element,
)
atom_array.res_name[invalid_mask] = np.array(["UNK"] * sum(invalid_mask))
return atom_array[ret_mask]
def _readout_seq_from_struc(
atom_array, central_atom="CB", threshold=0.5, association_scheme: str = "atom14"
):
cur_atom_array_list = []
# Iterate through each residue
res_ids = atom_array.res_id
res_start_indices = np.concatenate(
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
)
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
for start, end in zip(res_start_indices, res_end_indices):
# ... Check if the current residue is after padding (seq unknown):
cur_res_atom_array = atom_array[start:end]
is_seq_known = all(
np.array(cur_res_atom_array.is_motif_atom_with_fixed_seq, dtype=bool)
)
# Here it assumes that every non-protein part has its sequence shown (not padded)
if not is_seq_known:
# For Glycine: it doesn't have CB, so set the virtual atom as CA.
# The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
# There might be a better way to do this.
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
if np.linalg.norm(CA_coord - CB_coord) < threshold:
cur_central_atom = "CA"
else:
cur_central_atom = central_atom
central_mask = cur_res_atom_array.atom_name == cur_central_atom
# ... Calculate the distance to the central atom
central_coord = cur_res_atom_array.coord[central_mask][
0
] # Should only have one central atom anyway
dists = np.linalg.norm(cur_res_atom_array.coord - central_coord, axis=-1)
# ... Select virtual atom by the distance. Shouldn't count the central atom itself.
is_virtual = (dists < threshold) & ~central_mask
# ... Throw away virtual atoms
cur_res_atom_array_wo_virtual = cur_res_atom_array[~is_virtual]
cur_pred_res_atom_names = (
cur_res_atom_array_wo_virtual.atom_name
) # e.g. [N, CA, C, O, CB, V6, V2]
# ... Iterate over the possible restypes and find the matched one if there is any
has_restype_assigned = False
for restype, atom_names in association_schemes_stripped[
association_scheme
].items():
atom_names = np.array(atom_names)
# Shouldn't match these two
if restype in ["UNK", "MSK"]:
continue
# ... Find the index of virtual atom names in the standard atom14 names
atom_name_idx_in_atom14_scheme = np.array(
[
np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
for atom_name in cur_pred_res_atom_names
]
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
# ... Find the matched restype by checking if all the non-None posititons and None positions match
# This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
x is None for x in atom_names[~atom14_scheme_mask]
):
cur_res_atom_array.res_name = np.array(
[restype] * len(cur_res_atom_array)
)
cur_atom_array_list.append(cur_res_atom_array)
has_restype_assigned = True
break
else:
cur_atom_array_list.append(cur_res_atom_array)
has_restype_assigned = True
# ... Give UNK as the residue name if the mapping fails (unrealistic sidechain)
if not has_restype_assigned:
cur_res_atom_array.res_name = np.array(["UNK"] * len(cur_res_atom_array))
cur_atom_array_list.append(cur_res_atom_array)
cur_atom_array = concatenate(cur_atom_array_list)
return cur_atom_array
#######################################################################
# Unindexed output parsing
#######################################################################
def _reassign_unindexed_token_chains(atom_array):
if np.any((mask := atom_array.is_motif_atom_unindexed)):
# HACK: Since res_ids are the same, we should save them with a different chain index.
atom_array.chain_id[mask] = "X"
atom_array.res_id[mask] = atom_array.orig_res_id[mask]
# Parse to separate chains
starts = get_token_starts(atom_array)
unindexed_starts = starts[mask[starts]]
token_breaks = atom_array[
unindexed_starts
].is_motif_atom_unindexed_motif_breakpoint
token_group_id = np.cumsum(token_breaks, dtype=int) # Group by motif breaks
token_chain_id = np.array([f"X{i}" for i in token_group_id])
chains = atom_array.chain_id[starts]
chains[mask[starts]] = token_chain_id
atom_array.chain_id = spread_token_wise(atom_array, chains)
return atom_array
def process_unindexed_outputs(
atom_array,
match_atom_names=True,
insert_guideposts=False,
verbose=False,
):
"""
Process design outputs containing unindexed tokens.
Returns metadata such as the assigned positional indices from the input indices
and the RMSD of the unindexed tokens.
Returns:
- Diffused atom array (without additional unindexed tokens)
- Metadata:
- diffused_indices: keys = original (contig) indices, values = diffused indices
- insertion_rmsd: overall RMSD of insertion
- insertion_rmsd_by_residue: RMSD of insertion for each token
TODO: Add additional geometry metrics such as bond angle non-ideality, clashes etc.
TODO: atom1d conditioning adherence - does the output contain HBonds in the right places, correct rasa values?
"""
# ... Find assignments based on greedy search
starts = get_token_starts(atom_array, add_exclusive_stop=True)
# [N_diffused,]
atom_array_diffused = atom_array[~atom_array.is_motif_atom_unindexed].copy()
global_idx = np.arange(atom_array.array_length())[
~atom_array.is_motif_atom_unindexed
]
metadata = {
"diffused_index_map": {},
"insertion_rmsd_by_token": {},
"join_point_rmsd_by_token": {},
"insertion_rmsd_by_restype": {},
}
token_maes = []
token_rmcds = []
n_conjoined_residues = 0
# Initialize an empty array
inserted_mask = np.full_like(atom_array_diffused.is_motif_atom_unindexed, False)
for start, end in zip(starts[:-1], starts[1:]):
token = atom_array[start:end]
if not token.is_motif_atom_unindexed.all():
continue
if "src_component" in token.get_annotation_categories():
token_pdb_id = token.src_component[0]
else:
raise ValueError(
"Missing annotation 'src_component' in token. Is this inference?"
)
if "src_sym_component" in token.get_annotation_categories():
# if symmetry, token_pdb_id are updated to match the symmetrized component
token_pdb_id = token.src_sym_component[0]
res_name = token.res_name[0]
# ... Calculate [N_unindex, N_diffused] distance matrix
dists = np.linalg.norm(
token.coord[:, None] - atom_array_diffused.coord[None, :], axis=-1
)
# ... Match atom indices based on atom names (mask out non-identical) and remove already inserted
dists[:, inserted_mask.copy()] = np.inf
if match_atom_names:
matching_atom_name = (
token.atom_name[:, None] == atom_array_diffused.atom_name[None, :]
)
dists[~matching_atom_name] = np.inf
# ... Find the res_id's in the diffused regions belonging to the diffused indices
row_ind, col_ind = linear_sum_assignment(dists)
res_id, chain_id, is_conjoined = indices_to_components_(
atom_array_diffused, col_ind
)
n_conjoined_residues += int(is_conjoined)
# ... Recompute distance indices based on single residue pairings only
token_match = (atom_array_diffused.res_id == res_id) & (
atom_array_diffused.chain_id == chain_id
)
dists[:, ~token_match] = np.nan
BIG = 1e12
dists = np.nan_to_num(dists, nan=BIG, posinf=BIG, neginf=BIG)
row_ind, col_ind = linear_sum_assignment(dists)
res_id_, chain_id_, _ = indices_to_components_(atom_array_diffused, col_ind)
assert (res_id_ == res_id) & (chain_id_ == chain_id)
inserted_mask = np.logical_or(inserted_mask, token_match)
# ... Compute metrics based on the new distances
diff = token.coord[row_ind] - atom_array_diffused.coord[col_ind]
token_rmsd = float(np.sqrt((diff**2).sum(-1).mean()))
token_rmcd = float(np.cbrt((np.abs(diff) ** 3).sum(-1).mean()))
token_mae = float((np.abs(diff)).sum(-1).mean())
metadata["insertion_rmsd_by_token"][token_pdb_id] = token_rmsd
token_maes.append(token_mae)
token_rmcds.append(token_rmcd)
if res_name not in metadata["insertion_rmsd_by_restype"]:
metadata["insertion_rmsd_by_restype"][res_name] = []
metadata["insertion_rmsd_by_restype"][res_name].append(token_rmsd)
if not np.any(np.isin(token.atom_name, ["N", "CA", "C", "O"])):
if np.sum(token.atomize) == 1:
join_atom = np.where(token.atomize)[0][0]
elif "CB" in token.atom_name:
join_atom = np.where(token.atom_name == "CB")[0][0]
else:
join_atom = None
if join_atom is None:
global_logger.warning(
f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
)
else:
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
# ... Decide whether to cleanup guideposts or not
if insert_guideposts:
atom_array_diffused.coord[global_idx[col_ind]] = token.coord[row_ind]
if token.is_motif_atom_with_fixed_seq[0]:
atom_array_diffused.res_name[token_match] = token.res_name[0]
# atom_array_diffused.is_motif_token[token_match] = True
# atom_array_diffused.is_motif_atom[global_idx[col_ind]] = True
atom_array_diffused.is_motif_atom_with_fixed_coord[global_idx[col_ind]] = (
True
)
# ... Calculate global metrics
def safe_mean(x):
"""Return nan-safe mean for empty or nan arrays."""
x = np.asarray(x, float)
if x.size == 0 or not np.isfinite(x).any():
return float("nan")
return float(np.nanmean(x))
metadata["insertion.mae"] = safe_mean(token_maes)
metadata["insertion.rmcd"] = safe_mean(token_rmcds)
metadata["insertion_rmsd"] = safe_mean(
list(metadata["insertion_rmsd_by_token"].values())
)
metadata["join_point_rmsd"] = safe_mean(
list(metadata["join_point_rmsd_by_token"].values())
)
metadata["insertion_rmsd_by_restype"] = {
a: safe_mean(v) for a, v in metadata["insertion_rmsd_by_restype"].items()
}
metadata["n_conjoined_residues"] = n_conjoined_residues
if not verbose:
metadata = {
k: v for k, v in metadata.items() if not k.startswith("insertion_rmsd_by_")
}
return atom_array_diffused, metadata
def indices_to_components_(atom_array, col_ind):
"""
Fetch chain and resids in atom array given a set of raw indices
will return 'conjoined' if indices to not map to a unique residue
"""
res_ids, chain_ids = (
atom_array.res_id[col_ind],
atom_array.chain_id[col_ind],
)
if len(set(res_ids.tolist())) > 1 or len(set(chain_ids.tolist())) > 1:
global_logger.warning(
f"Unindexed token mapped its atoms to multiple diffused residues: {res_ids.tolist()} and chains {chain_ids.tolist()}."
)
# Handle by majority
pair_counts = Counter(zip(chain_ids.tolist(), res_ids.tolist()))
(chain_id, res_id), _ = pair_counts.most_common(1)[0]
conjoined = True
else:
res_id = res_ids[0]
chain_id = chain_ids[0]
conjoined = False
return res_id, chain_id, conjoined

View File

@@ -1,56 +0,0 @@
from atomworks.ml.transforms._checks import (
check_contains_keys,
check_is_instance,
)
from atomworks.ml.transforms.base import Transform
from biotite.structure import AtomArray
class SetOccToZeroOnBfactor(Transform):
"""
This component marks atoms as occ=0 based on bfactor values
It takes as input 'brange', a list specifying the Mminimum and maximum B factors to
keep.
Example:
brange = [-1.0,70.0] will mark with occ=0 any atom with b>70 or b<-1
"""
def __init__(
self,
bmin=None,
bmax=None,
):
self.bmin = bmin
self.bmax = bmax
def check_input(self, data: dict):
check_contains_keys(data, ["atom_array"])
check_is_instance(data, "atom_array", AtomArray)
# check_atom_array_annotation(data, ["b_factor", "occupancy"])
def forward(self, data: dict) -> dict:
atom_array = data["atom_array"]
if (
self.bmin is None and self.bmax is None
) or "b_factor" not in atom_array.get_annotation_categories():
return data
bfact = atom_array.get_annotation("b_factor")
if self.bmin is not None:
mask = bfact < self.bmin
if self.bmax is not None:
mask = mask | (bfact > self.bmax)
else:
mask = bfact > self.bmax
occ = atom_array.get_annotation("occupancy")
occ[mask] = 0.0
atom_array.set_annotation("occupancy", occ)
data["atom_array"] = atom_array
return data

View File

@@ -1,10 +1,6 @@
from collections import Counter
import networkx as nx
import numpy as np
from atomworks.io.utils.bonds import _atom_array_to_networkx_graph
from atomworks.ml.utils.token import get_token_starts
from scipy.optimize import linear_sum_assignment
from modelhub.utils.ddp import RankedLogger
@@ -192,171 +188,6 @@ def choose_uniformly_random_atom_name(subarray):
#################################################################################
def process_unindexed_outputs(
atom_array,
match_atom_names=True,
insert_guideposts=False,
verbose=False,
):
"""
Process design outputs containing unindexed tokens.
Returns metadata such as the assigned positional indices from the input indices
and the RMSD of the unindexed tokens.
Returns:
- Diffused atom array (without additional unindexed tokens)
- Metadata:
- diffused_indices: keys = original (contig) indices, values = diffused indices
- insertion_rmsd: overall RMSD of insertion
- insertion_rmsd_by_residue: RMSD of insertion for each token
TODO: Add additional geometry metrics such as bond angle non-ideality, clashes etc.
TODO: atom1d conditioning adherence - does the output contain HBonds in the right places, correct rasa values?
"""
# ... Find assignments based on greedy search
starts = get_token_starts(atom_array, add_exclusive_stop=True)
# [N_diffused,]
atom_array_diffused = atom_array[~atom_array.is_motif_atom_unindexed].copy()
global_idx = np.arange(atom_array.array_length())[
~atom_array.is_motif_atom_unindexed
]
metadata = {
"diffused_index_map": {},
"insertion_rmsd_by_token": {},
"join_point_rmsd_by_token": {},
"insertion_rmsd_by_restype": {},
}
token_maes = []
token_rmcds = []
n_conjoined_residues = 0
# Initialize an empty array
inserted_mask = np.full_like(atom_array_diffused.is_motif_atom_unindexed, False)
for start, end in zip(starts[:-1], starts[1:]):
token = atom_array[start:end]
if not token.is_motif_atom_unindexed.all():
continue
if "src_component" in token.get_annotation_categories():
token_pdb_id = token.src_component[0]
else:
raise ValueError(
"Missing annotation 'src_component' in token. Is this inference?"
)
if "src_sym_component" in token.get_annotation_categories():
# if symmetry, token_pdb_id are updated to match the symmetrized component
token_pdb_id = token.src_sym_component[0]
res_name = token.res_name[0]
# ... Calculate [N_unindex, N_diffused] distance matrix
dists = np.linalg.norm(
token.coord[:, None] - atom_array_diffused.coord[None, :], axis=-1
)
# ... Match atom indices based on atom names (mask out non-identical) and remove already inserted
dists[:, inserted_mask.copy()] = np.inf
if match_atom_names:
matching_atom_name = (
token.atom_name[:, None] == atom_array_diffused.atom_name[None, :]
)
dists[~matching_atom_name] = np.inf
# ... Find the res_id's in the diffused regions belonging to the diffused indices
row_ind, col_ind = linear_sum_assignment(dists)
res_id, chain_id, is_conjoined = indices_to_components(
atom_array_diffused, col_ind
)
n_conjoined_residues += int(is_conjoined)
# ... Recompute distance indices based on single residue pairings only
token_match = (atom_array_diffused.res_id == res_id) & (
atom_array_diffused.chain_id == chain_id
)
dists[:, ~token_match] = np.nan
BIG = 1e12
dists = np.nan_to_num(dists, nan=BIG, posinf=BIG, neginf=BIG)
row_ind, col_ind = linear_sum_assignment(dists)
res_id_, chain_id_, _ = indices_to_components(atom_array_diffused, col_ind)
assert (res_id_ == res_id) & (chain_id_ == chain_id)
inserted_mask = np.logical_or(inserted_mask, token_match)
# ... Compute metrics based on the new distances
diff = token.coord[row_ind] - atom_array_diffused.coord[col_ind]
token_rmsd = float(np.sqrt((diff**2).sum(-1).mean()))
token_rmcd = float(np.cbrt((np.abs(diff) ** 3).sum(-1).mean()))
token_mae = float((np.abs(diff)).sum(-1).mean())
metadata["insertion_rmsd_by_token"][token_pdb_id] = token_rmsd
token_maes.append(token_mae)
token_rmcds.append(token_rmcd)
if res_name not in metadata["insertion_rmsd_by_restype"]:
metadata["insertion_rmsd_by_restype"][res_name] = []
metadata["insertion_rmsd_by_restype"][res_name].append(token_rmsd)
if not np.any(np.isin(token.atom_name, ["N", "CA", "C", "O"])):
if np.sum(token.atomize) == 1:
join_atom = np.where(token.atomize)[0][0]
elif "CB" in token.atom_name:
join_atom = np.where(token.atom_name == "CB")[0][0]
else:
join_atom = None
if join_atom is None:
global_logger.warning(
f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
)
else:
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
# ... Decide whether to cleanup guideposts or not
if insert_guideposts:
atom_array_diffused.coord[global_idx[col_ind]] = token.coord[row_ind]
if token.is_motif_atom_with_fixed_seq[0]:
atom_array_diffused.res_name[token_match] = token.res_name[0]
# atom_array_diffused.is_motif_token[token_match] = True
# atom_array_diffused.is_motif_atom[global_idx[col_ind]] = True
atom_array_diffused.is_motif_atom_with_fixed_coord[global_idx[col_ind]] = (
True
)
# ... Calculate global metrics
def safe_mean(x):
"""Return nan-safe mean for empty or nan arrays."""
x = np.asarray(x, float)
if x.size == 0 or not np.isfinite(x).any():
return float("nan")
return float(np.nanmean(x))
metadata["insertion.mae"] = safe_mean(token_maes)
metadata["insertion.rmcd"] = safe_mean(token_rmcds)
metadata["insertion_rmsd"] = safe_mean(
list(metadata["insertion_rmsd_by_token"].values())
)
metadata["join_point_rmsd"] = safe_mean(
list(metadata["join_point_rmsd_by_token"].values())
)
metadata["insertion_rmsd_by_restype"] = {
a: safe_mean(v) for a, v in metadata["insertion_rmsd_by_restype"].items()
}
metadata["n_conjoined_residues"] = n_conjoined_residues
if not verbose:
metadata = {
k: v for k, v in metadata.items() if not k.startswith("insertion_rmsd_by_")
}
return atom_array_diffused, metadata
def random_condition(p_cond):
"""
Made this function because I always get confused by which order the
@@ -367,28 +198,3 @@ def random_condition(p_cond):
return False
else:
return np.random.rand() < p_cond
def indices_to_components(atom_array, col_ind):
"""
Fetch chain and resids in atom array given a set of raw indices
will return 'conjoined' if indices to not map to a unique residue
"""
res_ids, chain_ids = (
atom_array.res_id[col_ind],
atom_array.chain_id[col_ind],
)
if len(set(res_ids.tolist())) > 1 or len(set(chain_ids.tolist())) > 1:
global_logger.warning(
f"Unindexed token mapped its atoms to multiple diffused residues: {res_ids.tolist()} and chains {chain_ids.tolist()}."
)
# Handle by majority
pair_counts = Counter(zip(chain_ids.tolist(), res_ids.tolist()))
(chain_id, res_id), _ = pair_counts.most_common(1)[0]
conjoined = True
else:
res_id = res_ids[0]
chain_id = chain_ids[0]
conjoined = False
return res_id, chain_id, conjoined

View File

@@ -624,35 +624,33 @@ def build_atom14_base_pipeline(
kwargs.setdefault("crop_spatial_probability", 0.0)
kwargs.setdefault("dna_contact_crop_probability", 0.0)
kwargs.setdefault("max_atoms_in_crop", None)
kwargs.setdefault("b_factor_min", None)
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
kwargs.setdefault("meta_conditioning_probabilities", {})
kwargs.setdefault("association_scheme", "atom14")
kwargs.setdefault("sigma_perturb", 0.0)
kwargs.setdefault("sigma_perturb_com", 0.0)
kwargs.setdefault("allowed_types", "ALL")
kwargs.setdefault("train_conditions", {})
# TODO: Delete these once all checkpoints are updated with the latest defaults
kwargs.setdefault("generate_conformers_for_non_protein_only", True)
kwargs.setdefault("atom_1d_features", None)
kwargs.setdefault("token_1d_features", None)
kwargs.setdefault("diffusion_batch_size", 16)
kwargs.setdefault("sigma_data", 16)
kwargs.setdefault("return_atom_array", True)
kwargs.setdefault("provide_elements_for_unindexed_components", False)
kwargs.setdefault("center_option", "all")
kwargs.setdefault("use_element_for_atom_names_of_atomized_tokens", False)
kwargs.setdefault("residue_cache_dir", None)
kwargs.setdefault("keep_full_binder_in_spatial_crop", True)
kwargs.setdefault("max_binder_length", 999)
kwargs.setdefault("max_ppi_hotspots_frac_to_provide", 0)
kwargs.setdefault("ppi_hotspot_max_distance", 15)
kwargs.setdefault("max_ss_frac_to_provide", 0.0)
kwargs.setdefault("min_ss_island_len", 0)
kwargs.setdefault("max_ss_island_len", 999)
kwargs.setdefault("max_binder_length", 999)
kwargs.setdefault("b_factor_min", None)
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
kwargs.setdefault("meta_conditioning_probabilities", {})
kwargs.setdefault("association_scheme", "dense")
kwargs.setdefault("sigma_perturb", 0.0)
kwargs.setdefault("sigma_perturb_com", 0.0)
kwargs.setdefault("allowed_types", "ALL")
kwargs.setdefault("train_conditions", {})
kwargs.setdefault("residue_cache_dir", None)
# TODO: Delete these once all checkpoints are updated with the latest defaults
kwargs.setdefault("generate_conformers_for_non_protein_only", True)
# kwargs.setdefault("atom_1d_features", None)
# kwargs.setdefault("token_1d_features", None)
# kwargs.setdefault("diffusion_batch_size", 16)
# kwargs.setdefault("sigma_data", 16)
kwargs.setdefault("return_atom_array", True)
kwargs.setdefault("provide_elements_for_unindexed_components", False)
kwargs.setdefault("center_option", "all")
return build_atom14_base_pipeline_(
is_inference=is_inference,

View File

@@ -37,7 +37,9 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato
raise ValueError(
f"Scheme {scheme} not found in association_schemes_stripped. Available schemes: {list(association_schemes_stripped.keys())}"
)
atom_names = [str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
atom_names = (
[str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
)
idxs = np.array(
[
association_schemes_stripped[scheme][res_name].index(name)

View File

@@ -10,6 +10,7 @@ import biotite.structure as struc
import numpy as np
from atomworks import parse
from atomworks.constants import STANDARD_DNA
from atomworks.io.parser import STANDARD_PARSER_ARGS
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.preprocessing.utils.structure_utils import (
get_atom_mask_from_cell_list,
@@ -21,18 +22,18 @@ from atomworks.ml.utils.token import (
from rfd3.constants import (
REQUIRED_CONDITIONING_ANNOTATIONS,
)
from rfd3.inference.components import (
fetch_mask_from_component,
get_name_mask,
unravel_components,
)
from rfd3.transforms.conditioning_base import (
convert_existing_annotations_to_bool,
set_default_conditioning_annotations,
)
from rfd3.transforms.conditioning_utils import sample_island_tokens
from rfd3.constants import STANDARD_PARSER_ARGS
from modelhub.common import exists
from modelhub.utils.components import (
fetch_mask_from_component,
get_name_mask,
unravel_components,
)
from modelhub.utils.ddp import RankedLogger
logging.basicConfig(level=logging.INFO)

View File

@@ -8,7 +8,8 @@ import numpy as np
import torch
from atomworks.io.utils.io_utils import to_cif_file
from biotite.structure import AtomArray, AtomArrayStack, stack
from rfd3.util.alignment import weighted_rigid_align
from modelhub.utils.alignment import weighted_rigid_align
DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}

View File

@@ -258,7 +258,7 @@ def _viz_from_file(
atom_array = pickle.load(f)
elif file_path.endswith((".cif", ".cif.gz", ".bcif", ".bcif.gz")):
from atomworks.io.utils.io_utils import get_structure, read_any
from rfd3.inference.inference_utils import (
from rfd3.utils.inference import (
_add_design_annotations_from_cif_block_metadata,
)

View File

@@ -15,7 +15,6 @@ from rfd3.testing.testing_utils import (
TEST_CFG_INFERENCE,
TEST_CFG_TRAIN,
TEST_JSON_DATA,
build_pipelines,
instantiate_example,
)
@@ -29,6 +28,23 @@ sys.path.append(PATH_TO_SRC)
smoke_test = list(TEST_JSON_DATA.keys())
@pytest.mark.fast
def test_imports():
import rfd3
import modelhub
print("Imported rfd3 version:", rfd3)
print("Imported modelhub version:", modelhub)
# Try imports from main modules
from rfd3.metrics.losses import DiffusionLoss
from rfd3.model.RFD3 import RFD3
from rfd3.trainer.rfd3 import AADesignTrainer
print("imported modules:", RFD3, AADesignTrainer, DiffusionLoss)
def test_inference():
# Silence outputs:
@@ -60,6 +76,6 @@ def test_training_pipeline(example_name):
example.get("training_condition_name"),
)
if __name__ == "__main__":
# pytest.main(sys.argv)
pytest.main(["-v", __file__ + "::test_train_pipeline[rfd3-NA-full]"])
pytest.main(sys.argv)

View File

@@ -8,7 +8,7 @@ from rfd3.testing.testing_utils import (
build_pipelines,
instantiate_example,
)
from rfd3.trainer.rfd3_trainer import (
from rfd3.trainer.trainer_utils import (
_cleanup_virtual_atoms_and_assign_atom_name_elements,
)

View File

@@ -3,18 +3,19 @@ import sys
import numpy as np
import pytest
from atomworks.io.utils.io_utils import load_any
from rfd3.inference.components import (
from rfd3.inference.input_parsing import DesignInputSpecification
from rfd3.inference.parsing import InputSelection
from rfd3.testing.testing_utils import (
TEST_JSON_DATA,
)
from modelhub.utils.components import (
fetch_mask_from_component,
fetch_mask_from_idx,
fetch_mask_from_name,
get_name_mask,
unravel_components,
)
from rfd3.inference.input_parsing import InputSpecification
from rfd3.inference.parsing import InputSelection
from rfd3.testing.testing_utils import (
TEST_JSON_DATA,
)
# TEST 1 - test the selections
args = TEST_JSON_DATA["amidase_helix"]
@@ -116,7 +117,7 @@ atom_array_ref_unindexed = load_any(file)[0]
def test_unindexed_break():
# Assert that unindexed selections with breaks work
sele = InputSelection.from_any(args["unindex"], atom_array=atom_array_ref_unindexed)
comps, breaks = InputSpecification.break_unindexed(sele)
comps, breaks = DesignInputSpecification.break_unindexed(sele)
if __name__ == "__main__":

View File

@@ -10,7 +10,7 @@ from rfd3.testing.testing_utils import (
instantiate_example,
load_train_or_val_cfg,
)
from rfd3.transforms.conditioning_utils import (
from rfd3.trainer.trainer_utils import (
process_unindexed_outputs,
)
@@ -48,10 +48,10 @@ def test_unindexed_cleanup(example, is_inference):
) # spoof inference label
atom_array_cleaned, metadata = process_unindexed_outputs(atom_array)
from rfd3.testing.debug_utils import save_pipe_out
# from rfd3.testing.debug_utils import save_pipe_out
save_pipe_out(atom_array)
print("Metadata:", metadata)
# save_pipe_out(atom_array)
# print("Metadata:", metadata)
xyz_cleaned = np.nan_to_num(atom_array_cleaned.coord)
xyz_diffused = np.nan_to_num(atom_array[~atom_array.is_motif_atom_unindexed].coord)

View File

@@ -192,6 +192,8 @@ def test_atom14_pipeline_regression(
# Compare results
config_desc = f" ({config.name})" if config.name != "a14-base-train" else ""
mode_description = f"{'inference' if is_inference else 'training'}{config_desc}"
if "specification" in result:
expected_result["specification"] = result["specification"]
_assert_pipeline_results_equal(
result, expected_result, example_name, mode_description

View File

@@ -32,7 +32,7 @@ dependencies = [
# ... typing & documentation
"jaxtyping>=0.2.17,<1",
"beartype>=0.18.0,<1",
"typer>=0.9.0,<1",
"typer>=0.20.0,<1",
# ... ml tools (core)
"torch>=2.2.0,<3",
"lightning>=2.5.0",
@@ -41,12 +41,17 @@ dependencies = [
"einx>=0.1.0,<1",
"opt_einsum>=3.4.0,<4",
"dm-tree>=0.1.6,<1",
"zstandard",
"pandas",
# "biotite",
"atomworks"
]
[project.optional-dependencies]
rfd3 = [
"pydantic>=2.8",
"toolz",
]
rf3 = [
"cuequivariance_ops_cu12>=0.6.1; sys_platform == 'linux'",
@@ -72,7 +77,8 @@ dev = [
"pytest-dotenv>=0.5.2,<1", # load environment variables from .env file
"pytest-cov>=4.1.0,<5", # generate coverage report
"pytest-benchmark>=5.0.0,<6", # benchmark tests for speed
"atomworks==1.0.2",
# "atomworks==1.0.2",
"pre-commit"
]
[project.scripts]
rf3 = "rf3.cli:app"

View File

@@ -49,5 +49,9 @@ try:
except ImportError:
logger.debug("cuEquivariance unavailable: import failed")
# Whether to disable checkpointing globally
DISABLE_CHECKPOINTING = False
# Export for easy access
__all__ = ["SHOULD_USE_CUEQUIVARIANCE"]
__all__ = ["SHOULD_USE_CUEQUIVARIANCE", "DISABLE_CHECKPOINTING"]

28
src/modelhub/constants.py Normal file
View File

@@ -0,0 +1,28 @@
# fmt: off
# ... For convenience, define BKBN, or TIP to be used as a shortcut | TIP is the largest set of fixed atom given at least 2 tip atoms
TIP_BY_RESTYPE = {
"TRP": ["CG","CD1","CD2","NE1","CE2","CE3","CZ2","CZ3","CH2"], # fix both rings
"HIS": ["CG","ND1","CD2","CE1","NE2"], # fixed ring
"TYR": ["CZ","OH"], # keeps ring dihedral flexible
"PHE": ["CG","CD1","CD2","CE1","CE2","CZ"],
"ASN": ["CB", "CG","OD1","ND2"],
"ASP": ["CB", "CG","OD1","OD2"],
"GLN": ["CG", "CD","OE1","NE2"],
"GLU": ["CG", "CD","OE1","OE2"],
"CYS": ["CB", "SG"],
"SER": ["CB", "OG"],
"THR": ["CB", "OG1"],
"LEU": ["CB", "CG", "CD1", "CD2"],
"VAL": ["CG1", "CG2"],
"ILE": ["CB", "CG2"],
"MET": ["SD", "CE"],
"LYS": ["CE","NZ"],
"ARG": ["CD","NE","CZ","NH1","NH2"],
"PRO": None,
"ALA": None,
"GLY": None,
"UNK": None,
"MSK": None
}
# fmt: on

View File

@@ -0,0 +1,216 @@
import logging
from os import PathLike
from pathlib import Path
from typing import Any, Dict
import hydra
import torch
from biotite.structure import AtomArray
from lightning.fabric import seed_everything
from omegaconf import OmegaConf
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
from modelhub.utils.logging import print_config_tree
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
datefmt="%H:%M:%S",
)
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
def merge(cfg, overrides: dict):
return OmegaConf.merge(cfg, OmegaConf.create(overrides))
class BaseInferenceEngine:
"""
Base inference engine.
Separates model setup (expensive, once) from inference (can run multiple times).
"""
def __init__(
self,
ckpt_path: PathLike,
num_nodes: int = 1,
devices_per_node: int = 1,
# Config overrides
transform_overrides={},
inference_sampler_overrides={},
trainer_overrides={},
# Debug
print_config: bool = False,
seed: int | None = None,
):
"""Initialize inference engine and load model.
Model config is loaded from checkpoint and overridden with parameters provided here.
Args:
ckpt_path: Path to model checkpoint.
seed: Random seed. If None, uses external RNG state. Defaults to ``None``.
num_nodes: Number of nodes for distributed inference. Defaults to ``1``.
devices_per_node: Number of devices per node. Defaults to ``1``.
print_config: Whether to print config trees. Defaults to ``False``.
"""
# Set attrs
self.initialized_ = False
self.trainer = None
self.pipeline = None
self.print_config = print_config
self.ckpt_path = ckpt_path
# Set random seed (only if seed is not None)
if seed is not None:
ranked_logger.info(f"Seeding everything with seed={seed}...")
seed_everything(seed, workers=True, verbose=True)
else:
ranked_logger.info("Seed is None - using external RNG state")
self.seed = seed
# Stored for later;
self.transform_overrides = transform_overrides
self.overrides: dict[str, Any] = {}
base_overrides = {
"trainer.seed": seed,
"trainer.metrics": {},
"trainer.loss": None,
"trainer.num_nodes": num_nodes,
"trainer.devices_per_node": devices_per_node,
}
for key, value in base_overrides.items():
self._assign_override(key, value)
for key, value in trainer_overrides.items():
self._assign_override(f"trainer.{key}", value)
for key, value in inference_sampler_overrides.items():
self._assign_override(f"model.net.inference_sampler.{key}", value)
###################################################################################
# Required subclasss methods
###################################################################################
def initialize(self):
if self.initialized_:
return getattr(self, "cfg", None)
# Load checkpoint and config
ranked_logger.info(
f"Loading checkpoint from {Path(self.ckpt_path).resolve()}..."
)
checkpoint = torch.load(self.ckpt_path, "cpu", weights_only=False)
cfg = self._override_checkpoint_config(checkpoint["train_cfg"])
# Load pipeline first before trainer/model
self._construct_pipeline(cfg)
self._construct_trainer(cfg, checkpoint=checkpoint)
ranked_logger.info("Model loaded and ready for inference.")
self.initialized_ = True
return cfg
def run(
self,
inputs: (
Dict[str, dict] | AtomArray | list[AtomArray] | PathLike | list[PathLike]
),
*_,
) -> dict[str, dict] | None:
self.initialize()
raise NotImplementedError(
"Subclasses must implement inference logic in `run` method."
)
###################################################################################
# Util methods
###################################################################################
def _override_checkpoint_config(self, cfg):
cfg = merge(cfg, self.overrides)
cfg = set_accelerator_based_on_availability(cfg)
return cfg
def _construct_trainer(self, cfg, checkpoint=None):
"""
Sets attr self.trainer
"""
# Instantiate trainer
ranked_logger.info("Instantiating trainer...")
if self.print_config:
print_config_tree(
cfg.trainer, resolve=True, title="INFERENCE TRAINER CONFIGURATION"
)
trainer = hydra.utils.instantiate(
cfg.trainer,
_convert_="partial",
_recursive_=False,
)
# Setup model
ranked_logger.info("Setting up model...")
trainer.fabric.launch()
trainer.initialize_or_update_trainer_state(
{"train_cfg": cfg}
) # config from training stores net params
trainer.construct_model()
ranked_logger.info("Loading model weights from checkpoint...")
trainer.load_checkpoint(checkpoint=checkpoint or self.ckpt_path)
# Ensure optimizer isn't loaded
trainer.state["optimizer"] = None
trainer.state["train_cfg"].model.optimizer = None
trainer.setup_model_optimizers_and_schedulers()
trainer.state["model"].eval()
self.trainer = trainer
def _assign_override(self, dotted_key: str, value: Any) -> None:
"""Assign ``value`` into ``self.overrides`` using a dotted path."""
target = self.overrides
keys = dotted_key.split(".")
for key in keys[:-1]:
if key not in target or not isinstance(target[key], dict):
target[key] = {}
target = target[key]
target[keys[-1]] = value
def _construct_pipeline(self, cfg):
"""
Sets attr self.pipeline
"""
# Construct pipeline
ranked_logger.info("Building Transform pipeline...")
first_val_dataset_key, first_val_dataset = next(iter(cfg.datasets.val.items()))
ranked_logger.info(
f"Using settings from validation dataset: {first_val_dataset_key}."
)
transform = first_val_dataset.dataset.transform
transform = merge(transform, self.transform_overrides)
if self.print_config:
print_config_tree(
transform,
resolve=True,
title="INFERENCE TRANSFORM PIPELINE",
)
self.pipeline = hydra.utils.instantiate(transform)
# aliases for run
def forward(self, *args, **kwargs):
return self.run(*args, **kwargs)
def __call__(self, *args, **kwargs):
return self.run(*args, **kwargs)
# for use as a context manager: e.g. `with BaseInferenceEngine(...) as engine:` to automatically cleanup
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.trainer = None
self.pipeline = None
self.initialized_ = False

View File

@@ -0,0 +1,27 @@
import hydra
import torch.nn as nn
class Loss(nn.Module):
def __init__(self, **losses):
super().__init__()
self.to_compute = []
for loss_name, loss in losses.items():
loss_fn = hydra.utils.instantiate(loss)
print(f"Adding loss {loss_name} to the loss function")
self.to_compute.append(loss_fn)
def forward(
self,
network_input,
network_output,
loss_input,
):
loss_dict = {}
loss = 0
for loss_fn in self.to_compute:
loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
loss += loss_
loss_dict.update(loss_dict_)
loss_dict["total_loss"] = loss.detach()
return loss, loss_dict

View File

@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
pi = torch.acos(torch.zeros(1)).item() * 2
class FourierEmbedding(nn.Module):
def __init__(self, c):
super().__init__()
self.c = c
self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
self.reset_parameters()
def reset_parameters(self) -> None:
# super().reset_parameters()
nn.init.normal_(self.w)
nn.init.normal_(self.b)
def forward(
self,
t, # [D]
):
return torch.cos(2 * pi * (t[..., None] * self.w + self.b))
class Dropout(nn.Module):
# Dropout entire row or column
def __init__(self, broadcast_dim=None, p_drop=0.15):
super(Dropout, self).__init__()
# give ones with probability of 1-p_drop / zeros with p_drop
self.sampler = torch.distributions.bernoulli.Bernoulli(
torch.tensor([1 - p_drop])
)
self.broadcast_dim = broadcast_dim
self.p_drop = p_drop
def forward(self, x):
if not self.training: # no drophead during evaluation mode
return x
shape = list(x.shape)
if self.broadcast_dim is not None:
shape[self.broadcast_dim] = 1
mask = self.sampler.sample(shape).to(x.device).view(shape)
x = mask * x / (1.0 - self.p_drop)
return x

View File

@@ -69,11 +69,13 @@ class FabricTrainer(ABC):
checkpoint_every_n_steps: int | None = None,
clip_grad_max_norm: float | None = None,
skip_nan_grad: bool = False,
error_if_grad_nonfinite: bool = False,
limit_train_batches: int | float = float("inf"),
limit_val_batches: int | float = float("inf"),
prevalidate: bool = False,
nccl_timeout: int = 3_200,
find_unused_parameters: bool = False,
skip_optimizer_loading: bool = False,
) -> None:
"""Base Trainer class built around Lightning Fabric.
@@ -103,6 +105,7 @@ class FabricTrainer(ABC):
clip_grad_max_norm: Maximum gradient norm to clip to (default: None). If None, no gradient clipping is performed.
skip_nan_grad: Whether to skip optimizer updates when gradients contain NaN or Inf values (default: False).
Useful for training stability, especially with mixed precision or challenging datasets.
error_if_grad_nonfinite: Whether to raise when gradient clipping detects NaN or Inf gradients (default: False).
limit_train_batches: Limit on the number of training batches per epoch (default: float("inf")).
Helpful for debugging; should NOT be used when training production models.
limit_val_batches: Limit on the number of validation batches per epoch (default: float("inf")).
@@ -111,6 +114,7 @@ class FabricTrainer(ABC):
nccl_timeout: Timeout for NCCL operations (default: 3200). Only used with DDP strategy.
find_unused_parameters: Whether to let DDP find and skip gradient synchronization for unused parameters in the model (default: False). NOTE: Setting to True will incur a performance penalty,
but allow for training for bespoke use cases where parts of the model are frozen.
skip_optimizer_loading: Whether to skip loading the optimizer/scheduler state when restoring from checkpoints (default: False).
References:
(1) Fabric Arguments (https://lightning.ai/docs/fabric/stable/api/fabric_args.html)
@@ -145,6 +149,7 @@ class FabricTrainer(ABC):
# Training
self.clip_grad_max_norm = clip_grad_max_norm
self.skip_nan_grad = skip_nan_grad
self.error_if_grad_nonfinite = error_if_grad_nonfinite
self.grad_accum_steps = grad_accum_steps
# Stopping
@@ -162,6 +167,7 @@ class FabricTrainer(ABC):
self.output_dir = Path(output_dir) if output_dir else None
self.checkpoint_every_n_epochs = checkpoint_every_n_epochs
self.checkpoint_every_n_steps = checkpoint_every_n_steps
self.skip_optimizer_loading = skip_optimizer_loading
# Inject trainer reference into callbacks for easy access to trainer state
self._inject_trainer_into_callbacks()
@@ -262,14 +268,28 @@ class FabricTrainer(ABC):
)
self.initialize_or_update_trainer_state({"scheduler_cfg": scheduler_cfg})
@abstractmethod
def construct_model(self):
"""Instantiate the model, updating the trainer state in-place.
This method must set the "model" key in the state dictionary using `self.initialize_or_update_trainer_state()`.
For an example, see the `construct_model` method in the `AF3Trainer`
Construct the model and optionally wrap with EMA.
"""
raise NotImplementedError
# ... instantiate model with Hydra and Fabric
with self.fabric.init_module():
ranked_logger.info("Instantiating model...")
model = hydra.utils.instantiate(
self.state["train_cfg"].model.net,
_recursive_=False,
)
# Optionally, wrap the model with EMA
if self.state["train_cfg"].model.ema is not None:
ranked_logger.info("Wrapping model with EMA...")
model = EMA(model, **self.state["train_cfg"].model.ema)
self.initialize_or_update_trainer_state({"model": model})
def setup_model_optimizers_and_schedulers(self) -> None:
"""Setup the model, optimizer(s), and scheduler(s) with Fabric.
@@ -354,6 +374,11 @@ class FabricTrainer(ABC):
), "Checkpoint path not found in checkpoint configuration!"
ckpt_path = Path(ckpt_config.path)
reset_optimizer = bool(
getattr(ckpt_config, "reset_optimizer", False)
or self.skip_optimizer_loading
)
if ckpt_path.is_dir():
# If given a directory, load the latest checkpoint from the directory
ranked_logger.info(
@@ -362,14 +387,14 @@ class FabricTrainer(ABC):
self.load_checkpoint(
self.get_latest_checkpoint(ckpt_path),
weight_loading_config=ckpt_config.weight_loading_config,
reset_optimizer=ckpt_config.reset_optimizer,
reset_optimizer=reset_optimizer,
)
else:
# If given a specific checkpoint file, load that checkpoint
self.load_checkpoint(
ckpt_path,
weight_loading_config=ckpt_config.weight_loading_config,
reset_optimizer=ckpt_config.reset_optimizer,
reset_optimizer=reset_optimizer,
)
# Apply parameter freezing if a freezing config is provided
@@ -716,7 +741,7 @@ class FabricTrainer(ABC):
module=model,
optimizer=optimizer,
max_norm=self.clip_grad_max_norm,
error_if_nonfinite=False, # Don't error on NaN/Inf if skip_nan_grad is True
error_if_nonfinite=self.error_if_grad_nonfinite,
)
# ... step the optimizer

View File

@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
def weighted_rigid_align(
X_L, # [B, L, 3]
X_gt_L, # [B, L, 3]
X_exists_L, # [L]
w_L, # [B, L]
X_exists_L=None, # [L]
w_L=None, # [B, L]
):
"""
Weighted rigid body alignment of X_gt_L onto X_L with weights w_L
@@ -21,6 +21,13 @@ def weighted_rigid_align(
assert X_L.shape == X_gt_L.shape
assert X_L.shape[:-1] == w_L.shape
if X_exists_L is None:
X_exists_L = torch.ones((X_L.shape[-2]), dtype=torch.bool)
if w_L is None:
w_L = torch.ones_like(X_L[..., 0])
else:
w_L = w_L.to(torch.float32)
# Assert `X_exists_L` is a boolean mask
assert (
X_exists_L.dtype == torch.bool

View File

@@ -44,6 +44,7 @@ def set_accelerator_based_on_availability(cfg: dict | DictConfig):
cfg.trainer.num_nodes = 1
else:
cfg.trainer.accelerator = "gpu"
return cfg
class RankedLogger(logging.LoggerAdapter):

View File

@@ -1,7 +1,8 @@
import math
import torch
from rf3.flow_matching.rigid_utils import rot_vec_mul
from modelhub.utils.rigid import rot_vec_mul
def centre(X_L, X_exists_L):