mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
refactor source files for open sourcing (#648)
* mc * Add base class for inference engine * refactor inference engine * Move constants and components out of rfd3 folder * Fixes to engine * Update with working checkpoint * revert layer utils * Fix more imports * Move alignment, conditiontransitionblock * Update sampler name * mc * More import fixes * make format * Minor fixes * mc * Fix rf3 inference engine * Fix inference sampler * Fix modules * Running inference * Make format * add pre-commit hook * fix: RF3 inference (#670) fix: make rf3 tests in new format * Minor cleanup --------- Co-authored-by: Nathaniel Corley <ncorley@uw.edu>
This commit is contained in:
8
.pre-commit-config.yaml
Normal file
8
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: make-format
|
||||
name: Run `make format`
|
||||
entry: make format
|
||||
language: system
|
||||
pass_filenames: false
|
||||
11
README.md
11
README.md
@@ -111,3 +111,14 @@ To add a new model:
|
||||
2. Add `modelhub` as a dependency
|
||||
3. Implement model-specific code in `models/<model_name>/src/`
|
||||
4. Users can install with: `uv pip install -e ./models/<model_name>`
|
||||
|
||||
### Pre-commit Formatting
|
||||
|
||||
We ship a `.pre-commit-config.yaml` that runs `make format` (via `ruff format`) before each commit. Enable it once per clone:
|
||||
|
||||
```bash
|
||||
pip install pre-commit # if not already installed
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
After installation the hook automatically formats the repo whenever you `git commit`. Use `pre-commit run --all-files` to apply it manually.
|
||||
|
||||
Submodule lib/atomworks updated: 7e12a8a11b...04da3547b2
@@ -621,3 +621,4 @@ view(atom_array)
|
||||
**Alternative viewing options:**
|
||||
- View in PyMol like normal, or using `pymol_remote`
|
||||
- Use the `view_pymol()` function for direct PyMol integration
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ classifiers = [
|
||||
|
||||
dependencies = [
|
||||
# Core functionality shared across all models
|
||||
"modelforge",
|
||||
# "modelforge",
|
||||
# CLI
|
||||
"typer>=0.9.0,<1",
|
||||
# RF3-specific ML dependencies
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
from atomworks.common import exists
|
||||
from atomworks.enums import ChainType
|
||||
from atomworks.ml.datasets import logger, StructuralDatasetWrapper
|
||||
from atomworks.ml.datasets import StructuralDatasetWrapper, logger
|
||||
from atomworks.ml.datasets.parsers import (
|
||||
MetadataRowParser,
|
||||
load_example_from_metadata_row,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from beartype.typing import Any, Literal
|
||||
from jaxtyping import Float
|
||||
from rf3.data.rotation_augmentation import centre_random_augmentation
|
||||
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.rotation_augmentation import centre_random_augmentation
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -12,13 +11,12 @@ from atomworks.ml.preprocessing.msa.finding import (
|
||||
)
|
||||
from atomworks.ml.samplers import LoadBalancedDistributedSampler
|
||||
from biotite.structure import AtomArray
|
||||
from lightning.fabric import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from modelhub.inference_engines.base import BaseInferenceEngine
|
||||
from modelhub.metrics.metric import MetricManager
|
||||
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
|
||||
from modelhub.utils.logging import print_config_tree
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from rf3.model.RF3 import ShouldEarlyStopFn
|
||||
from rf3.utils.inference import (
|
||||
InferenceInput,
|
||||
@@ -63,7 +61,7 @@ def should_early_stop_by_mean_plddt(
|
||||
return fn
|
||||
|
||||
|
||||
class RF3InferenceEngine:
|
||||
class RF3InferenceEngine(BaseInferenceEngine):
|
||||
"""RF3 inference engine.
|
||||
|
||||
Separates model setup (expensive, once) from inference (can run multiple times).
|
||||
@@ -84,71 +82,44 @@ class RF3InferenceEngine:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_path: PathLike,
|
||||
# Model parameters
|
||||
n_recycles: int = 10,
|
||||
diffusion_batch_size: int = 5,
|
||||
num_steps: int = 50,
|
||||
seed: int | None = None,
|
||||
# Templating, MSAs, etc.
|
||||
template_noise_scale: float = 1e-5,
|
||||
early_stopping_plddt_threshold: float | None = None,
|
||||
metrics_cfg: dict | OmegaConf | MetricManager | None = None,
|
||||
num_nodes: int = 1,
|
||||
devices_per_node: int = 1,
|
||||
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
|
||||
# Output control
|
||||
compress_outputs: bool = True,
|
||||
# Debug
|
||||
print_config: bool = False,
|
||||
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
|
||||
early_stopping_plddt_threshold: float | None = None,
|
||||
# Metrics
|
||||
metrics_cfg: dict | OmegaConf | MetricManager | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize inference engine and load model.
|
||||
|
||||
Model config is loaded from checkpoint and overridden with parameters provided here.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to model checkpoint.
|
||||
n_recycles: Number of recycles. Defaults to ``10``.
|
||||
diffusion_batch_size: Number of structures to generate per input. Defaults to ``5``.
|
||||
num_steps: Number of diffusion steps. Defaults to ``50``.
|
||||
seed: Random seed. If None, uses external RNG state. Defaults to ``None``.
|
||||
template_noise_scale: Noise scale for template coordinates. Defaults to ``1e-5``.
|
||||
raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
|
||||
compress_outputs: Whether to gzip output files. Defaults to ``True``.
|
||||
early_stopping_plddt_threshold: Stop early if pLDDT below threshold. Defaults to ``None``.
|
||||
metrics_cfg: Metrics configuration. Can be:
|
||||
- dict/OmegaConf with Hydra configs
|
||||
- Pre-instantiated MetricManager
|
||||
- None (no metrics).
|
||||
Defaults to ``None``.
|
||||
num_nodes: Number of nodes for distributed inference. Defaults to ``1``.
|
||||
devices_per_node: Number of devices per node. Defaults to ``1``.
|
||||
compress_outputs: Whether to gzip output files. Defaults to ``True``.
|
||||
print_config: Whether to print config trees. Defaults to ``False``.
|
||||
raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
|
||||
**kwargs: Additional arguments passed to BaseInferenceEngine:
|
||||
- ckpt_path (PathLike, required): Path to model checkpoint.
|
||||
- seed (int | None): Random seed. If None, uses external RNG state. Defaults to ``None``.
|
||||
- num_nodes (int): Number of nodes for distributed inference. Defaults to ``1``.
|
||||
- devices_per_node (int): Number of devices per node. Defaults to ``1``.
|
||||
- print_config (bool): Whether to print config trees. Defaults to ``False``.
|
||||
"""
|
||||
# Load checkpoint and config
|
||||
ranked_logger.info(f"Loading checkpoint from {Path(ckpt_path).resolve()}...")
|
||||
checkpoint = torch.load(ckpt_path, "cpu", weights_only=False)
|
||||
self.cfg = OmegaConf.create(checkpoint["train_cfg"])
|
||||
|
||||
# Override config with inference parameters
|
||||
self.cfg.model.net.inference_sampler.num_timesteps = num_steps
|
||||
self.cfg.trainer.num_nodes = num_nodes
|
||||
self.cfg.trainer.devices_per_node = devices_per_node
|
||||
|
||||
# Set metrics - can be dict/OmegaConf or MetricManager
|
||||
# Store MetricManager separately since OmegaConf can't serialize it
|
||||
self._custom_metric_manager = None
|
||||
if isinstance(metrics_cfg, MetricManager):
|
||||
# Already instantiated - store separately and pass to trainer later
|
||||
self._custom_metric_manager = metrics_cfg
|
||||
self.cfg.trainer["metrics"] = {} # Empty dict in config
|
||||
elif metrics_cfg is not None:
|
||||
# Hydra config dict
|
||||
self.cfg.trainer["metrics"] = metrics_cfg
|
||||
else:
|
||||
self.cfg.trainer["metrics"] = {}
|
||||
|
||||
set_accelerator_based_on_availability(self.cfg)
|
||||
|
||||
# set MSA directories
|
||||
if env_var_msa_dirs := get_msa_dirs_from_env(raise_if_not_set=False):
|
||||
override_msa_dirs = [str(msa_dir) for msa_dir in env_var_msa_dirs]
|
||||
@@ -163,112 +134,72 @@ class RF3InferenceEngine:
|
||||
]
|
||||
ranked_logger.info(f"Using default MSA directories: {override_msa_dirs}")
|
||||
|
||||
# Dataset overrides
|
||||
self.dataset_overrides = {
|
||||
"diffusion_batch_size": diffusion_batch_size,
|
||||
"n_recycles": n_recycles,
|
||||
"raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
|
||||
"undesired_res_names": [],
|
||||
"template_noise_scales": {
|
||||
"atomized": template_noise_scale,
|
||||
"not_atomized": template_noise_scale,
|
||||
super().__init__(
|
||||
transform_overrides={
|
||||
"diffusion_batch_size": diffusion_batch_size,
|
||||
"n_recycles": n_recycles,
|
||||
"raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
|
||||
"undesired_res_names": [],
|
||||
"template_noise_scales": {
|
||||
"atomized": template_noise_scale,
|
||||
"not_atomized": template_noise_scale,
|
||||
},
|
||||
"allowed_chain_types_for_conditioning": None,
|
||||
"protein_msa_dirs": [
|
||||
{
|
||||
"dir": msa_dir,
|
||||
"extension": extension.value,
|
||||
"directory_depth": depth,
|
||||
}
|
||||
for msa_dir, depth, extension in [
|
||||
(msa_dir, *get_msa_depth_and_ext_from_folder(Path(msa_dir)))
|
||||
for msa_dir in override_msa_dirs
|
||||
]
|
||||
],
|
||||
"rna_msa_dirs": [],
|
||||
# (Paranoia - in validation, these should be set correctly anyhow)
|
||||
"p_give_polymer_ref_conf": 0.0,
|
||||
"p_give_non_polymer_ref_conf": 0.0,
|
||||
"p_dropout_ref_conf": 0.0,
|
||||
"use_element_for_atom_names_of_atomized_tokens": True,
|
||||
},
|
||||
"allowed_chain_types_for_conditioning": None,
|
||||
"protein_msa_dirs": [
|
||||
{
|
||||
"dir": msa_dir,
|
||||
"extension": extension.value,
|
||||
"directory_depth": depth,
|
||||
}
|
||||
for msa_dir, depth, extension in [
|
||||
(msa_dir, *get_msa_depth_and_ext_from_folder(Path(msa_dir)))
|
||||
for msa_dir in override_msa_dirs
|
||||
]
|
||||
],
|
||||
"rna_msa_dirs": [],
|
||||
# (Paranoia - in validation, these should be set correctly anyhow)
|
||||
"p_give_polymer_ref_conf": 0.0,
|
||||
"p_give_non_polymer_ref_conf": 0.0,
|
||||
"p_dropout_ref_conf": 0.0,
|
||||
"use_element_for_atom_names_of_atomized_tokens": True,
|
||||
}
|
||||
|
||||
self.print_config = print_config
|
||||
|
||||
# Set random seed (only if seed is not None)
|
||||
if seed is not None or self.cfg.seed is not None:
|
||||
seed = seed or self.cfg.seed
|
||||
ranked_logger.info(f"Seeding everything with seed={seed}...")
|
||||
seed_everything(seed, workers=True, verbose=True)
|
||||
else:
|
||||
ranked_logger.info("Seed is None - using external RNG state")
|
||||
|
||||
# Instantiate trainer
|
||||
ranked_logger.info("Instantiating trainer...")
|
||||
if self.print_config:
|
||||
print_config_tree(
|
||||
self.cfg.trainer, resolve=True, title="INFERENCE TRAINER CONFIGURATION"
|
||||
)
|
||||
|
||||
self.trainer = hydra.utils.instantiate(
|
||||
self.cfg.trainer,
|
||||
_convert_="partial",
|
||||
_recursive_=False,
|
||||
inference_sampler_overrides={
|
||||
"num_timesteps": num_steps,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# If we have a custom MetricManager, override the trainer's metrics
|
||||
if self._custom_metric_manager is not None:
|
||||
self.trainer.metrics = self._custom_metric_manager
|
||||
# remove loss override if present (i.e. keep from checkpoint)
|
||||
self.overrides["trainer"].pop("loss", None)
|
||||
|
||||
self.ckpt_path = ckpt_path
|
||||
# Store metrics config for later - will be set directly on trainer in initialize()
|
||||
self._metrics_cfg = metrics_cfg
|
||||
|
||||
# Dataset overrides
|
||||
self.early_stopping_plddt_threshold = early_stopping_plddt_threshold
|
||||
self.compress_outputs = compress_outputs
|
||||
|
||||
# Setup model
|
||||
ranked_logger.info("Setting up model...")
|
||||
self.trainer.fabric.launch()
|
||||
self.trainer.initialize_or_update_trainer_state({"train_cfg": self.cfg})
|
||||
self.trainer.construct_model()
|
||||
def initialize(self):
|
||||
cfg = super().initialize()
|
||||
|
||||
ranked_logger.info("Loading model weights from checkpoint...")
|
||||
self.trainer.load_checkpoint(checkpoint=checkpoint)
|
||||
if cfg is not None:
|
||||
self.cfg = cfg # store for later use
|
||||
|
||||
# Ensure optimizer isn't loaded
|
||||
self.trainer.state["optimizer"] = None
|
||||
self.trainer.state["train_cfg"].model.optimizer = None
|
||||
# Set trainer metrics directly based on what was requested
|
||||
# This bypasses the OmegaConf merge issue with empty dicts
|
||||
if isinstance(self._metrics_cfg, MetricManager):
|
||||
# Already instantiated - use directly
|
||||
self.trainer.metrics = self._metrics_cfg
|
||||
elif self._metrics_cfg is not None:
|
||||
# Hydra config dict - instantiate MetricManager
|
||||
self.trainer.metrics = MetricManager.instantiate_from_hydra(
|
||||
metrics_cfg=self._metrics_cfg
|
||||
)
|
||||
else:
|
||||
# No metrics requested - disable them
|
||||
self.trainer.metrics = None
|
||||
|
||||
self.trainer.setup_model_optimizers_and_schedulers()
|
||||
self.trainer.state["model"].eval()
|
||||
|
||||
# Construct pipeline
|
||||
ranked_logger.info("Building Transform pipeline...")
|
||||
first_val_dataset_key, first_val_dataset = next(
|
||||
iter(self.cfg.datasets.val.items())
|
||||
)
|
||||
ranked_logger.info(
|
||||
f"Using settings from validation dataset: {first_val_dataset_key}."
|
||||
)
|
||||
|
||||
assert (
|
||||
first_val_dataset.dataset.transform.is_inference
|
||||
), "Inference must be enabled for the validation dataset."
|
||||
|
||||
# Provide manual overrides to dataset config
|
||||
for key, value in self.dataset_overrides.items():
|
||||
first_val_dataset.dataset.transform[key] = value
|
||||
|
||||
if self.print_config:
|
||||
print_config_tree(
|
||||
first_val_dataset.dataset.transform,
|
||||
resolve=True,
|
||||
title="INFERENCE TRANSFORM PIPELINE",
|
||||
)
|
||||
|
||||
self.pipeline = hydra.utils.instantiate(
|
||||
first_val_dataset.dataset.transform,
|
||||
)
|
||||
|
||||
ranked_logger.info("Model loaded and ready for inference.")
|
||||
return cfg
|
||||
|
||||
def run(
|
||||
self,
|
||||
@@ -314,6 +245,8 @@ class RF3InferenceEngine:
|
||||
If ``out_dir`` is None: Dict mapping example_id to results dict.
|
||||
If ``out_dir`` is set: None (results saved to disk).
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
# Setup output directory if provided
|
||||
out_dir = Path(out_dir) if out_dir else None
|
||||
if out_dir:
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rf3.alignment import weighted_rigid_align
|
||||
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
from modelhub.utils.alignment import weighted_rigid_align
|
||||
|
||||
|
||||
# resolve residue-level symmetries in native vs pred
|
||||
@@ -360,31 +359,6 @@ class SubunitSymmetryResolution(nn.Module):
|
||||
return loss_input
|
||||
|
||||
|
||||
class Loss(nn.Module):
|
||||
def __init__(self, **losses):
|
||||
super().__init__()
|
||||
self.to_compute = []
|
||||
for loss_name, loss in losses.items():
|
||||
loss_fn = hydra.utils.instantiate(loss)
|
||||
print(f"Adding loss {loss_name} to the loss function")
|
||||
self.to_compute.append(loss_fn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
network_input,
|
||||
network_output,
|
||||
loss_input,
|
||||
):
|
||||
loss_dict = {}
|
||||
loss = 0
|
||||
for loss_fn in self.to_compute:
|
||||
loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
|
||||
loss += loss_
|
||||
loss_dict.update(loss_dict_)
|
||||
loss_dict["total_loss"] = loss.detach()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class ProteinLigandBondLoss(nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
|
||||
@@ -15,6 +15,7 @@ from rf3.model.layers.pairformer_layers import (
|
||||
RF3TemplateEmbedder,
|
||||
)
|
||||
|
||||
from modelhub.model.layers.blocks import FourierEmbedding
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -229,29 +230,6 @@ class DiffusionConditioning(nn.Module):
|
||||
return _run_conditioning(Z_II, S_trunk_I, S_inputs_I)
|
||||
|
||||
|
||||
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||
|
||||
|
||||
class FourierEmbedding(nn.Module):
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.c = c
|
||||
self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
|
||||
self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
# super().reset_parameters()
|
||||
nn.init.normal_(self.w)
|
||||
nn.init.normal_(self.b)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
t, # [D]
|
||||
):
|
||||
return torch.cos(2 * pi * (t[:, None] * self.w + self.b))
|
||||
|
||||
|
||||
class DistogramHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -16,10 +16,10 @@ from rf3.model.layers.outer_product import (
|
||||
OuterProductMean_AF3,
|
||||
)
|
||||
from rf3.model.RF3_blocks import MSAPairWeightedAverage, MSASubsampleEmbedder
|
||||
from rf3.util_module import Dropout
|
||||
from torch import nn
|
||||
from torch.nn.functional import one_hot, relu
|
||||
|
||||
from modelhub.model.layers.blocks import Dropout
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from einops import repeat
|
||||
from jaxtyping import Float, Int
|
||||
from lightning_utilities import apply_to_collection
|
||||
from omegaconf import DictConfig
|
||||
from rf3.loss.af3_losses import Loss as AF3Loss
|
||||
from rf3.loss.af3_losses import (
|
||||
ResidueSymmetryResolution,
|
||||
SubunitSymmetryResolution,
|
||||
@@ -15,6 +14,7 @@ from rf3.utils.io import build_stack_from_atom_array_and_batched_coords
|
||||
from rf3.utils.recycling import get_recycle_schedule
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.metrics.losses import Loss
|
||||
from modelhub.metrics.metric import MetricManager
|
||||
from modelhub.trainers.fabric import FabricTrainer
|
||||
from modelhub.training.EMA import EMA
|
||||
@@ -42,6 +42,7 @@ class RF3Trainer(FabricTrainer):
|
||||
n_recycles_train: int | None = None,
|
||||
loss: DictConfig | dict | None = None,
|
||||
metrics: DictConfig | dict | MetricManager | None = None,
|
||||
seed=None, # dumped
|
||||
**kwargs,
|
||||
):
|
||||
"""See `FabricTrainer` for the additional initialization arguments.
|
||||
@@ -79,7 +80,7 @@ class RF3Trainer(FabricTrainer):
|
||||
self.metrics = None
|
||||
|
||||
# Loss
|
||||
self.loss = AF3Loss(**loss) if loss else None
|
||||
self.loss = Loss(**loss) if loss else None
|
||||
|
||||
# (Symmetry resolution)
|
||||
self.subunit_symm_resolve = SubunitSymmetryResolution()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def init_lecun_normal(module, scale=1.0):
|
||||
@@ -38,29 +37,6 @@ def create_custom_forward(module, **kwargs):
|
||||
return custom_forward
|
||||
|
||||
|
||||
class Dropout(nn.Module):
|
||||
# Dropout entire row or column
|
||||
def __init__(self, broadcast_dim=None, p_drop=0.15):
|
||||
super(Dropout, self).__init__()
|
||||
# give ones with probability of 1-p_drop / zeros with p_drop
|
||||
self.sampler = torch.distributions.bernoulli.Bernoulli(
|
||||
torch.tensor([1 - p_drop])
|
||||
)
|
||||
self.broadcast_dim = broadcast_dim
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training: # no drophead during evaluation mode
|
||||
return x
|
||||
shape = list(x.shape)
|
||||
if self.broadcast_dim is not None:
|
||||
shape[self.broadcast_dim] = 1
|
||||
mask = self.sampler.sample(shape).to(x.device).view(shape)
|
||||
|
||||
x = mask * x / (1.0 - self.p_drop)
|
||||
return x
|
||||
|
||||
|
||||
def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5):
|
||||
# Distance radial basis function
|
||||
D_max = D_min + (D_count - 1) * D_sigma
|
||||
|
||||
@@ -8,8 +8,8 @@ from atomworks.ml.utils.io import apply_sharding_pattern
|
||||
from atomworks.ml.utils.misc import hash_sequence
|
||||
from beartype.typing import Literal
|
||||
from biotite.structure import AtomArray, AtomArrayStack, stack
|
||||
from rf3.alignment import weighted_rigid_align
|
||||
|
||||
from modelhub.utils.alignment import weighted_rigid_align
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
@@ -171,7 +171,7 @@ def test_inference_regression(example_id, rmsd_tolerance, csv_tolerance):
|
||||
baseline_dir = TEST_DATA_DIR / "inference_regression_tests" / example_id
|
||||
|
||||
with (
|
||||
initialize(config_path="../configs"),
|
||||
initialize(config_path="../configs", version_base="1.3"),
|
||||
tempfile.TemporaryDirectory() as temp_dir,
|
||||
rng_state(create_rng_state_from_seeds(1, 1, 1)),
|
||||
):
|
||||
@@ -327,5 +327,6 @@ def test_inference_regression_in_memory(example_id, rmsd_tolerance, csv_toleranc
|
||||
rmsd_difference < rmsd_tolerance
|
||||
), f"Mean RMSD difference {rmsd_difference:.4f}Å exceeds {rmsd_tolerance}Å tolerance for {example_id}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -108,3 +108,23 @@ This is then passed through the same processing pipeline as in training with `is
|
||||
<figcaption>Overview of important transforms in the Atom14 conditioning pipeline.
|
||||
</figcaption>
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this code or data in your work, please cite:
|
||||
|
||||
```bibtex
|
||||
@article {butcher2025_rfdiffusion3,
|
||||
author = {Butcher, Jasper and Krishna, Rohith and Mitra, Raktim and Brent, Rafael Isaac and Li, Yanjing and Corley, Nathaniel and Kim, Paul T and Funk, Jonathan and Mathis, Simon Valentin and Salike, Saman and Muraishi, Aiko and Eisenach, Helen and Thompson, Tuscan Rock and Chen, Jie and Politanska, Yuliya and Sehgal, Enisha and Coventry, Brian and Zhang, Odin and Qiang, Bo and Didi, Kieran and Kazman, Maxwell and DiMaio, Frank and Baker, David},
|
||||
title = {De novo Design of All-atom Biomolecular Interactions with RFdiffusion3},
|
||||
elocation-id = {2025.09.18.676967},
|
||||
year = {2025},
|
||||
doi = {10.1101/2025.09.18.676967},
|
||||
publisher = {Cold Spring Harbor Laboratory},
|
||||
URL = {https://www.biorxiv.org/content/early/2025/11/19/2025.09.18.676967},
|
||||
eprint = {https://www.biorxiv.org/content/early/2025/11/19/2025.09.18.676967.full.pdf},
|
||||
journal = {bioRxiv}
|
||||
}
|
||||
```
|
||||
@@ -35,17 +35,6 @@ dataset:
|
||||
use_element_for_atom_names_of_atomized_tokens: ${datasets.global_transform_args.use_element_for_atom_names_of_atomized_tokens}
|
||||
residue_cache_dir: ${paths.data.residue_cache_dir}
|
||||
|
||||
# PPI
|
||||
keep_full_binder_in_spatial_crop: ${datasets.global_transform_args.keep_full_binder_in_spatial_crop}
|
||||
max_binder_length: ${datasets.global_transform_args.max_binder_length}
|
||||
max_ppi_hotspots_frac_to_provide: ${datasets.global_transform_args.max_ppi_hotspots_frac_to_provide}
|
||||
ppi_hotspot_max_distance: ${datasets.global_transform_args.ppi_hotspot_max_distance}
|
||||
|
||||
# Secondary structure
|
||||
max_ss_frac_to_provide: ${datasets.global_transform_args.max_ss_frac_to_provide}
|
||||
min_ss_island_len: ${datasets.global_transform_args.min_ss_island_len}
|
||||
max_ss_island_len: ${datasets.global_transform_args.max_ss_island_len}
|
||||
|
||||
# Other dataset-specific parameters
|
||||
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
|
||||
token_1d_features: ${model.net.token_initializer.token_1d_features}
|
||||
@@ -159,23 +159,6 @@ trainer:
|
||||
prevalidate: False
|
||||
validate_every_n_epochs: 4
|
||||
precision: bf16-mixed
|
||||
loss:
|
||||
verbose_diffusion_loss: null
|
||||
simple_diffusion_loss:
|
||||
_target_: rfd3.metrics.losses.SimpleDiffusionLoss
|
||||
sigma_data: ${model.net.diffusion_module.sigma_data}
|
||||
weight: 4.0
|
||||
lddt_weight: 0.25
|
||||
alpha_virtual_atom: 1.0
|
||||
alpha_polar_residues: 1.0
|
||||
|
||||
lp_weight: 0.0
|
||||
unindexed_norm_p: 1.0
|
||||
alpha_unindexed_diffused: 1.0
|
||||
unindexed_t_alpha: 0.75
|
||||
normalize_virtual_atom_weight: False
|
||||
alpha_ligand: 10.0
|
||||
|
||||
# callbacks:
|
||||
# activations_tracking_callback:
|
||||
# _target_: modelhub.callbacks.health_logging.ActivationsGradientsWeightsTracker
|
||||
|
||||
@@ -7,10 +7,9 @@ defaults:
|
||||
ckpt_path: ???
|
||||
num_nodes: 1
|
||||
devices_per_node: 1
|
||||
print_config: False
|
||||
seed: null
|
||||
|
||||
# Parameters for RFD3InferenceEngine.run()
|
||||
inputs: ???
|
||||
out_dir: ???
|
||||
dump_predictions: true
|
||||
dump_trajectories: false
|
||||
one_model_per_file: false
|
||||
|
||||
@@ -3,14 +3,14 @@ defaults:
|
||||
- base
|
||||
- _self_
|
||||
|
||||
_target_: rfd3.inference.engine.RFD3InferenceEngine
|
||||
_target_: rfd3.engine.RFD3InferenceEngine
|
||||
|
||||
out_dir: ???
|
||||
inputs: ??? # null, json, pdb or
|
||||
ckpt_path: /projects/ml/aa_design/models/rfd3_latest.ckpt
|
||||
# ckpt_path: /projects/ml/aa_design/models/rfd3_latest.ckpt
|
||||
ckpt_path: /projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt
|
||||
json_keys_subset: null
|
||||
skip_existing: True
|
||||
seed: null # if null samples seed integer based on timestamp
|
||||
|
||||
#########################################################
|
||||
# Design spec args: overrides args from input json
|
||||
@@ -21,7 +21,6 @@ specification: {}
|
||||
diffusion_batch_size: 8
|
||||
n_batches: 1
|
||||
|
||||
|
||||
# Inference sampler args | set to None to use the default in the checkpoint's config
|
||||
inference_sampler:
|
||||
kind: "default" # "default" or "symmetry" to choose the sampler
|
||||
@@ -35,14 +34,9 @@ inference_sampler:
|
||||
cfg_t_max: null # max t to apply cfg guidance
|
||||
cfg_scale: 1.5
|
||||
center_option: "all" # Options are ["all", "motif", "diffuse"]
|
||||
move_noise_to_reset_com: False # Reset the COM of the diffuse region after the re-noising operation in each diffusion step
|
||||
s_trans: 1.0 # Translational noise scale for augmentation during inference
|
||||
fraction_of_steps_to_fix_motif: 0.0 # Fraction of steps to let the model not move the motif. e.g. if we have 10 steps, set this value to 0.2 will make model not move motif for the last 2 steps.
|
||||
skip_few_diffusion_steps: False # Choose to skip some diffusion steps based on the noise scheme
|
||||
inference_noise_scaling_factor: 1.0
|
||||
allow_realignment: False
|
||||
zero_drift_noise: False
|
||||
use_frame_guidance: False
|
||||
|
||||
# Diffusion args:
|
||||
num_timesteps: 200
|
||||
@@ -51,7 +45,6 @@ inference_sampler:
|
||||
p: 7
|
||||
gamma_0: 0.6 # Previously 1.0 | 0.0 for ODE sampling
|
||||
gamma_min: 1.0
|
||||
gamma_min2: 0.0
|
||||
s_jitter_origin: 0.0 # Sigma of gaussian noise to jitter the motif offset (equivalent to ORI token Jitter)
|
||||
|
||||
# Saving args
|
||||
@@ -66,13 +59,8 @@ output_full_json: True
|
||||
# e.g. Empty string -> f'{jsonkey}_{batch}_{model}'
|
||||
# e.g. Chunk string -> f'{chunkprefix_}{jsonkey}_{batch}_{model}' (pipelines usage)
|
||||
global_prefix: null
|
||||
|
||||
dump_prediction_metadata_json: True
|
||||
dump_trajectories: False
|
||||
one_model_per_file: True
|
||||
align_trajectory_structures: False
|
||||
|
||||
# Additional args
|
||||
print_config: False
|
||||
prevalidate_inputs: True
|
||||
low_memory_mode: False # False for standard mode, True for memory efficient tokenization mode
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
_target_: rfd3.model.aa_design.RFD3
|
||||
_target_: rfd3.model.RFD3.RFD3
|
||||
|
||||
c_s: 384
|
||||
c_z: 128
|
||||
@@ -53,7 +53,7 @@ token_initializer: # formerly known as the trunk
|
||||
n_attn_keys: 128
|
||||
|
||||
diffusion_module:
|
||||
_target_: rfd3.model.rfd3_diffusion_module.RFD3DiffusionModule
|
||||
_target_: rfd3.model.RFD3_diffusion_module.RFD3DiffusionModule
|
||||
c_token: 768
|
||||
c_t_embed: 256 # Time embedding dimension
|
||||
sigma_data: 16
|
||||
|
||||
@@ -26,6 +26,6 @@ design_benchmark_data_dir: /projects/ml/aa_design/benchmarks
|
||||
design_model_weight_dir: /projects/ml/aa_design/models
|
||||
|
||||
# path to directory with cached residue data
|
||||
residue_cache_dir: /net/tukwila/ncorley/datahub/MACE-OFF23_medium
|
||||
residue_cache_dir: /net/tukwila/lschaaf/datahub/MACE-Egret-3-noH/mace_embeddings
|
||||
|
||||
cif_cache_dir: /net/tukwila/ncorley/cifutils/cache
|
||||
@@ -1,6 +1,7 @@
|
||||
defaults:
|
||||
- ddp
|
||||
- loss/losses/diffusion_loss@loss.verbose_diffusion_loss
|
||||
- loss/losses/diffusion_loss@loss.diffusion_loss
|
||||
- loss/losses/sequence_loss@loss.sequence_loss
|
||||
- metrics: design_metrics
|
||||
- _self_
|
||||
|
||||
@@ -27,11 +28,4 @@ clip_grad_max_norm: 10.0
|
||||
output_dir: ${paths.output_dir}
|
||||
n_recycles_train: 2
|
||||
grad_accum_steps: 3 # overridden by launch.sh
|
||||
skip_optimizer_loading: True
|
||||
|
||||
loss:
|
||||
verbose_diffusion_loss:
|
||||
sequence_head_loss:
|
||||
_target_: rfd3.metrics.losses.SequenceLoss
|
||||
weight: 0.1
|
||||
max_t: 1
|
||||
skip_optimizer_loading: True
|
||||
@@ -1,21 +1,11 @@
|
||||
# _target_: modelhub.loss.af3_losses.DiffusionLoss
|
||||
# sigma_data: ${model.net.diffusion_module.sigma_data}
|
||||
# alpha_dna: 5
|
||||
# alpha_rna: 5
|
||||
# alpha_ligand: 10
|
||||
# edm_lambda: True
|
||||
# se3_invariant_loss: True
|
||||
# clamp_diffusion_loss: False
|
||||
sigma_data: ${model.net.diffusion_module.sigma_data}
|
||||
weight: 4.0
|
||||
|
||||
_target_: rfd3.metrics.losses.VerboseDiffusionLoss
|
||||
alpha_motif: 1.0
|
||||
alpha_ca_atom: 1.0
|
||||
alpha_virtual_atom: 1.0
|
||||
alpha_fixed_motif: 2.0
|
||||
alpha_unindexed_diffused: 2.0
|
||||
lddt_weight: 0.25
|
||||
clamp_diffusion_loss: True
|
||||
align_prediction: False
|
||||
use_motif_aligned_loss: False
|
||||
alpha_virtual_atom: 1.0
|
||||
alpha_polar_residues: 1.0
|
||||
lp_weight: 0.0
|
||||
unindexed_norm_p: 1.0
|
||||
alpha_unindexed_diffused: 1.0
|
||||
unindexed_t_alpha: 0.75
|
||||
normalize_virtual_atom_weight: False
|
||||
alpha_ligand: 10.0
|
||||
@@ -0,0 +1,3 @@
|
||||
_target_: rfd3.metrics.losses.SequenceLoss
|
||||
weight: 0.1
|
||||
max_t: 1
|
||||
@@ -3,7 +3,7 @@ name = "rfd3"
|
||||
dynamic = ["version"]
|
||||
description = "De novo Design of All-atom Biomolecular Interactions with RFdiffusion3"
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.12"
|
||||
requires-python = ">=3.12"
|
||||
authors = [
|
||||
{ name = "Institute for Protein Design", email = "contact@ipd.uw.edu" },
|
||||
]
|
||||
|
||||
@@ -19,26 +19,15 @@ def design(ctx: typer.Context):
|
||||
|
||||
# Get all arguments
|
||||
args = ctx.params.get("args", []) + ctx.args
|
||||
|
||||
# Parse arguments
|
||||
hydra_overrides = []
|
||||
|
||||
if len(args) == 1 and "=" not in args[0]:
|
||||
# Old style: single positional argument assumed to be inputs
|
||||
hydra_overrides.append(f"inputs={args[0]}")
|
||||
else:
|
||||
# New style: all arguments are hydra overrides
|
||||
hydra_overrides.extend(args)
|
||||
args = [a for a in args if a not in ["design", "fold"]]
|
||||
|
||||
# Ensure we have at least a default inference_engine if not specified
|
||||
has_inference_engine = any(
|
||||
arg.startswith("inference_engine=") for arg in hydra_overrides
|
||||
)
|
||||
has_inference_engine = any(arg.startswith("inference_engine=") for arg in args)
|
||||
if not has_inference_engine:
|
||||
hydra_overrides.append("inference_engine=rfdiffusion3")
|
||||
args.append("inference_engine=rfdiffusion3")
|
||||
|
||||
with initialize_config_dir(config_dir=config_path, version_base="1.3"):
|
||||
cfg = compose(config_name="inference", overrides=hydra_overrides)
|
||||
cfg = compose(config_name="inference", overrides=args)
|
||||
# Lazy import to avoid loading heavy dependencies at CLI startup
|
||||
from rfd3.run_inference import run_inference
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import numpy as np
|
||||
from atomworks.constants import CRYSTALLIZATION_AIDS
|
||||
|
||||
from modelhub.constants import TIP_BY_RESTYPE
|
||||
|
||||
TIP_BY_RESTYPE
|
||||
|
||||
# Annot: default (diffused default)
|
||||
REQUIRED_CONDITIONING_ANNOTATION_VALUES = {
|
||||
@@ -204,46 +207,3 @@ SELECTION_NONPROTEIN = [
|
||||
"MACROLIDE",
|
||||
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
|
||||
]
|
||||
|
||||
# fmt: off
|
||||
# ... For convenience, we allow BKBN, or TIP to be used as a shortcut | TIP is the largest set of fixed atom given at least 2 tip atoms
|
||||
TIP_BY_RESTYPE = {
|
||||
"TRP": ["CG","CD1","CD2","NE1","CE2","CE3","CZ2","CZ3","CH2"], # fix both rings
|
||||
"HIS": ["CG","ND1","CD2","CE1","NE2"], # fixed ring
|
||||
"TYR": ["CZ","OH"], # keeps ring dihedral flexible
|
||||
"PHE": ["CG","CD1","CD2","CE1","CE2","CZ"],
|
||||
"ASN": ["CB", "CG","OD1","ND2"],
|
||||
"ASP": ["CB", "CG","OD1","OD2"],
|
||||
"GLN": ["CG", "CD","OE1","NE2"],
|
||||
"GLU": ["CG", "CD","OE1","OE2"],
|
||||
"CYS": ["CB", "SG"],
|
||||
"SER": ["CB", "OG"],
|
||||
"THR": ["CB", "OG1"],
|
||||
"LEU": ["CB", "CG", "CD1", "CD2"],
|
||||
"VAL": ["CG1", "CG2"],
|
||||
"ILE": ["CB", "CG2"],
|
||||
"MET": ["SD", "CE"],
|
||||
"LYS": ["CE","NZ"],
|
||||
"ARG": ["CD","NE","CZ","NH1","NH2"],
|
||||
"PRO": None,
|
||||
"ALA": None,
|
||||
"GLY": None,
|
||||
"UNK": None,
|
||||
"MSK": None
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
STANDARD_PARSER_ARGS = {
|
||||
"add_missing_atoms": True,
|
||||
"add_id_and_entity_annotations": True,
|
||||
"add_bond_types_from_struct_conn": ("covale",),
|
||||
"remove_ccds": tuple(CRYSTALLIZATION_AIDS),
|
||||
"remove_waters": True,
|
||||
"fix_ligands_at_symmetry_centers": True,
|
||||
"fix_arginines": True,
|
||||
"fix_formal_charges": True,
|
||||
"fix_bond_types": True,
|
||||
"convert_mse_to_met": True, # Changed from False to True vs. atomworks.io.parser.parse default
|
||||
"hydrogen_policy": "keep",
|
||||
"model": None, # all models
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# RFdiffusion3 — Input specification (dialect **2**)
|
||||
|
||||
> **TL;DR**
|
||||
> Inputs are now defined with a single `InputSpecification` class.
|
||||
> Inputs are now defined with a single `DesignInputSpecification` class.
|
||||
> Selections like “what’s fixed?”, “what’s sequence-free?”, “which atoms are donors/acceptors?” are all expressed with the same **InputSelection** mini-language.
|
||||
> Everything is reproducibly logged back out alongside your generation.
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
- [What changed (high level)](#what-changed-high-level)
|
||||
- [Quick start](#quick-start)
|
||||
- [The `InputSelection` mini-language](#the-inputselection-mini-language)
|
||||
- [Full schema: `InputSpecification`](#full-schema-inputspecification)
|
||||
- [Full schema: `DesignInputSpecification`](#full-schema-DesignInputSpecification)
|
||||
- [Common recipes (cookbook)](#common-recipes-cookbook)
|
||||
- [Partial diffusion](#partial-diffusion)
|
||||
- [Symmetry](#symmetry)
|
||||
@@ -40,7 +40,7 @@
|
||||
|
||||
---
|
||||
|
||||
## InputSpecification
|
||||
## DesignInputSpecification
|
||||
|
||||
| Field | Type | Description |
|
||||
| -------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------- |
|
||||
|
||||
435
models/rfd3/src/rfd3/engine.py
Normal file
435
models/rfd3/src/rfd3/engine.py
Normal file
@@ -0,0 +1,435 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from biotite.structure import AtomArray
|
||||
from toolz import merge_with
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.inference_engines.base import BaseInferenceEngine
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
|
||||
from rfd3.inference.datasets import (
|
||||
assemble_distributed_inference_loader_from_json,
|
||||
)
|
||||
from rfd3.inference.input_parsing import DesignInputSpecification
|
||||
from rfd3.utils.inference import ensure_input_is_abspath
|
||||
from rfd3.utils.io import (
|
||||
CIF_LIKE_EXTENSIONS,
|
||||
dump_metadata,
|
||||
dump_structures,
|
||||
dump_trajectories,
|
||||
extract_example_id_from_path,
|
||||
find_files_with_extension,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
class RFD3InferenceEngine(BaseInferenceEngine):
|
||||
"""Inference engine for RFdiffusion3"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# Default input handling args
|
||||
skip_existing: bool,
|
||||
json_keys_subset: None | List[str],
|
||||
prevalidate_inputs: bool,
|
||||
# Base inference engine args
|
||||
diffusion_batch_size: int,
|
||||
inference_sampler: dict,
|
||||
specification: dict | None,
|
||||
# Structure dumping arguments
|
||||
global_prefix: str | None,
|
||||
cleanup_guideposts: bool,
|
||||
cleanup_virtual_atoms: bool,
|
||||
read_sequence_from_sequence_head: bool,
|
||||
output_full_json: bool,
|
||||
dump_prediction_metadata_json: bool,
|
||||
dump_trajectories: bool,
|
||||
align_trajectory_structures: bool,
|
||||
low_memory_mode: bool,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
transform_overrides={"diffusion_batch_size": diffusion_batch_size},
|
||||
inference_sampler_overrides={**inference_sampler},
|
||||
trainer_overrides={
|
||||
"cleanup_guideposts": cleanup_guideposts,
|
||||
"cleanup_virtual_atoms": cleanup_virtual_atoms,
|
||||
"read_sequence_from_sequence_head": read_sequence_from_sequence_head,
|
||||
"output_full_json": output_full_json,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
# save
|
||||
self.specification_overrides = dict(specification or {})
|
||||
|
||||
# Setup output directories and args
|
||||
self.global_prefix = global_prefix
|
||||
self.json_keys_subset = json_keys_subset
|
||||
self.prevalidate_inputs = prevalidate_inputs
|
||||
self.skip_existing = skip_existing
|
||||
|
||||
# Saving / other args
|
||||
self.dump_prediction_metadata_json = dump_prediction_metadata_json
|
||||
self.dump_trajectories = dump_trajectories
|
||||
self.align_trajectory_structures = align_trajectory_structures
|
||||
if not cleanup_guideposts:
|
||||
ranked_logger.warning(
|
||||
"Guideposts will not be cleaned up. This is intended for debugging purposes."
|
||||
)
|
||||
if not cleanup_virtual_atoms:
|
||||
ranked_logger.warning(
|
||||
"Virtual atoms will not be cleaned up. Some tools like MPNN may run, but outputs will not be like native structures."
|
||||
)
|
||||
|
||||
# Check which example ids already exist in the output directory
|
||||
if low_memory_mode:
|
||||
ranked_logger.info("Low memory mode enabled.")
|
||||
# HACK: Set attribute to the diffusion module
|
||||
os.environ["RFD3_LOW_MEMORY_MODE"] = "1"
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
inputs: str | PathLike | AtomArray | DesignInputSpecification,
|
||||
n_batches: int | None = None,
|
||||
out_dir: str | PathLike | None = None,
|
||||
):
|
||||
self._set_out_dir(out_dir)
|
||||
inputs = self._canonicalize_inputs(inputs)
|
||||
design_specifications = self._multiply_specifications(
|
||||
inputs=inputs,
|
||||
n_batches=n_batches,
|
||||
)
|
||||
# init before
|
||||
self.initialize()
|
||||
outputs = self._run_multi(design_specifications)
|
||||
return outputs
|
||||
|
||||
def _set_out_dir(self, out_dir: str | PathLike | None):
|
||||
out_dir = Path(out_dir) if out_dir else None
|
||||
if out_dir:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
ranked_logger.info(f"Outputs will be written to {out_dir.resolve()}.")
|
||||
self.out_dir = out_dir
|
||||
|
||||
def _run_multi(self, specs):
|
||||
# ==============================================================================
|
||||
# Prepare pipeline and inference loader
|
||||
# ==============================================================================
|
||||
loader = assemble_distributed_inference_loader_from_json(
|
||||
# Passed directly to ContigJSONDataset
|
||||
# data={spec.example_id: spec for spec in spec.values()},
|
||||
data=specs,
|
||||
transform=self.pipeline,
|
||||
name="inference-dataset",
|
||||
cif_parser_args=None,
|
||||
subset_to_keys=None,
|
||||
eval_every_n=1,
|
||||
# Sampler args
|
||||
world_size=self.trainer.fabric.world_size,
|
||||
rank=self.trainer.fabric.global_rank,
|
||||
)
|
||||
loader = self.trainer.fabric.setup_dataloaders(
|
||||
loader,
|
||||
use_distributed_sampler=False,
|
||||
)
|
||||
|
||||
# ==============================================================================
|
||||
# Evaluate, using `validation_step`
|
||||
# ==============================================================================
|
||||
outputs = {}
|
||||
for batch_idx, batch in enumerate(loader):
|
||||
pipeline_output = batch[0]
|
||||
output = self._model_forward(pipeline_output)
|
||||
|
||||
if self.out_dir:
|
||||
self.save_batch_outputs(
|
||||
out_dir=self.out_dir,
|
||||
example_id=pipeline_output["example_id"],
|
||||
network_output=output["network_output"],
|
||||
prediction_metadata=output["prediction_metadata"],
|
||||
predicted_atom_array_stack=output["predicted_atom_array_stack"],
|
||||
pipeline_output=pipeline_output,
|
||||
)
|
||||
else:
|
||||
outputs[pipeline_output["example_id"]] = {
|
||||
"network_output": output["network_output"],
|
||||
"prediction_metadata": output["prediction_metadata"],
|
||||
"predicted_atom_array_stack": output["predicted_atom_array_stack"],
|
||||
}
|
||||
return outputs
|
||||
|
||||
def _model_forward(self, pipeline_output):
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
pipeline_output = self.trainer.fabric.to_device(pipeline_output)
|
||||
output = self.trainer.validation_step(
|
||||
batch=pipeline_output,
|
||||
batch_idx=0,
|
||||
compute_metrics=False,
|
||||
)
|
||||
t_end = time.time()
|
||||
|
||||
# Add additional information to prediction metadata
|
||||
for key in output["prediction_metadata"].keys():
|
||||
ckpt = Path(self.ckpt_path)
|
||||
if ckpt.is_symlink():
|
||||
ckpt = ckpt.resolve(strict=True) # follow symlink to target
|
||||
output["prediction_metadata"][key]["ckpt_path"] = str(ckpt)
|
||||
output["prediction_metadata"][key]["seed"] = self.seed
|
||||
|
||||
ranked_logger.info(f"Finished inference batch in {t_end - t0:.2f} seconds.")
|
||||
return output
|
||||
|
||||
###############################################
|
||||
# Input merging
|
||||
###############################################
|
||||
|
||||
def _canonicalize_inputs(
|
||||
self, inputs
|
||||
) -> Dict[str, dict | DesignInputSpecification]:
|
||||
is_json_like = (isinstance(inputs, (str, PathLike, Path))) or (
|
||||
isinstance(inputs, list)
|
||||
and all([isinstance(i, (str, PathLike, Path)) for i in inputs])
|
||||
)
|
||||
is_specification_like = isinstance(inputs, DesignInputSpecification) or (
|
||||
isinstance(inputs, list)
|
||||
and all([isinstance(i, DesignInputSpecification) for i in inputs])
|
||||
)
|
||||
is_atom_array_like = isinstance(inputs, (AtomArray, list)) or (
|
||||
isinstance(inputs, list) and all([isinstance(i, AtomArray) for i in inputs])
|
||||
)
|
||||
if inputs is None:
|
||||
# Create empty specification dictionary
|
||||
return {"": {**self.specification_overrides}}
|
||||
elif is_json_like:
|
||||
# List of file paths
|
||||
inputs = process_input(
|
||||
inputs,
|
||||
json_keys_subset=self.json_keys_subset,
|
||||
global_prefix=self.global_prefix,
|
||||
specification_overrides=self.specification_overrides,
|
||||
validate=self.prevalidate_inputs,
|
||||
) # any -> Dict[Name: DesignInputSpecification]
|
||||
elif is_specification_like:
|
||||
# List of DesignInputSpecifications
|
||||
if isinstance(inputs, DesignInputSpecification):
|
||||
inputs = [inputs]
|
||||
inputs = {f"backbone_{i}": spec for i, spec in enumerate(inputs)}
|
||||
elif is_atom_array_like:
|
||||
raise NotImplementedError("AtomArray inputs not yet supported.")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid input type: {type(inputs)}. Expected JSON/YAML file paths, AtomArray, or DesignInputSpecification.\nInput: {inputs}"
|
||||
)
|
||||
|
||||
return design_specifications
|
||||
|
||||
def _multiply_specifications(
|
||||
self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
# Find existing example IDS in output directory
|
||||
if exists(self.out_dir):
|
||||
existing_example_ids = set(
|
||||
extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
|
||||
for path in find_files_with_extension(self.out_dir, CIF_LIKE_EXTENSIONS)
|
||||
)
|
||||
ranked_logger.info(
|
||||
f"Found {len(existing_example_ids)} existing example IDs in the output directory."
|
||||
)
|
||||
|
||||
# Based on inputs, construct the specifications to loop through
|
||||
design_specifications = {}
|
||||
for prefix, example_spec in inputs.items():
|
||||
# ... Create n_batches for example
|
||||
for batch_id in range((n_batches) if exists(n_batches) else 1):
|
||||
# ... Example ID
|
||||
example_id = f"{prefix}_{batch_id}" if exists(n_batches) else prefix
|
||||
|
||||
if (
|
||||
self.skip_existing
|
||||
and exists(self.out_dir)
|
||||
and example_id in existing_example_ids
|
||||
):
|
||||
ranked_logger.info(
|
||||
f"Skipping design specification for example {example_id} | Already exists."
|
||||
)
|
||||
continue
|
||||
design_specifications[example_id] = example_spec
|
||||
return design_specifications
|
||||
|
||||
def save_batch_outputs(
|
||||
self,
|
||||
*,
|
||||
out_dir,
|
||||
network_output,
|
||||
prediction_metadata,
|
||||
predicted_atom_array_stack,
|
||||
pipeline_output,
|
||||
example_id,
|
||||
):
|
||||
out_dir = Path(out_dir)
|
||||
dump_structures(
|
||||
atom_arrays=predicted_atom_array_stack,
|
||||
base_path=out_dir / example_id,
|
||||
one_model_per_file=True,
|
||||
extra_fields=SAVED_CONDITIONING_ANNOTATIONS,
|
||||
)
|
||||
|
||||
if self.dump_prediction_metadata_json:
|
||||
dump_metadata(
|
||||
prediction_metadata=prediction_metadata,
|
||||
base_path=out_dir / example_id,
|
||||
one_model_per_file=True,
|
||||
)
|
||||
|
||||
if self.dump_trajectories:
|
||||
dump_trajectories(
|
||||
trajectory_list=network_output["X_denoised_L_traj"],
|
||||
atom_array=pipeline_output["atom_array"],
|
||||
base_path=out_dir / f"{example_id}_denoised",
|
||||
align_structures=self.align_trajectory_structures,
|
||||
)
|
||||
dump_trajectories(
|
||||
trajectory_list=network_output["X_noisy_L_traj"],
|
||||
atom_array=pipeline_output["atom_array"],
|
||||
base_path=out_dir / f"{example_id}_noisy",
|
||||
align_structures=self.align_trajectory_structures,
|
||||
)
|
||||
|
||||
ranked_logger.info(
|
||||
f"Outputs for {example_id} written to {out_dir / example_id}."
|
||||
)
|
||||
|
||||
|
||||
def normalize_inputs(inputs: str | list | None) -> list[str | None]:
|
||||
"""
|
||||
inputs: str | list[str] | None
|
||||
- Can be:
|
||||
- A single path to a JSON, YAML, or regular input file (cif or pdb)
|
||||
- A comma-separated string of paths (e.g. "a.json,b.json")
|
||||
- A list of file paths
|
||||
- None or an empty list, in which case a dummy input is added (used for e.g. motif-only design)
|
||||
- Returns list of paths or [None] if no inputs are provided
|
||||
"""
|
||||
if inputs is None or (isinstance(inputs, list) and len(inputs) == 0):
|
||||
inputs = [None]
|
||||
elif isinstance(inputs, str):
|
||||
inputs = inputs.split(",")
|
||||
elif not isinstance(inputs, list):
|
||||
raise ValueError(
|
||||
f"Invalid input type: {type(inputs)}. Expected str, list, or None.\nInput: {inputs}"
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def process_input(
|
||||
inputs: str | list | None,
|
||||
json_keys_subset: str | list | None = None,
|
||||
global_prefix: str | None = None,
|
||||
specification_overrides: dict | None = None,
|
||||
validate: bool = True,
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
inputs: Any -> list[str | None] (see normalize_inputs)
|
||||
json_keys_subset: extract only subset of JSON keys. None will keep all keys
|
||||
prefix: If provided, prefix all example ids with said prefix
|
||||
|
||||
returns: Dictionaries of specifcation args pre-batching:
|
||||
{
|
||||
'jsonfile_jsonkey1': {
|
||||
**args_from_key1
|
||||
},
|
||||
'jsonfile_jsonkey2': {
|
||||
**args_from_key2
|
||||
}
|
||||
}
|
||||
"""
|
||||
specification_overrides = dict(specification_overrides or {})
|
||||
|
||||
def merge_args(example_args: dict) -> dict:
|
||||
return merge_with(lambda x: x[-1], example_args, specification_overrides)
|
||||
|
||||
inputs = normalize_inputs(inputs)
|
||||
|
||||
# If global_prefix is not provided, then default to using the basename of the JSON or YAML file (when provided)
|
||||
if global_prefix is None:
|
||||
use_json_basename_prefix = True
|
||||
else:
|
||||
use_json_basename_prefix = False
|
||||
|
||||
# ... Convert all inputs to list of inputs (e.g. if comma-separated)
|
||||
if exists(inputs) and "," in inputs:
|
||||
inputs = inputs.split(",")
|
||||
elif not exists(inputs):
|
||||
# If inputs is None or empty, we will create a dummy input
|
||||
inputs = []
|
||||
inputs = inputs if isinstance(inputs, list) else [inputs]
|
||||
if len(inputs) == 0:
|
||||
inputs = [None]
|
||||
|
||||
# ... Determine prefix of sample to create
|
||||
all_specs = {}
|
||||
for input in inputs:
|
||||
if exists(input) and (input.endswith(".json") or input.endswith(".yaml")):
|
||||
# ... Load JSON or YAML file
|
||||
with open(input, "r") as f:
|
||||
data = json.load(f) if input.endswith(".json") else yaml.safe_load(f)
|
||||
|
||||
# ... Apply any global args for this input file
|
||||
if "global_args" in data:
|
||||
global_args = data.pop("global_args")
|
||||
for example in data:
|
||||
data[example].update(global_args)
|
||||
|
||||
# ... Subset to keys
|
||||
if json_keys_subset is not None:
|
||||
json_keys_subset = (
|
||||
json_keys_subset.split(",")
|
||||
if isinstance(json_keys_subset, str)
|
||||
else json_keys_subset
|
||||
)
|
||||
data = {
|
||||
example: data[example]
|
||||
for example in json_keys_subset
|
||||
if example in data
|
||||
}
|
||||
|
||||
# ... Extract each accumulated example in data.
|
||||
for example, args in data.items():
|
||||
args = ensure_input_is_abspath(args, input)
|
||||
if use_json_basename_prefix:
|
||||
name = os.path.splitext(os.path.basename(input))[0]
|
||||
prefix = f"{name}_{example}"
|
||||
else:
|
||||
prefix = f"{global_prefix}{example}"
|
||||
args["extra"] = args.get("extra", {}) | {"example": example}
|
||||
all_specs[prefix] = dict(merge_args(args))
|
||||
|
||||
elif exists(input):
|
||||
prefix = os.path.basename(os.path.splitext(input)[0])
|
||||
if global_prefix is not None:
|
||||
prefix = f"{global_prefix}{prefix}"
|
||||
all_specs[prefix] = dict(merge_args({"input": input}))
|
||||
else:
|
||||
all_specs["backbone"] = dict(specification_overrides)
|
||||
|
||||
if validate:
|
||||
for prefix, example_spec in all_specs.items():
|
||||
ranked_logger.info(
|
||||
f"Prevalidating design specification for example: {prefix}"
|
||||
)
|
||||
DesignInputSpecification.safe_init(**example_spec)
|
||||
|
||||
return all_specs
|
||||
@@ -5,29 +5,17 @@
|
||||
import json
|
||||
import os
|
||||
import textwrap
|
||||
import time
|
||||
from os import PathLike
|
||||
from typing import Any, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import yaml
|
||||
from atomworks.io.parser import parse_atom_array
|
||||
|
||||
from atomworks.ml.datasets import MolecularDataset
|
||||
from atomworks.ml.transforms.base import Compose, Transform, TransformedDict
|
||||
from biotite.structure import BondList
|
||||
from atomworks.ml.transforms.base import Compose, Transform
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from rfd3.constants import (
|
||||
INFERENCE_ANNOTATIONS,
|
||||
REQUIRED_INFERENCE_ANNOTATIONS,
|
||||
)
|
||||
from rfd3.inference.inference_utils import ensure_input_is_abspath
|
||||
from rfd3.inference.input_parsing import (
|
||||
create_atom_array_from_design_specification,
|
||||
)
|
||||
from rfd3.transforms.conditioning_base import (
|
||||
check_has_required_conditioning_annotations,
|
||||
convert_existing_annotations_to_bool,
|
||||
DesignInputSpecification,
|
||||
)
|
||||
from rfd3.utils.inference import ensure_input_is_abspath
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
SequentialSampler,
|
||||
@@ -49,7 +37,7 @@ class ContigJsonDataset(MolecularDataset):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data: PathLike | dict,
|
||||
data: PathLike | Dict[str, dict | DesignInputSpecification],
|
||||
cif_parser_args: dict | None,
|
||||
transform: Transform | Compose | None,
|
||||
name: str | None,
|
||||
@@ -137,7 +125,7 @@ class ContigJsonDataset(MolecularDataset):
|
||||
def _check_json_keys(self):
|
||||
"""Check if the JSON keys are valid."""
|
||||
for k, data in self.data.items():
|
||||
if not isinstance(data, dict):
|
||||
if not isinstance(data, (dict, DesignInputSpecification)):
|
||||
raise ValueError("Each item in the JSON data must be a dictionary.")
|
||||
|
||||
@property
|
||||
@@ -171,77 +159,19 @@ class ContigJsonDataset(MolecularDataset):
|
||||
spec = self.data[example_id]
|
||||
|
||||
# if 'input' in metadata and not abspath, prepend the source json directory to the file path
|
||||
spec = ensure_input_is_abspath(spec, self.json_path)
|
||||
if not isinstance(spec, DesignInputSpecification):
|
||||
spec = ensure_input_is_abspath(spec, self.json_path)
|
||||
spec["cif_parser_args"] = self.cif_parser_args
|
||||
spec = DesignInputSpecification(**spec)
|
||||
|
||||
# ... Create atom array with conditioning annotations
|
||||
atom_array, spec_dict = create_atom_array_from_design_specification(
|
||||
cif_parser_args=self.cif_parser_args, **spec
|
||||
)
|
||||
# Create pipeline input
|
||||
data = spec.to_pipeline_input(example_id=example_id)
|
||||
|
||||
# ... Forward into
|
||||
data = prepare_pipeline_input_from_atom_array(atom_array)
|
||||
data["example_id"] = example_id
|
||||
|
||||
# ... Wrap up with additional features
|
||||
if "extra" not in spec_dict:
|
||||
spec_dict["extra"] = {}
|
||||
spec_dict["extra"]["example_id"] = example_id
|
||||
data["specification"] = spec_dict
|
||||
|
||||
# ... Send through pipeline
|
||||
# Apply transforms and return
|
||||
data = self.transform(data)
|
||||
return data
|
||||
|
||||
|
||||
def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers.base.load_example_from_metadata_row
|
||||
atom_array_orig,
|
||||
) -> dict:
|
||||
"""
|
||||
Load or create an example from a metadata dictionary.
|
||||
If the file path is not provided in the metadata dictionary, create a spoofed CIF file based on the length.
|
||||
Args:
|
||||
atom_array_orig: Atom array instantiated with conditioning annotations
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the parsed row data and additional loaded CIF data.
|
||||
"""
|
||||
_start_parse_time = time.time()
|
||||
# HACK: Set empty bond graph:
|
||||
if atom_array_orig.bonds is None:
|
||||
atom_array_orig.bonds = BondList(atom_array_orig.array_length())
|
||||
|
||||
# Temporary spoof of chain IDs to ensure duplicates aren't dropped:
|
||||
result_dict = parse_atom_array(
|
||||
atom_array_orig,
|
||||
remove_ccds=[],
|
||||
fix_arginines=False,
|
||||
add_missing_atoms=False,
|
||||
extra_fields=INFERENCE_ANNOTATIONS,
|
||||
build_assembly=None,
|
||||
hydrogen_policy="remove",
|
||||
)
|
||||
atom_array = result_dict["asym_unit"][0]
|
||||
|
||||
# HACK: Set iid information manually
|
||||
# We currently do not preserve this information from the input,
|
||||
# if you want these we'd need to remove the spoofing here
|
||||
check_has_required_conditioning_annotations(
|
||||
atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
|
||||
)
|
||||
atom_array = convert_existing_annotations_to_bool(atom_array)
|
||||
atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
|
||||
atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
|
||||
data = {
|
||||
"atom_array": atom_array, # First model
|
||||
"chain_info": result_dict["chain_info"],
|
||||
"ligand_info": result_dict["ligand_info"],
|
||||
"metadata": result_dict["metadata"],
|
||||
}
|
||||
_stop_parse_time = time.time()
|
||||
data = TransformedDict(data)
|
||||
return data
|
||||
|
||||
|
||||
def assemble_distributed_inference_loader_from_json(
|
||||
*, rank: int, world_size: int, **dataset_kwargs
|
||||
) -> DataLoader:
|
||||
|
||||
@@ -1,448 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import yaml
|
||||
from lightning.fabric import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
from rfd3.constants import (
|
||||
SAVED_CONDITIONING_ANNOTATIONS,
|
||||
)
|
||||
from rfd3.inference.datasets import (
|
||||
assemble_distributed_inference_loader_from_json,
|
||||
)
|
||||
from rfd3.inference.inference_utils import ensure_input_is_abspath
|
||||
from rfd3.inference.input_parsing import InputSpecification
|
||||
from toolz import merge_with
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.inference_engines.af3 import AF3InferenceEngine
|
||||
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
|
||||
from modelhub.utils.io import (
|
||||
CIF_LIKE_EXTENSIONS,
|
||||
extract_example_id_from_path,
|
||||
find_files_with_extension,
|
||||
)
|
||||
from modelhub.utils.logging import print_config_tree
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
def normalize_inputs(inputs: str | list | None) -> list[str | None]:
|
||||
"""
|
||||
inputs: str | list[str] | None
|
||||
- Can be:
|
||||
- A single path to a JSON, YAML, or regular input file (cif or pdb)
|
||||
- A comma-separated string of paths (e.g. "a.json,b.json")
|
||||
- A list of file paths
|
||||
- None or an empty list, in which case a dummy input is added (used for e.g. motif-only design)
|
||||
- Returns list of paths or [None] if no inputs are provided
|
||||
"""
|
||||
if inputs is None or (isinstance(inputs, list) and len(inputs) == 0):
|
||||
inputs = [None]
|
||||
elif isinstance(inputs, str):
|
||||
inputs = inputs.split(",")
|
||||
elif not isinstance(inputs, list):
|
||||
raise ValueError(
|
||||
f"Invalid input type: {type(inputs)}. Expected str, list, or None.\nInput: {inputs}"
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def process_input(
|
||||
inputs: str | list | None,
|
||||
json_keys_subset: str | list | None = None,
|
||||
global_prefix: str = None,
|
||||
specification_overrides: dict = {},
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
inputs: Any -> list[str | None] (see normalize_inputs)
|
||||
json_keys_subset: extract only subset of JSON keys. None will keep all keys
|
||||
prefix: If provided, prefix all example ids with said prefix
|
||||
|
||||
returns: Dictionaries of specifcation args pre-batching:
|
||||
{
|
||||
'jsonfile_jsonkey1': {
|
||||
**args_from_key1
|
||||
},
|
||||
'jsonfile_jsonkey2': {
|
||||
**args_from_key2
|
||||
}
|
||||
}
|
||||
"""
|
||||
merge_args = lambda d: merge_with(lambda x: x[-1], d, specification_overrides) # noqa
|
||||
inputs = normalize_inputs(inputs)
|
||||
|
||||
# If global_prefix is not provided, then default to using the basename of the JSON or YAML file (when provided)
|
||||
if global_prefix is None:
|
||||
use_json_basename_prefix = True
|
||||
else:
|
||||
use_json_basename_prefix = False
|
||||
|
||||
# ... Convert all inputs to list of inputs (e.g. if comma-separated)
|
||||
if exists(inputs) and "," in inputs:
|
||||
inputs = inputs.split(",")
|
||||
elif not exists(inputs):
|
||||
# If inputs is None or empty, we will create a dummy input
|
||||
inputs = []
|
||||
inputs = inputs if isinstance(inputs, list) else [inputs]
|
||||
if len(inputs) == 0:
|
||||
inputs = [None]
|
||||
|
||||
# ... Determine prefix of sample to create
|
||||
all_specs = {}
|
||||
for input in inputs:
|
||||
if exists(input) and (input.endswith(".json") or input.endswith(".yaml")):
|
||||
# ... Load JSON or YAML file
|
||||
with open(input, "r") as f:
|
||||
data = json.load(f) if input.endswith(".json") else yaml.safe_load(f)
|
||||
|
||||
# ... Apply any global args for this input file
|
||||
if "global_args" in data:
|
||||
global_args = data.pop("global_args")
|
||||
for example in data:
|
||||
data[example].update(global_args)
|
||||
|
||||
# ... Subset to keys
|
||||
if json_keys_subset is not None:
|
||||
json_keys_subset = (
|
||||
json_keys_subset.split(",")
|
||||
if isinstance(json_keys_subset, str)
|
||||
else json_keys_subset
|
||||
)
|
||||
data = {
|
||||
example: data[example]
|
||||
for example in json_keys_subset
|
||||
if example in data
|
||||
}
|
||||
|
||||
# ... Extract each accumulated example in data.
|
||||
for example, args in data.items():
|
||||
args = ensure_input_is_abspath(args, input)
|
||||
if use_json_basename_prefix:
|
||||
name = os.path.splitext(os.path.basename(input))[0]
|
||||
prefix = f"{name}_{example}"
|
||||
else:
|
||||
prefix = f"{global_prefix}{example}"
|
||||
args["extra"] = args.get("extra", {}) | {"example": example}
|
||||
all_specs[prefix] = dict(merge_args(args))
|
||||
|
||||
elif exists(input):
|
||||
prefix = os.path.basename(os.path.splitext(input)[0])
|
||||
if global_prefix is not None:
|
||||
prefix = f"{global_prefix}{prefix}"
|
||||
all_specs[prefix] = dict(merge_args({"input": input}))
|
||||
else:
|
||||
all_specs["backbone"] = specification_overrides
|
||||
|
||||
return all_specs
|
||||
|
||||
|
||||
class RFD3InferenceEngine(AF3InferenceEngine):
|
||||
"""Inference engine for RFdiffusion3"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Required args:
|
||||
out_dir: str | PathLike,
|
||||
ckpt_path: str | PathLike,
|
||||
inputs: str | PathLike,
|
||||
json_keys_subset: None | list[str],
|
||||
n_batches: int,
|
||||
*,
|
||||
# Default design specification args:
|
||||
specification: dict,
|
||||
# Base inference engine args
|
||||
diffusion_batch_size: int,
|
||||
skip_existing: bool,
|
||||
inference_sampler: dict,
|
||||
# Structure dumping arguments
|
||||
cleanup_guideposts: bool,
|
||||
cleanup_virtual_atoms: bool,
|
||||
read_sequence_from_sequence_head: bool,
|
||||
output_full_json: bool,
|
||||
dump_prediction_metadata_json: bool,
|
||||
dump_trajectories: bool,
|
||||
align_trajectory_structures: bool,
|
||||
one_model_per_file: bool,
|
||||
global_prefix: str | None,
|
||||
###############################################
|
||||
num_nodes: int,
|
||||
devices_per_node: int,
|
||||
print_config: bool,
|
||||
seed: int | None,
|
||||
temp_dir=None,
|
||||
low_memory_mode: bool = False,
|
||||
prevalidate_inputs: bool = True,
|
||||
# Atom array instantiation args collapsed into default dict:
|
||||
):
|
||||
"""
|
||||
Design specification args:
|
||||
# Main:
|
||||
inputs: JSON, YAML, PDB or CIF (comma-separated string or single) containing coordinate data to be parsed or args to override
|
||||
length: length of designed structure (int or str like '20-100'). Default: None (specified by contig string)
|
||||
contig: string of residues to use as backbone. Default: None (specified by length) Example: '10-20,A11-13,10-20'
|
||||
fixed_atoms: Dict[str] of atom names to use as indexed motif atoms with fixed coordinates. Default: None
|
||||
unindex: list of residues in input_src to unindex in design. Default: None
|
||||
|
||||
redesign_motif_sidechains: bool or str. Specifies which motif residues have fixed sidechains by default. Default: True
|
||||
ligand: str. Ligands in input file to include as motif
|
||||
atomwise_rasa: Dict[str] of atomwise rasas to use for design. Default: None
|
||||
ori_token: list of 3 floats indicating the origin to center the design around. Default: None
|
||||
seed: int or None. If None, a random seed will be sampled.
|
||||
|
||||
Additional args:
|
||||
ckpt_path: Path to checkpoint file
|
||||
n_batches: Number of samples to
|
||||
"""
|
||||
if not os.path.isabs(out_dir):
|
||||
out_dir = os.path.abspath(out_dir)
|
||||
ranked_logger.info("Using absolute path for out_dir: {}".format(out_dir))
|
||||
|
||||
# Convert input sources to design specification dictionaries
|
||||
inputs = process_input(
|
||||
inputs,
|
||||
json_keys_subset=json_keys_subset,
|
||||
global_prefix=global_prefix,
|
||||
specification_overrides=specification,
|
||||
) # any -> Dict[Name: InputSpecification]
|
||||
self.design_specifications = {}
|
||||
for prefix, example_spec in inputs.items():
|
||||
# ... Set example key as the prefix
|
||||
if prevalidate_inputs:
|
||||
ranked_logger.info(
|
||||
f"Prevalidating design specification for example: {prefix}"
|
||||
)
|
||||
InputSpecification.safe_init(**example_spec)
|
||||
|
||||
# ... Create n_batches for example
|
||||
for batch_i in range(n_batches):
|
||||
# ... Example ID
|
||||
example_id = f"{prefix}_{batch_i}"
|
||||
self.design_specifications[example_id] = example_spec
|
||||
|
||||
############################################################
|
||||
# Feed-forward inputs similar to MH-AF3 inference engine
|
||||
############################################################
|
||||
|
||||
# ... set the random seed for reproducibility (and for augmentation, e.g., for antibodies)
|
||||
if not exists(seed):
|
||||
seed = int(time.time() * 1000) % (2**31)
|
||||
ranked_logger.info(f"Seeding everything with seed={seed}...")
|
||||
seed_everything(seed, workers=True, verbose=True)
|
||||
self.seed = seed
|
||||
|
||||
# We only extract the `train_cfg` from the checkpoint initially
|
||||
self.load_and_override_ckpt_config(
|
||||
ckpt_path=ckpt_path,
|
||||
num_nodes=num_nodes,
|
||||
devices_per_node=devices_per_node,
|
||||
inference_sampler=inference_sampler,
|
||||
)
|
||||
|
||||
set_accelerator_based_on_availability(self.cfg)
|
||||
|
||||
# (b) based on the dataset (we will apply when constructing the pipeline)
|
||||
self.dataset_overrides = {
|
||||
"diffusion_batch_size": diffusion_batch_size,
|
||||
}
|
||||
|
||||
# ... instantiate the trainer with the (modified) configuration
|
||||
self.trainer = hydra.utils.instantiate(
|
||||
self.cfg.trainer,
|
||||
_convert_="partial",
|
||||
_recursive_=False,
|
||||
)
|
||||
|
||||
# Set the output directory for the CIF files (e.g., predicted structures)
|
||||
self.cif_out_dir = Path(out_dir) if out_dir else Path("./")
|
||||
|
||||
# Structure dumping
|
||||
self.dump_prediction_metadata_json = dump_prediction_metadata_json
|
||||
self.dump_trajectories = dump_trajectories
|
||||
self.one_model_per_file = one_model_per_file
|
||||
self.align_trajectory_structures = align_trajectory_structures
|
||||
|
||||
self.trainer.cleanup_virtual_atoms = cleanup_virtual_atoms
|
||||
self.trainer.cleanup_guideposts = cleanup_guideposts
|
||||
self.trainer.read_sequence_from_sequence_head = read_sequence_from_sequence_head
|
||||
self.trainer.output_full_json = output_full_json
|
||||
self.trainer.inference_sampler_overrides = inference_sampler
|
||||
self.prediction_extra_fields = SAVED_CONDITIONING_ANNOTATIONS
|
||||
self.skip_existing = skip_existing
|
||||
self.dump_predictions = True
|
||||
self.print_config = print_config
|
||||
|
||||
if not cleanup_guideposts:
|
||||
ranked_logger.warning(
|
||||
"Guideposts will not be cleaned up. This is intended for debugging purposes."
|
||||
)
|
||||
if not cleanup_virtual_atoms:
|
||||
ranked_logger.warning(
|
||||
"Virtual atoms will not be cleaned up. Some tools like MPNN may run, but outputs will not be like native structures."
|
||||
)
|
||||
|
||||
# Check which example ids already exist in the output directory
|
||||
self.existing_example_ids = set(
|
||||
extract_example_id_from_path(path, CIF_LIKE_EXTENSIONS)
|
||||
for path in find_files_with_extension(out_dir, CIF_LIKE_EXTENSIONS)
|
||||
)
|
||||
ranked_logger.info(
|
||||
f"Found {len(self.existing_example_ids)} existing example IDs in the output directory."
|
||||
)
|
||||
|
||||
if low_memory_mode:
|
||||
ranked_logger.info("Low memory mode enabled.")
|
||||
# HACK: Set attribute to the diffusion module
|
||||
os.environ["RFD3_LOW_MEMORY_MODE"] = "1"
|
||||
|
||||
def load_and_override_ckpt_config(
|
||||
self, ckpt_path, num_nodes, devices_per_node, inference_sampler
|
||||
):
|
||||
assert exists(ckpt_path), f"Checkpoint path ({ckpt_path}) not provided."
|
||||
ranked_logger.info(f"Loading checkpoint from {Path(ckpt_path).resolve()}...")
|
||||
self.ckpt_path = ckpt_path
|
||||
|
||||
self.cfg = OmegaConf.create(
|
||||
torch.load(self.ckpt_path, "cpu", weights_only=False)["train_cfg"]
|
||||
)
|
||||
|
||||
# Override specific parameters within the Hydra config:
|
||||
# (a) based on the input arguments
|
||||
self.cfg.trainer.num_nodes = num_nodes
|
||||
self.cfg.trainer.devices_per_node = devices_per_node
|
||||
for k, v in inference_sampler.items():
|
||||
if v is None:
|
||||
continue
|
||||
setattr(self.cfg.model.net.inference_sampler, k, v)
|
||||
|
||||
# Set metrics / callbacks to be null s.t. they aren't loaded
|
||||
self.cfg.trainer.metrics = None
|
||||
|
||||
# Record the random seed to be dumped in the output JSON
|
||||
self.cfg.trainer.seed = self.seed
|
||||
|
||||
def example_id_exists(self, example_id, verbose=False):
|
||||
# TODO: Move this to another file to standardize better with src
|
||||
if not self.one_model_per_file:
|
||||
# Check if one file exists
|
||||
all_exist = example_id in self.existing_example_ids
|
||||
if all_exist and verbose:
|
||||
ranked_logger.info(
|
||||
f"Model file for example {example_id} already exists in the output directory."
|
||||
)
|
||||
else:
|
||||
all_exist = all(
|
||||
[
|
||||
(f"{example_id}_model_{i}" in self.existing_example_ids)
|
||||
for i in range(self.dataset_overrides["diffusion_batch_size"])
|
||||
]
|
||||
)
|
||||
if all_exist and verbose:
|
||||
ranked_logger.info(
|
||||
f"All models for example {example_id} already exist in the output directory."
|
||||
)
|
||||
return all_exist
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Run design on a set of specifications
|
||||
"""
|
||||
if self.print_config:
|
||||
print_config_tree(
|
||||
self.cfg.model, resolve=True, title="INFERENCE MODEL CONFIGURATION"
|
||||
)
|
||||
|
||||
# ... spawn processes for distributed training, if using multiple GPUs
|
||||
ranked_logger.info(
|
||||
f"Spawning {self.trainer.fabric.world_size} processes from {self.trainer.fabric.global_rank}..."
|
||||
)
|
||||
|
||||
# ==============================================================================
|
||||
# Construct the model and load the checkpoint
|
||||
# ==============================================================================
|
||||
self.trainer.initialize_or_update_trainer_state({"train_cfg": self.cfg})
|
||||
self.trainer.construct_model()
|
||||
self.trainer.load_checkpoint(ckpt_path=self.ckpt_path, is_inference=True)
|
||||
|
||||
# Ensure optimizer isn't loaded
|
||||
self.trainer.state["optimizer"] = None
|
||||
self.trainer.state["train_cfg"].model.optimizer = None
|
||||
|
||||
self.trainer.setup_model_optimizers_and_schedulers()
|
||||
self.trainer.state["model"].eval()
|
||||
|
||||
# ==============================================================================
|
||||
# Prepare pipeline and inference loader
|
||||
# ==============================================================================
|
||||
# TODO: have name be the basename of the JSON or YAML file
|
||||
loader = assemble_distributed_inference_loader_from_json(
|
||||
# Passed directly to ContigJSONDataset
|
||||
data=self.design_specifications,
|
||||
transform=self.construct_pipeline(),
|
||||
name="inference-dataset",
|
||||
cif_parser_args={},
|
||||
subset_to_keys=None,
|
||||
eval_every_n=1,
|
||||
# Sampler args
|
||||
world_size=self.trainer.fabric.world_size,
|
||||
rank=self.trainer.fabric.global_rank,
|
||||
)
|
||||
loader = self.trainer.fabric.setup_dataloaders(
|
||||
loader,
|
||||
use_distributed_sampler=False,
|
||||
)
|
||||
|
||||
# ==============================================================================
|
||||
# Evaluate, using `validation_step``
|
||||
# ==============================================================================
|
||||
|
||||
for batch_idx, batch in enumerate(loader):
|
||||
pipeline_output = batch[0]
|
||||
example_id = pipeline_output["example_id"]
|
||||
|
||||
if self.skip_existing:
|
||||
if self.example_id_exists(example_id, verbose=True):
|
||||
ranked_logger.info(
|
||||
f"Skipping structure {batch_idx + 1}/{len(loader)}: {example_id} | Already exists."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
ranked_logger.info(
|
||||
f"Predicting structure {batch_idx + 1}/{len(loader)}: {example_id}"
|
||||
)
|
||||
|
||||
# Model inference
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
pipeline_output = self.trainer.fabric.to_device(pipeline_output)
|
||||
output = self.trainer.validation_step(
|
||||
batch=pipeline_output,
|
||||
batch_idx=0,
|
||||
compute_metrics=False,
|
||||
)
|
||||
t_end = time.time()
|
||||
|
||||
# Add additional information to prediction metadata
|
||||
for key in output["prediction_metadata"].keys():
|
||||
ckpt = Path(self.ckpt_path)
|
||||
if ckpt.is_symlink():
|
||||
ckpt = ckpt.resolve(strict=True) # follow symlink to target
|
||||
output["prediction_metadata"][key]["ckpt_path"] = str(ckpt)
|
||||
output["prediction_metadata"][key]["seed"] = self.seed
|
||||
|
||||
ranked_logger.info(f"Finished inference batch in {t_end - t0:.2f} seconds.")
|
||||
self.save_batch_outputs(
|
||||
example_id=example_id,
|
||||
network_output=output["network_output"],
|
||||
prediction_metadata=output["prediction_metadata"],
|
||||
predicted_atom_array_stack=output["predicted_atom_array_stack"],
|
||||
pipeline_output=pipeline_output,
|
||||
)
|
||||
@@ -1,11 +1,18 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from atomworks.constants import STANDARD_AA
|
||||
from atomworks.io.parser import parse_atom_array
|
||||
|
||||
# from atomworks.ml.datasets.datasets import BaseDataset
|
||||
from atomworks.ml.transforms.base import TransformedDict
|
||||
from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
)
|
||||
@@ -18,20 +25,10 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from rfd3.constants import (
|
||||
INFERENCE_ANNOTATIONS,
|
||||
REQUIRED_CONDITIONING_ANNOTATION_VALUES,
|
||||
REQUIRED_INFERENCE_ANNOTATIONS,
|
||||
)
|
||||
from rfd3.inference.components import (
|
||||
get_design_pattern_with_constraints,
|
||||
get_motif_components_and_breaks,
|
||||
)
|
||||
from rfd3.inference.inference_utils import (
|
||||
extract_ligand_array,
|
||||
inference_load_,
|
||||
set_com,
|
||||
set_common_annotations,
|
||||
set_indices,
|
||||
)
|
||||
from rfd3.inference.legacy_input_parsing import (
|
||||
create_atom_array_from_design_specification_legacy,
|
||||
)
|
||||
@@ -48,8 +45,19 @@ from rfd3.transforms.conditioning_base import (
|
||||
set_default_conditioning_annotations,
|
||||
)
|
||||
from rfd3.transforms.util_transforms import assign_types_
|
||||
from rfd3.utils.inference import (
|
||||
extract_ligand_array,
|
||||
inference_load_,
|
||||
set_com,
|
||||
set_common_annotations,
|
||||
set_indices,
|
||||
)
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.utils.components import (
|
||||
get_design_pattern_with_constraints,
|
||||
get_motif_components_and_breaks,
|
||||
)
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -62,216 +70,6 @@ logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
#################################################################################
|
||||
|
||||
|
||||
@contextmanager
|
||||
def validator_context(validator_name: str, data: dict = None):
|
||||
"""Context manager for validator execution with logging."""
|
||||
logger.debug(f"Starting validator: {validator_name}")
|
||||
try:
|
||||
yield
|
||||
logger.debug(f"✓ Completed validator: {validator_name}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"✗ Failed in validator: {validator_name}\n"
|
||||
f" Error: {str(e)}\n"
|
||||
f" Error type: {type(e).__name__}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def create_diffused_residues(n, additional_annotations=None):
|
||||
if n <= 0:
|
||||
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
|
||||
|
||||
atoms = []
|
||||
[
|
||||
atoms.extend(
|
||||
[
|
||||
struc.Atom(
|
||||
np.array([0.0, 0.0, 0.0], dtype=np.float32),
|
||||
res_name="ALA",
|
||||
res_id=idx,
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
)
|
||||
for idx in range(1, n + 1)
|
||||
]
|
||||
array = struc.array(atoms)
|
||||
array.set_annotation(
|
||||
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
|
||||
)
|
||||
array.set_annotation(
|
||||
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
|
||||
)
|
||||
array = set_default_conditioning_annotations(
|
||||
array, motif=False, additional=additional_annotations
|
||||
)
|
||||
array = set_common_annotations(array)
|
||||
return array
|
||||
|
||||
|
||||
def create_motif_residue(
|
||||
token,
|
||||
strip_sidechains_by_default: bool,
|
||||
):
|
||||
if strip_sidechains_by_default and token.res_name in STANDARD_AA:
|
||||
n_atoms = token.shape[0]
|
||||
diffuse_oxygen = False
|
||||
if n_atoms < 3:
|
||||
raise ValueError(
|
||||
f"Not enough data for {src_chain}{src_resid} in input atom array."
|
||||
)
|
||||
if n_atoms == 3:
|
||||
# Handle cases with N, CA, C only;
|
||||
token = token + create_o_atoms(token.copy())
|
||||
diffuse_oxygen = True # flag oxygen for generation
|
||||
|
||||
# Subset to the first 4 atoms (N, CA, C, O) only
|
||||
token = token[np.isin(token.atom_name, ["N", "CA", "C", "O"])]
|
||||
|
||||
# exactly N, CA, C, O but no CB. Place CB onto idealized position and conver to ALA
|
||||
# Sequence name ALA ensures the padded atoms to be diffused from the fixed backbone
|
||||
# are placed on the CB so as to not leak the identity of the residue.
|
||||
token = token + create_cb_atoms(token.copy())
|
||||
|
||||
# Sequence name must be set to ALA such that the central atom is correctly CB
|
||||
token.res_name = np.full_like(token.res_name, "ALA", dtype=token.res_name.dtype)
|
||||
token.set_annotation(
|
||||
"is_motif_atom_with_fixed_coord",
|
||||
np.where(
|
||||
np.arange(token.shape[0], dtype=int) < (4 - int(diffuse_oxygen)),
|
||||
token.is_motif_atom_with_fixed_coord,
|
||||
0,
|
||||
),
|
||||
)
|
||||
|
||||
check_has_required_conditioning_annotations(token)
|
||||
token = set_common_annotations(token)
|
||||
token.set_annotation("res_id", np.full(token.shape[0], 1)) # Reset to 1
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def accumulate_components(
|
||||
components_to_accumulate: List[Union[str, int]],
|
||||
*,
|
||||
# Tokens from input
|
||||
indexed_tokens: Dict[str, AtomArray],
|
||||
unindexed_tokens: Dict[str, AtomArray],
|
||||
# Additional parameters
|
||||
atom_array_accum=[],
|
||||
start_chain: str = "A",
|
||||
start_resid: int = 1,
|
||||
unindexed_breaks: Optional[List[bool]] = [],
|
||||
**kwargs,
|
||||
) -> AtomArray:
|
||||
# ... Create list of components
|
||||
assert (
|
||||
x := (set(list(indexed_tokens.keys()) + list(unindexed_tokens.keys())))
|
||||
).issubset(
|
||||
(y := set(components_to_accumulate))
|
||||
), "Unindexed and indexed set {} is not subset of components to accumulate {}".format(
|
||||
x, y
|
||||
)
|
||||
all_tokens = indexed_tokens | unindexed_tokens
|
||||
all_annots = []
|
||||
[
|
||||
all_annots.extend(list(tok.get_annotation_categories()))
|
||||
for tok in all_tokens.values()
|
||||
]
|
||||
all_annots = set(all_annots)
|
||||
|
||||
# ... For-loop accum variables
|
||||
unindexed_components_started = (
|
||||
False # once one unindexed component is added, stop adding diffused residues
|
||||
)
|
||||
chain = start_chain
|
||||
res_id = start_resid
|
||||
molecule_id = 0
|
||||
|
||||
# ... Insert contig information one- by one-
|
||||
assert len(components_to_accumulate) == len(
|
||||
unindexed_breaks
|
||||
), "Mismatch in number of components to accumulate and breaks"
|
||||
for component, is_break in zip(components_to_accumulate, unindexed_breaks):
|
||||
if component == "/0":
|
||||
# Reset iterators on next chain
|
||||
chain = chr(ord(chain) + 1)
|
||||
molecule_id += 1
|
||||
res_id = 1
|
||||
continue
|
||||
|
||||
# ... Create array to insert
|
||||
if str(component)[0].isalpha(): # motif (e.g. "A22")
|
||||
n = 1
|
||||
|
||||
# ... Fetch the motif residue
|
||||
token = all_tokens[component]
|
||||
|
||||
# ... Insert breakpoint when break clause is met
|
||||
if exists(is_break) and is_break:
|
||||
if not unindexed_components_started:
|
||||
chain = start_chain
|
||||
unindexed_components_started = True
|
||||
token.set_annotation(
|
||||
"is_motif_atom_unindexed_motif_breakpoint",
|
||||
np.ones(token.shape[0], dtype=int),
|
||||
)
|
||||
else:
|
||||
token.set_annotation(
|
||||
"is_motif_atom_unindexed_motif_breakpoint",
|
||||
np.zeros(token.shape[0], dtype=int),
|
||||
)
|
||||
else:
|
||||
n = int(component)
|
||||
# ... Skip if none or unindexed
|
||||
if n == 0 or unindexed_components_started:
|
||||
res_id += n
|
||||
continue
|
||||
|
||||
# ... Create diffused residues
|
||||
token = create_diffused_residues(n, all_annots)
|
||||
|
||||
# ... Set index of insertion
|
||||
token = set_indices(
|
||||
array=token,
|
||||
chain=chain,
|
||||
res_id_start=res_id,
|
||||
molecule_id=molecule_id,
|
||||
component=component,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(get_token_starts(token)) == n
|
||||
), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
|
||||
|
||||
# ... Insert & Increment residue ID
|
||||
atom_array_accum.append(token)
|
||||
res_id += n
|
||||
|
||||
# ... Concatenate all components
|
||||
atom_array_accum = struc.concatenate(atom_array_accum)
|
||||
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
|
||||
|
||||
# Reset res_id for unindexed residues to avoid duplicates
|
||||
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
):
|
||||
max_id = np.max(
|
||||
atom_array_accum[
|
||||
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
].res_id
|
||||
)
|
||||
atom_array_accum.res_id[
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
] += max_id + 1
|
||||
|
||||
# ... Bonds
|
||||
if atom_array_accum.bonds is None:
|
||||
atom_array_accum.bonds = BondList(atom_array_accum.array_length())
|
||||
return atom_array_accum
|
||||
|
||||
|
||||
class LegacySpecification(BaseModel):
|
||||
"""Legacy specification for compatibility with legacy input parsing."""
|
||||
|
||||
@@ -280,20 +78,34 @@ class LegacySpecification(BaseModel):
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
def build(self):
|
||||
def build(self, *args, **kwargs):
|
||||
"""Build atom array using legacy input parsing."""
|
||||
atom_array = create_atom_array_from_design_specification_legacy(
|
||||
design_specification=self.model_dump(),
|
||||
)
|
||||
return atom_array, self.model_dump()
|
||||
|
||||
def to_pipeline_input(self, example_id):
|
||||
atom_array, spec_dict = self.build(return_metadata=True)
|
||||
|
||||
# ... Forward into
|
||||
data = prepare_pipeline_input_from_atom_array(atom_array)
|
||||
data["example_id"] = example_id
|
||||
|
||||
# ... Wrap up with additional features
|
||||
if "extra" not in spec_dict:
|
||||
spec_dict["extra"] = {}
|
||||
spec_dict["extra"]["example_id"] = example_id
|
||||
data["specification"] = spec_dict
|
||||
return data
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# Input specification
|
||||
# ========================================================================
|
||||
|
||||
|
||||
class InputSpecification(BaseModel):
|
||||
class DesignInputSpecification(BaseModel):
|
||||
"""Validated and parsed input specification before resolution."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -932,12 +744,75 @@ class InputSpecification(BaseModel):
|
||||
else:
|
||||
return cls(**spec_kwargs)
|
||||
|
||||
def to_pipeline_input(self, example_id):
|
||||
atom_array, spec_dict = self.build(return_metadata=True)
|
||||
|
||||
# ... Forward into
|
||||
data = prepare_pipeline_input_from_atom_array(atom_array)
|
||||
data["example_id"] = example_id
|
||||
|
||||
# ... Wrap up with additional features
|
||||
if "extra" not in spec_dict:
|
||||
spec_dict["extra"] = {}
|
||||
spec_dict["extra"]["example_id"] = example_id
|
||||
data["specification"] = spec_dict
|
||||
return data
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Public API
|
||||
# APIs and utils
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers.base.load_example_from_metadata_row
|
||||
atom_array_orig,
|
||||
) -> dict:
|
||||
"""
|
||||
Load or create an example from a metadata dictionary.
|
||||
If the file path is not provided in the metadata dictionary, create a spoofed CIF file based on the length.
|
||||
Args:
|
||||
atom_array_orig: Atom array instantiated with conditioning annotations
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the parsed row data and additional loaded CIF data.
|
||||
"""
|
||||
_start_parse_time = time.time()
|
||||
# HACK: Set empty bond graph:
|
||||
if atom_array_orig.bonds is None:
|
||||
atom_array_orig.bonds = BondList(atom_array_orig.array_length())
|
||||
|
||||
# Temporary spoof of chain IDs to ensure duplicates aren't dropped:
|
||||
result_dict = parse_atom_array(
|
||||
atom_array_orig,
|
||||
remove_ccds=[],
|
||||
fix_arginines=False,
|
||||
add_missing_atoms=False,
|
||||
extra_fields=INFERENCE_ANNOTATIONS,
|
||||
build_assembly=None,
|
||||
hydrogen_policy="remove",
|
||||
)
|
||||
atom_array = result_dict["asym_unit"][0]
|
||||
|
||||
# HACK: Set iid information manually
|
||||
# We currently do not preserve this information from the input,
|
||||
# if you want these we'd need to remove the spoofing here
|
||||
check_has_required_conditioning_annotations(
|
||||
atom_array, required=REQUIRED_INFERENCE_ANNOTATIONS
|
||||
)
|
||||
atom_array = convert_existing_annotations_to_bool(atom_array)
|
||||
atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
|
||||
atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
|
||||
data = {
|
||||
"atom_array": atom_array, # First model
|
||||
"chain_info": result_dict["chain_info"],
|
||||
"ligand_info": result_dict["ligand_info"],
|
||||
"metadata": result_dict["metadata"],
|
||||
}
|
||||
_stop_parse_time = time.time()
|
||||
data = TransformedDict(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_atom_array_from_design_specification(
|
||||
**spec_kwargs,
|
||||
) -> tuple[AtomArray, dict]:
|
||||
@@ -952,6 +827,216 @@ def create_atom_array_from_design_specification(
|
||||
return atom_array, {}
|
||||
|
||||
# Create input specfication and build
|
||||
spec = InputSpecification(**spec_kwargs)
|
||||
spec = DesignInputSpecification(**spec_kwargs)
|
||||
atom_array, metadata = spec.build(return_metadata=True)
|
||||
return atom_array, metadata
|
||||
|
||||
|
||||
@contextmanager
|
||||
def validator_context(validator_name: str, data: dict = None):
|
||||
"""Context manager for validator execution with logging."""
|
||||
logger.debug(f"Starting validator: {validator_name}")
|
||||
try:
|
||||
yield
|
||||
logger.debug(f"✓ Completed validator: {validator_name}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"✗ Failed in validator: {validator_name}\n"
|
||||
f" Error: {str(e)}\n"
|
||||
f" Error type: {type(e).__name__}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def create_diffused_residues(n, additional_annotations=None):
|
||||
if n <= 0:
|
||||
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
|
||||
|
||||
atoms = []
|
||||
[
|
||||
atoms.extend(
|
||||
[
|
||||
struc.Atom(
|
||||
np.array([0.0, 0.0, 0.0], dtype=np.float32),
|
||||
res_name="ALA",
|
||||
res_id=idx,
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
)
|
||||
for idx in range(1, n + 1)
|
||||
]
|
||||
array = struc.array(atoms)
|
||||
array.set_annotation(
|
||||
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
|
||||
)
|
||||
array.set_annotation(
|
||||
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
|
||||
)
|
||||
array = set_default_conditioning_annotations(
|
||||
array, motif=False, additional=additional_annotations
|
||||
)
|
||||
array = set_common_annotations(array)
|
||||
return array
|
||||
|
||||
|
||||
def create_motif_residue(
|
||||
token,
|
||||
strip_sidechains_by_default: bool,
|
||||
):
|
||||
if strip_sidechains_by_default and token.res_name in STANDARD_AA:
|
||||
n_atoms = token.shape[0]
|
||||
diffuse_oxygen = False
|
||||
if n_atoms < 3:
|
||||
raise ValueError(
|
||||
f"Not enough data for {src_chain}{src_resid} in input atom array."
|
||||
)
|
||||
if n_atoms == 3:
|
||||
# Handle cases with N, CA, C only;
|
||||
token = token + create_o_atoms(token.copy())
|
||||
diffuse_oxygen = True # flag oxygen for generation
|
||||
|
||||
# Subset to the first 4 atoms (N, CA, C, O) only
|
||||
token = token[np.isin(token.atom_name, ["N", "CA", "C", "O"])]
|
||||
|
||||
# exactly N, CA, C, O but no CB. Place CB onto idealized position and conver to ALA
|
||||
# Sequence name ALA ensures the padded atoms to be diffused from the fixed backbone
|
||||
# are placed on the CB so as to not leak the identity of the residue.
|
||||
token = token + create_cb_atoms(token.copy())
|
||||
|
||||
# Sequence name must be set to ALA such that the central atom is correctly CB
|
||||
token.res_name = np.full_like(token.res_name, "ALA", dtype=token.res_name.dtype)
|
||||
token.set_annotation(
|
||||
"is_motif_atom_with_fixed_coord",
|
||||
np.where(
|
||||
np.arange(token.shape[0], dtype=int) < (4 - int(diffuse_oxygen)),
|
||||
token.is_motif_atom_with_fixed_coord,
|
||||
0,
|
||||
),
|
||||
)
|
||||
|
||||
check_has_required_conditioning_annotations(token)
|
||||
token = set_common_annotations(token)
|
||||
token.set_annotation("res_id", np.full(token.shape[0], 1)) # Reset to 1
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def accumulate_components(
|
||||
components_to_accumulate: List[Union[str, int]],
|
||||
*,
|
||||
# Tokens from input
|
||||
indexed_tokens: Dict[str, AtomArray],
|
||||
unindexed_tokens: Dict[str, AtomArray],
|
||||
# Additional parameters
|
||||
atom_array_accum=[],
|
||||
start_chain: str = "A",
|
||||
start_resid: int = 1,
|
||||
unindexed_breaks: Optional[List[bool]] = [],
|
||||
**kwargs,
|
||||
) -> AtomArray:
|
||||
# ... Create list of components
|
||||
assert (
|
||||
x := (set(list(indexed_tokens.keys()) + list(unindexed_tokens.keys())))
|
||||
).issubset(
|
||||
(y := set(components_to_accumulate))
|
||||
), "Unindexed and indexed set {} is not subset of components to accumulate {}".format(
|
||||
x, y
|
||||
)
|
||||
all_tokens = indexed_tokens | unindexed_tokens
|
||||
all_annots = []
|
||||
[
|
||||
all_annots.extend(list(tok.get_annotation_categories()))
|
||||
for tok in all_tokens.values()
|
||||
]
|
||||
all_annots = set(all_annots)
|
||||
|
||||
# ... For-loop accum variables
|
||||
unindexed_components_started = (
|
||||
False # once one unindexed component is added, stop adding diffused residues
|
||||
)
|
||||
chain = start_chain
|
||||
res_id = start_resid
|
||||
molecule_id = 0
|
||||
|
||||
# ... Insert contig information one- by one-
|
||||
assert len(components_to_accumulate) == len(
|
||||
unindexed_breaks
|
||||
), "Mismatch in number of components to accumulate and breaks"
|
||||
for component, is_break in zip(components_to_accumulate, unindexed_breaks):
|
||||
if component == "/0":
|
||||
# Reset iterators on next chain
|
||||
chain = chr(ord(chain) + 1)
|
||||
molecule_id += 1
|
||||
res_id = 1
|
||||
continue
|
||||
|
||||
# ... Create array to insert
|
||||
if str(component)[0].isalpha(): # motif (e.g. "A22")
|
||||
n = 1
|
||||
|
||||
# ... Fetch the motif residue
|
||||
token = all_tokens[component]
|
||||
|
||||
# ... Insert breakpoint when break clause is met
|
||||
if exists(is_break) and is_break:
|
||||
if not unindexed_components_started:
|
||||
chain = start_chain
|
||||
unindexed_components_started = True
|
||||
token.set_annotation(
|
||||
"is_motif_atom_unindexed_motif_breakpoint",
|
||||
np.ones(token.shape[0], dtype=int),
|
||||
)
|
||||
else:
|
||||
token.set_annotation(
|
||||
"is_motif_atom_unindexed_motif_breakpoint",
|
||||
np.zeros(token.shape[0], dtype=int),
|
||||
)
|
||||
else:
|
||||
n = int(component)
|
||||
# ... Skip if none or unindexed
|
||||
if n == 0 or unindexed_components_started:
|
||||
res_id += n
|
||||
continue
|
||||
|
||||
# ... Create diffused residues
|
||||
token = create_diffused_residues(n, all_annots)
|
||||
|
||||
# ... Set index of insertion
|
||||
token = set_indices(
|
||||
array=token,
|
||||
chain=chain,
|
||||
res_id_start=res_id,
|
||||
molecule_id=molecule_id,
|
||||
component=component,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(get_token_starts(token)) == n
|
||||
), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
|
||||
|
||||
# ... Insert & Increment residue ID
|
||||
atom_array_accum.append(token)
|
||||
res_id += n
|
||||
|
||||
# ... Concatenate all components
|
||||
atom_array_accum = struc.concatenate(atom_array_accum)
|
||||
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
|
||||
|
||||
# Reset res_id for unindexed residues to avoid duplicates
|
||||
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
):
|
||||
max_id = np.max(
|
||||
atom_array_accum[
|
||||
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
].res_id
|
||||
)
|
||||
atom_array_accum.res_id[
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
] += max_id + 1
|
||||
|
||||
# ... Bonds
|
||||
if atom_array_accum.bonds is None:
|
||||
atom_array_accum.bonds = BondList(atom_array_accum.array_length())
|
||||
return atom_array_accum
|
||||
|
||||
@@ -17,23 +17,6 @@ from rfd3.constants import (
|
||||
OPTIONAL_CONDITIONING_VALUES,
|
||||
REQUIRED_INFERENCE_ANNOTATIONS,
|
||||
)
|
||||
from rfd3.inference.components import (
|
||||
fetch_mask_from_component,
|
||||
fetch_mask_from_idx,
|
||||
get_design_pattern_with_constraints,
|
||||
get_motif_components_and_breaks,
|
||||
get_name_mask,
|
||||
split_contig,
|
||||
)
|
||||
from rfd3.inference.inference_utils import (
|
||||
create_cb_atoms,
|
||||
create_o_atoms,
|
||||
extract_ligand_array,
|
||||
inference_load_,
|
||||
set_com,
|
||||
set_common_annotations,
|
||||
set_indices,
|
||||
)
|
||||
from rfd3.inference.symmetry.symmetry_utils import (
|
||||
center_symmetric_src_atom_array,
|
||||
make_symmetric_atom_array,
|
||||
@@ -45,8 +28,25 @@ from rfd3.transforms.conditioning_base import (
|
||||
set_default_conditioning_annotations,
|
||||
)
|
||||
from rfd3.transforms.util_transforms import assign_types_
|
||||
from rfd3.utils.inference import (
|
||||
create_cb_atoms,
|
||||
create_o_atoms,
|
||||
extract_ligand_array,
|
||||
inference_load_,
|
||||
set_com,
|
||||
set_common_annotations,
|
||||
set_indices,
|
||||
)
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.utils.components import (
|
||||
fetch_mask_from_component,
|
||||
fetch_mask_from_idx,
|
||||
get_design_pattern_with_constraints,
|
||||
get_motif_components_and_breaks,
|
||||
get_name_mask,
|
||||
split_contig,
|
||||
)
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@@ -9,7 +9,8 @@ from pydantic import (
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from rfd3.inference.components import (
|
||||
|
||||
from modelhub.utils.components import (
|
||||
ComponentStr,
|
||||
fetch_mask_from_idx,
|
||||
get_name_mask,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
from rfd3.inference.components import fetch_mask_from_idx
|
||||
|
||||
from modelhub.utils.components import fetch_mask_from_idx
|
||||
|
||||
|
||||
def expand_contig_to_resid_from_string(contig_string):
|
||||
|
||||
@@ -8,7 +8,6 @@ from pydantic import (
|
||||
ConfigDict,
|
||||
Field,
|
||||
)
|
||||
from rfd3.inference.components import fetch_mask_from_component
|
||||
from rfd3.inference.symmetry.atom_array import (
|
||||
FIXED_ENTITY_ID,
|
||||
FIXED_TRANSFORM_ID,
|
||||
@@ -33,6 +32,7 @@ from rfd3.inference.symmetry.frames import (
|
||||
)
|
||||
from rfd3.transforms.conditioning_base import get_motif_features
|
||||
|
||||
from modelhub.utils.components import fetch_mask_from_component
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
@@ -4,30 +4,6 @@ import torch.nn as nn
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
|
||||
|
||||
class Loss(nn.Module):
|
||||
def __init__(self, **losses):
|
||||
super().__init__()
|
||||
self.to_compute = []
|
||||
for loss_name, loss in losses.items():
|
||||
loss_fn = hydra.utils.instantiate(loss)
|
||||
print(f"Adding loss {loss_name} to the loss function")
|
||||
self.to_compute.append(loss_fn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
network_input,
|
||||
network_output,
|
||||
loss_input,
|
||||
):
|
||||
loss_dict = {}
|
||||
loss = 0
|
||||
for loss_fn in self.to_compute:
|
||||
loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
|
||||
loss += loss_
|
||||
loss_dict.update(loss_dict_)
|
||||
loss_dict["total_loss"] = loss.detach()
|
||||
return loss, loss_dict
|
||||
|
||||
class SequenceLoss(nn.Module):
|
||||
def __init__(self, weight, min_t=0, max_t=torch.inf):
|
||||
super().__init__()
|
||||
@@ -80,7 +56,7 @@ class SequenceLoss(nn.Module):
|
||||
return self.weight * token_loss, outs
|
||||
|
||||
|
||||
class SimpleDiffusionLoss(nn.Module):
|
||||
class DiffusionLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -94,8 +70,7 @@ class SimpleDiffusionLoss(nn.Module):
|
||||
unindexed_t_alpha=1.0,
|
||||
unindexed_norm_p=1.0,
|
||||
lp_weight=0.0,
|
||||
normalize_virtual_atom_weight=False,
|
||||
alpha_normalized_virtual_atom_weight=3.9,
|
||||
**_, # dump args from old configs
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
@@ -108,8 +83,6 @@ class SimpleDiffusionLoss(nn.Module):
|
||||
self.unindexed_t_alpha = unindexed_t_alpha
|
||||
self.lp_weight = lp_weight
|
||||
self.alpha_ligand = alpha_ligand
|
||||
self.normalize_virtual_atom_weight = normalize_virtual_atom_weight
|
||||
self.alpha_normalized_virtual_atom_weight = alpha_normalized_virtual_atom_weight
|
||||
self.alpha_polar_residues = alpha_polar_residues
|
||||
|
||||
self.get_lambda = (
|
||||
@@ -123,18 +96,10 @@ class SimpleDiffusionLoss(nn.Module):
|
||||
crd_mask_L = loss_input["crd_mask_L"] # (D, L)
|
||||
crd_mask_L = crd_mask_L.unsqueeze(0).expand(D, -1)
|
||||
tok_idx = network_input["f"]["atom_to_token_map"]
|
||||
is_motif_token = network_input["f"]["is_motif_token"] # N
|
||||
t = network_input["t"] # (D,)
|
||||
is_original_unindexed_token = loss_input["is_original_unindexed_token"][tok_idx]
|
||||
is_polar_atom = network_input["f"]["is_polar"][tok_idx]
|
||||
is_ligand = network_input["f"]["is_ligand"][tok_idx]
|
||||
|
||||
# Treat fully fixed atoms as non-lossable atoms to provide stable normalization
|
||||
# is_motif_atom_with_fixed_coord = network_input["f"][
|
||||
# "is_motif_atom_with_fixed_coord"
|
||||
# ]
|
||||
# crd_mask_L = crd_mask_L * ~is_motif_atom_with_fixed_coord[None]
|
||||
|
||||
is_virtual_atom = network_input["f"]["is_virtual"] # L
|
||||
is_sidechain_atom = network_input["f"]["is_sidechain"] # L
|
||||
is_sidechain_atom = is_sidechain_atom & ~is_virtual_atom
|
||||
@@ -148,29 +113,6 @@ class SimpleDiffusionLoss(nn.Module):
|
||||
|
||||
# Upweight polar residues
|
||||
w_L[is_polar_atom] *= self.alpha_polar_residues
|
||||
|
||||
if self.normalize_virtual_atom_weight:
|
||||
# Divide by the number of virtual atoms within a token
|
||||
n_virtual_atoms_per_token = torch.zeros_like(is_motif_token).float()
|
||||
n_virtual_atoms_per_token.scatter_add_(
|
||||
0, tok_idx.long(), is_virtual_atom.float()
|
||||
)
|
||||
n_virtual_atoms_per_token = n_virtual_atoms_per_token.clamp(min=1)
|
||||
|
||||
# Also divide by the number of heavy atoms within a token
|
||||
n_sc_atoms_per_token = torch.zeros_like(is_motif_token).float()
|
||||
n_sc_atoms_per_token.scatter_add_(
|
||||
0, tok_idx.long(), is_sidechain_atom.float()
|
||||
)
|
||||
n_sc_atoms_per_token = n_sc_atoms_per_token.clamp(min=1)
|
||||
|
||||
w_L[is_virtual_atom] /= (
|
||||
self.alpha_normalized_virtual_atom_weight
|
||||
) * n_virtual_atoms_per_token[tok_idx][is_virtual_atom]
|
||||
w_L[is_sidechain_atom] /= (
|
||||
10 - self.alpha_normalized_virtual_atom_weight
|
||||
) * n_sc_atoms_per_token[tok_idx][is_sidechain_atom]
|
||||
|
||||
w_L = w_L[None].expand(D, -1) * crd_mask_L
|
||||
|
||||
X_gt_L = torch.nan_to_num(loss_input["X_gt_L_in_input_frame"])
|
||||
|
||||
@@ -3,7 +3,7 @@ import numpy as np
|
||||
from biotite.structure.info import residue
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from modelhub.metrics.base import Metric
|
||||
from modelhub.metrics.metric import Metric
|
||||
|
||||
|
||||
def collapsing_virtual_atoms_batched(
|
||||
|
||||
105
models/rfd3/src/rfd3/model/RFD3.py
Normal file
105
models/rfd3/src/rfd3/model/RFD3.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import os
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from rfd3.model.cfg_utils import (
|
||||
strip_f,
|
||||
)
|
||||
from rfd3.model.inference_sampler import ConditionalDiffusionSampler
|
||||
from rfd3.model.layers.encoders import TokenInitializer
|
||||
from torch import nn
|
||||
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
class RFD3(nn.Module):
|
||||
"""
|
||||
Simplified model for generation
|
||||
This module level serves to wrap the diffusion module of AF3
|
||||
to be roughly equivalent to the AF3 model w/o trunk processing.
|
||||
|
||||
Allows the same sampler to be used
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# Channel dimensions ('global' features)
|
||||
c_s: int,
|
||||
c_z: int,
|
||||
c_atom: int,
|
||||
c_atompair: int,
|
||||
# Arguments for modules that will be instantiated
|
||||
token_initializer: DictConfig | dict,
|
||||
diffusion_module: DictConfig | dict,
|
||||
inference_sampler: DictConfig | dict,
|
||||
**_,
|
||||
):
|
||||
super().__init__()
|
||||
# Check for chunked P_LL mode via environment variable
|
||||
use_chunked_pll = os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"
|
||||
ranked_logger.info(f"RFD3 initialized with chunked_pll={use_chunked_pll}")
|
||||
|
||||
# Simple constant-feature initializer
|
||||
self.token_initializer = TokenInitializer(
|
||||
c_s=c_s,
|
||||
c_z=c_z,
|
||||
c_atom=c_atom,
|
||||
c_atompair=c_atompair,
|
||||
use_chunked_pll=use_chunked_pll,
|
||||
**token_initializer,
|
||||
)
|
||||
|
||||
# Diffusion module instantiated to allow for config scripting
|
||||
self.diffusion_module = hydra.utils.instantiate(
|
||||
diffusion_module, c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z
|
||||
)
|
||||
|
||||
self.use_classifier_free_guidance = (
|
||||
inference_sampler["use_classifier_free_guidance"]
|
||||
and inference_sampler["cfg_scale"] != 1.0
|
||||
)
|
||||
self.cfg_features = inference_sampler.pop("cfg_features", [])
|
||||
|
||||
# ... initialize the inference sampler, which performs a full diffusion rollout during inference
|
||||
self.inference_sampler = ConditionalDiffusionSampler(**inference_sampler)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: dict,
|
||||
coord_atom_lvl_to_be_noised: torch.Tensor = None,
|
||||
n_cycle=None,
|
||||
**_,
|
||||
) -> dict:
|
||||
initializer_outputs = self.token_initializer(input["f"])
|
||||
|
||||
if self.training:
|
||||
# Single denoising step
|
||||
return self.diffusion_module(
|
||||
X_noisy_L=input["X_noisy_L"],
|
||||
t=input["t"],
|
||||
f=input["f"],
|
||||
n_recycle=n_cycle,
|
||||
**initializer_outputs,
|
||||
) # [D, L, 3]
|
||||
else:
|
||||
if self.use_classifier_free_guidance:
|
||||
f_ref = strip_f(input["f"], self.cfg_features)
|
||||
ref_initializer_outputs = self.token_initializer(f_ref)
|
||||
else:
|
||||
f_ref = None
|
||||
ref_initializer_outputs = None
|
||||
|
||||
return self.inference_sampler.sample_diffusion_like_af3(
|
||||
f=input["f"],
|
||||
f_ref=f_ref, # for cfg
|
||||
diffusion_module=self.diffusion_module,
|
||||
diffusion_batch_size=coord_atom_lvl_to_be_noised.shape[0],
|
||||
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
|
||||
# Forwarded as **kwargs:
|
||||
initializer_outputs=initializer_outputs,
|
||||
ref_initializer_outputs=ref_initializer_outputs, # for cfg
|
||||
)
|
||||
@@ -5,11 +5,11 @@ from contextlib import ExitStack
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rfd3.model.block_utils import (
|
||||
from rfd3.model.layers.block_utils import (
|
||||
bucketize_scaled_distogram,
|
||||
create_attention_indices,
|
||||
)
|
||||
from rfd3.model.blocks import (
|
||||
from rfd3.model.layers.blocks import (
|
||||
CompactStreamingDecoder,
|
||||
Downcast,
|
||||
LinearEmbedWithPool,
|
||||
@@ -17,17 +17,14 @@ from rfd3.model.blocks import (
|
||||
LocalAtomTransformer,
|
||||
LocalTokenTransformer,
|
||||
)
|
||||
from rfd3.model.encoders import (
|
||||
from rfd3.model.layers.encoders import (
|
||||
DiffusionTokenEncoder,
|
||||
)
|
||||
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
|
||||
|
||||
from modelhub.model.AF3_structure import (
|
||||
from modelhub.model.layers.blocks import (
|
||||
FourierEmbedding,
|
||||
)
|
||||
from modelhub.model.layers.af3_diffusion_transformer import (
|
||||
DiffusionTransformer,
|
||||
)
|
||||
from modelhub.model.layers.layer_utils import RMSNorm, linearNoBias
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,7 +50,7 @@ class RFD3DiffusionModule(nn.Module):
|
||||
atom_attention_decoder,
|
||||
# upcast,
|
||||
downcast,
|
||||
use_local_token_attention=False,
|
||||
use_local_token_attention=True,
|
||||
**_,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -111,20 +108,12 @@ class RFD3DiffusionModule(nn.Module):
|
||||
**diffusion_token_encoder,
|
||||
)
|
||||
|
||||
if not use_local_token_attention:
|
||||
self.diffusion_transformer = DiffusionTransformer(
|
||||
c_token=c_token,
|
||||
c_tokenpair=c_z,
|
||||
c_s=c_s,
|
||||
**diffusion_transformer,
|
||||
)
|
||||
else:
|
||||
self.diffusion_transformer = LocalTokenTransformer(
|
||||
c_token=c_token,
|
||||
c_tokenpair=c_z,
|
||||
c_s=c_s,
|
||||
**diffusion_transformer,
|
||||
)
|
||||
self.diffusion_transformer = LocalTokenTransformer(
|
||||
c_token=c_token,
|
||||
c_tokenpair=c_z,
|
||||
c_s=c_s,
|
||||
**diffusion_transformer,
|
||||
)
|
||||
|
||||
self.decoder = CompactStreamingDecoder(
|
||||
c_atom=c_atom,
|
||||
@@ -342,21 +331,18 @@ class RFD3DiffusionModule(nn.Module):
|
||||
)
|
||||
|
||||
# ... Diffusion transformer
|
||||
if not self.use_local_token_attention:
|
||||
A_I = self.diffusion_transformer(A_I, S_I, Z_II, Beta_II=None)
|
||||
else:
|
||||
A_I = self.diffusion_transformer(
|
||||
A_I,
|
||||
S_I,
|
||||
Z_II,
|
||||
f=f,
|
||||
X_L=(
|
||||
X_noisy_L[..., f["is_ca"], :]
|
||||
if X_L_self is None
|
||||
else X_L_self[..., f["is_ca"], :]
|
||||
),
|
||||
full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"),
|
||||
)
|
||||
A_I = self.diffusion_transformer(
|
||||
A_I,
|
||||
S_I,
|
||||
Z_II,
|
||||
f=f,
|
||||
X_L=(
|
||||
X_noisy_L[..., f["is_ca"], :]
|
||||
if X_L_self is None
|
||||
else X_L_self[..., f["is_ca"], :]
|
||||
),
|
||||
full=not (os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"),
|
||||
)
|
||||
|
||||
# ... Decoder readout
|
||||
# Check if using chunked P_LL mode
|
||||
@@ -1,12 +0,0 @@
|
||||
from rfd3.model.encoders import SimpleRecycler
|
||||
|
||||
from modelhub.model.AF3 import AF3
|
||||
|
||||
|
||||
class AF3DesignTrunk(AF3):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.recycler = SimpleRecycler(
|
||||
c_s=kwargs["c_s"], c_z=kwargs["c_z"], **kwargs["recycler"]
|
||||
)
|
||||
self.distogram_head = None
|
||||
@@ -1,143 +1,20 @@
|
||||
import inspect
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
from beartype.typing import Any
|
||||
from jaxtyping import Float
|
||||
from omegaconf import DictConfig
|
||||
from rfd3.inference.symmetry.symmetry_utils import (
|
||||
apply_symmetry_to_xyz_atomwise,
|
||||
)
|
||||
from rfd3.model.cfg_utils import (
|
||||
strip_f,
|
||||
strip_X,
|
||||
)
|
||||
from rfd3.model.encoders import TokenInitializer
|
||||
from torch import nn
|
||||
|
||||
from modelhub import SWAP_LAYER_NORM_FOR_RMS_NORM
|
||||
from modelhub.alignment import weighted_rigid_align
|
||||
from modelhub.common import exists
|
||||
from modelhub.data.rotation_augmentation import (
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.rotation_augmentation import (
|
||||
rot_vec_mul,
|
||||
uniform_random_rotation,
|
||||
)
|
||||
from modelhub.diffusion_samplers.inference_sampler import SampleDiffusion
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
class RFD3(nn.Module):
|
||||
"""
|
||||
Simplified model for generation
|
||||
This module level serves to wrap the diffusion module of AF3
|
||||
to be roughly equivalent to the AF3 model w/o trunk processing.
|
||||
|
||||
Allows the same sampler to be used
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
# Channel dimensions ('global' features)
|
||||
c_s: int,
|
||||
c_z: int,
|
||||
c_atom: int,
|
||||
c_atompair: int,
|
||||
# Arguments for modules that will be instantiated
|
||||
token_initializer: DictConfig | dict,
|
||||
diffusion_module: DictConfig | dict,
|
||||
inference_sampler: DictConfig | dict,
|
||||
**_,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Register whether the model uses RMSNorms or LayerNorms
|
||||
self.register_buffer(
|
||||
"use_rmsnorm",
|
||||
torch.tensor(SWAP_LAYER_NORM_FOR_RMS_NORM, dtype=torch.bool),
|
||||
)
|
||||
|
||||
# Check for chunked P_LL mode via environment variable
|
||||
use_chunked_pll = os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"
|
||||
ranked_logger.info(f"RFD3 initialized with chunked_pll={use_chunked_pll}")
|
||||
|
||||
# Simple constant-feature initializer
|
||||
self.token_initializer = TokenInitializer(
|
||||
c_s=c_s,
|
||||
c_z=c_z,
|
||||
c_atom=c_atom,
|
||||
c_atompair=c_atompair,
|
||||
use_chunked_pll=use_chunked_pll,
|
||||
**token_initializer,
|
||||
)
|
||||
|
||||
# Diffusion module instantiated to allow for config scripting
|
||||
self.diffusion_module = hydra.utils.instantiate(
|
||||
diffusion_module, c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z
|
||||
)
|
||||
|
||||
self.use_classifier_free_guidance = (
|
||||
inference_sampler["use_classifier_free_guidance"]
|
||||
and inference_sampler["cfg_scale"] != 1.0
|
||||
)
|
||||
self.cfg_features = inference_sampler.pop("cfg_features", [])
|
||||
|
||||
# ... initialize the inference sampler, which performs a full diffusion rollout during inference
|
||||
self.inference_sampler = ConditionalDiffusionSampler(**inference_sampler)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: dict,
|
||||
coord_atom_lvl_to_be_noised: torch.Tensor = None,
|
||||
n_cycle=None,
|
||||
**_,
|
||||
) -> dict:
|
||||
# Assert that the correct swap is used
|
||||
if bool(self.use_rmsnorm.item()) != SWAP_LAYER_NORM_FOR_RMS_NORM:
|
||||
raise ValueError(
|
||||
"Loaded checkpoint with use RMSNorm {} but environment variable set expects {}".format(
|
||||
self.use_rmsnorm.item(),
|
||||
SWAP_LAYER_NORM_FOR_RMS_NORM,
|
||||
)
|
||||
+ " Set environment variable SWAP_LAYER_NORM_FOR_RMS_NORM to {}".format(
|
||||
{True: "1", False: "0"}[self.use_rmsnorm.item()]
|
||||
)
|
||||
)
|
||||
|
||||
initializer_outputs = self.token_initializer(input["f"])
|
||||
|
||||
if self.training:
|
||||
# Single denoising step
|
||||
return self.diffusion_module(
|
||||
X_noisy_L=input["X_noisy_L"],
|
||||
t=input["t"],
|
||||
f=input["f"],
|
||||
n_recycle=n_cycle,
|
||||
**initializer_outputs,
|
||||
) # [D, L, 3]
|
||||
else:
|
||||
if self.use_classifier_free_guidance:
|
||||
f_ref = strip_f(input["f"], self.cfg_features)
|
||||
ref_initializer_outputs = self.token_initializer(f_ref)
|
||||
else:
|
||||
f_ref = None
|
||||
ref_initializer_outputs = None
|
||||
|
||||
return self.inference_sampler.sample_diffusion_like_af3(
|
||||
f=input["f"],
|
||||
f_ref=f_ref, # for cfg
|
||||
diffusion_module=self.diffusion_module,
|
||||
diffusion_batch_size=input["t"].shape[0],
|
||||
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
|
||||
# Forwarded as **kwargs:
|
||||
initializer_outputs=initializer_outputs,
|
||||
ref_initializer_outputs=ref_initializer_outputs, # for cfg
|
||||
)
|
||||
|
||||
|
||||
def centre_random_augment_around_motif(
|
||||
X_L: torch.Tensor, # (D, L, 3) noisy diffused coordinates
|
||||
coord_atom_lvl_to_be_noised: torch.Tensor, # (D, L, 3) original coordinates
|
||||
@@ -193,159 +70,70 @@ def centre_random_augment_around_motif(
|
||||
return X_L, R
|
||||
|
||||
|
||||
class SampleDiffusionWithMotif(SampleDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
center_option: str = "all",
|
||||
move_noise_to_reset_com: bool = False, # Reset the COM of the diffuse region after the re-noising operation in each diffusion step
|
||||
s_trans: float = 1.0, # Translational noise scale for augmentation during inference
|
||||
s_jitter_origin: float = 0.0, # Random translation of motif at the start of diffusion
|
||||
fraction_of_steps_to_fix_motif: float = 0.0, # Fraction of steps to let the model not move the motif. e.g. if we have 10 steps, set this value to 0.2 will make model not move motif for the first 2 steps.
|
||||
skip_few_diffusion_steps: bool = False, # Choose to skip some diffusion steps based on the noise scheme
|
||||
inference_noise_scaling_factor: float = 1.0,
|
||||
# Additional argumnets
|
||||
gamma_min2: float = 0.0,
|
||||
allow_realignment: bool = False,
|
||||
insert_motif_at_end: bool = True,
|
||||
use_classifier_free_guidance: bool = False,
|
||||
cfg_scale: float = 2.0,
|
||||
use_frame_guidance: bool = False, # Use frame guidance to align the virtual atoms to the central atom
|
||||
fg_scale: float = 1.5,
|
||||
zero_drift_noise: bool = False,
|
||||
cfg_t_max: float
|
||||
| None = None, # If not None, use classifier-free guidance only for t < cfg_t_max
|
||||
**kwargs,
|
||||
):
|
||||
self.gamma_min2 = gamma_min2
|
||||
self.allow_realignment = allow_realignment
|
||||
self.insert_motif_at_end = insert_motif_at_end
|
||||
self.use_classifier_free_guidance = use_classifier_free_guidance
|
||||
self.cfg_scale = cfg_scale
|
||||
self.cfg_t_max = cfg_t_max
|
||||
@dataclass(kw_only=True)
|
||||
class SampleDiffusionWithMotif:
|
||||
"""Diffusion sampler that supports optional motif alignment."""
|
||||
|
||||
self.center_option = center_option
|
||||
self.fraction_of_steps_to_fix_motif = fraction_of_steps_to_fix_motif
|
||||
self.move_noise_to_reset_com = move_noise_to_reset_com
|
||||
self.s_trans = s_trans
|
||||
self.skip_few_diffusion_steps = skip_few_diffusion_steps
|
||||
self.s_jitter_origin = s_jitter_origin
|
||||
self.inference_noise_scaling_factor = inference_noise_scaling_factor
|
||||
self.zero_drift_noise = zero_drift_noise
|
||||
# Standard EDM args
|
||||
num_timesteps: int # AF-3: 200
|
||||
min_t: int # AF-3: 0
|
||||
max_t: int # AF-3: 1
|
||||
sigma_data: int # AF-3: 16
|
||||
s_min: float # AF-3: 4e-4
|
||||
s_max: int # AF-3: 160
|
||||
p: int # AF-3: 7
|
||||
gamma_0: float # AF-3: 0.8
|
||||
gamma_min: float # AF-3: 1.0
|
||||
noise_scale: float # AF-3: 1.003
|
||||
step_scale: float # AF-3: 1.5
|
||||
solver: Literal["af3"]
|
||||
|
||||
self.use_frame_guidance = use_frame_guidance
|
||||
self.fg_scale = fg_scale
|
||||
# RFD3 / design args
|
||||
center_option: str = "all"
|
||||
s_trans: float = 1.0
|
||||
s_jitter_origin: float = 0.0
|
||||
fraction_of_steps_to_fix_motif: float = 0.0
|
||||
skip_few_diffusion_steps: bool = False
|
||||
allow_realignment: bool = False
|
||||
insert_motif_at_end: bool = True
|
||||
use_classifier_free_guidance: bool = False
|
||||
cfg_scale: float = 2.0
|
||||
cfg_t_max: float | None = None
|
||||
use_frame_guidance: bool = False
|
||||
fg_scale: float = 1.0
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# TODO: Make this a properly-parametrized function in terms of instance variables provided in the configs
|
||||
# For now, it's just hard-coded for early testing
|
||||
def modify_noise_schedule(self, noise_schedule: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Modify the noise schedule to skip more steps at high noise and fewer at low noise.
|
||||
"""
|
||||
mask = torch.ones_like(noise_schedule, dtype=bool)
|
||||
mask_len = len(mask)
|
||||
mask[: mask_len // 4] = torch.arange(mask_len // 4) % 5 == 0
|
||||
mask[mask_len // 4 : mask_len // 2] = (
|
||||
torch.arange(mask_len // 4, mask_len // 2) % 3 == 0
|
||||
)
|
||||
mask[mask_len // 2 : -mask_len // 4] = (
|
||||
torch.arange(mask_len // 2, mask_len - mask_len // 4) % 2 == 0
|
||||
)
|
||||
return noise_schedule[mask]
|
||||
|
||||
def _get_initial_structure(
|
||||
self,
|
||||
c0: torch.Tensor,
|
||||
D: int,
|
||||
L: int,
|
||||
coord_atom_lvl_to_be_noised: torch.Tensor,
|
||||
is_motif_atom_with_fixed_coord,
|
||||
def _construct_inference_noise_schedule(
|
||||
self, device: torch.device, partial_t: float = None
|
||||
) -> torch.Tensor:
|
||||
noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
|
||||
noise[..., is_motif_atom_with_fixed_coord, :] = 0 # Zero out noise going in
|
||||
X_L = noise + coord_atom_lvl_to_be_noised
|
||||
return X_L
|
||||
"""Constructs a noise schedule for use during inference.
|
||||
|
||||
def _move_noise_to_reset_com(self, X_noisy_L, is_motif_atom_with_fixed_coord):
|
||||
The inference noise schedule is defined in the AF-3 supplement as:
|
||||
|
||||
t_hat = sigma_data * (s_max**(1/p) + t * (s_min**(1/p) - s_max**(1/p)))**p
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor representing the noise schedule `t_hat`.
|
||||
|
||||
Reference:
|
||||
AlphaFold 3 Supplement, Section 3.7.1.
|
||||
"""
|
||||
Reset the COM of the diffuse region after the re-noising operation in each diffusion step.
|
||||
"""
|
||||
if self.center_option == "motif":
|
||||
print(
|
||||
"Warning: Moving the noise is not relevant when centering on the motif! Will be ignored."
|
||||
# Create a linearly spaced tensor of timesteps between min_t and max_t
|
||||
t = torch.linspace(self.min_t, self.max_t, self.num_timesteps, device=device)
|
||||
|
||||
# Construct the noise schedule, using the formula provided in the reference
|
||||
t_hat = (
|
||||
self.sigma_data
|
||||
* (
|
||||
(self.s_max) ** (1 / self.p)
|
||||
+ t * (self.s_min ** (1 / self.p) - self.s_max ** (1 / self.p))
|
||||
)
|
||||
elif self.center_option == "diffuse":
|
||||
displacement_vec = torch.mean(
|
||||
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :],
|
||||
dim=-2,
|
||||
keepdim=True,
|
||||
) # (D, 1, 3) - COM of noisy diffused atoms
|
||||
|
||||
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] = (
|
||||
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] - displacement_vec
|
||||
)
|
||||
else:
|
||||
n_diffused = (~is_motif_atom_with_fixed_coord).sum()
|
||||
displacement_vec = (
|
||||
torch.sum(
|
||||
X_noisy_L,
|
||||
dim=-2,
|
||||
keepdim=True,
|
||||
)
|
||||
/ n_diffused
|
||||
)
|
||||
|
||||
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] = (
|
||||
X_noisy_L[..., ~is_motif_atom_with_fixed_coord, :] - displacement_vec
|
||||
)
|
||||
|
||||
return X_noisy_L
|
||||
|
||||
def _skip_few_diffusion_steps(self, noise_schedule: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Modify the noise schedule to skip more steps at high noise and fewer at low noise.
|
||||
i.e. When the noise is high (first few diffusion steps), skip more steps;
|
||||
When the noise is lower, skip fewer steps;
|
||||
When the noise is low, keep all the steps.
|
||||
"""
|
||||
mask = torch.ones_like(noise_schedule, dtype=bool)
|
||||
mask_len = len(mask)
|
||||
mask[: mask_len // 4] = torch.arange(mask_len // 4) % 5 == 0
|
||||
mask[mask_len // 4 : mask_len // 2] = (
|
||||
torch.arange(mask_len // 4, mask_len // 2) % 3 == 0
|
||||
)
|
||||
mask[mask_len // 2 : -mask_len // 4] = (
|
||||
torch.arange(mask_len // 2, mask_len - mask_len // 4) % 2 == 0
|
||||
)
|
||||
return noise_schedule[mask]
|
||||
|
||||
def sample_diffusion_like_af3(
|
||||
self,
|
||||
*,
|
||||
f: dict[str, Any],
|
||||
diffusion_module: torch.nn.Module,
|
||||
diffusion_batch_size: int,
|
||||
coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
|
||||
initializer_outputs,
|
||||
ref_initializer_outputs: dict[str, Any] | None,
|
||||
f_ref: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
# Motif setup to recenter the motif at every step
|
||||
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
|
||||
|
||||
# Book-keeping
|
||||
noise_schedule = self._construct_inference_noise_schedule(
|
||||
device=coord_atom_lvl_to_be_noised.device
|
||||
** self.p
|
||||
)
|
||||
|
||||
# Choose to adjust the noise schedule
|
||||
if self.skip_few_diffusion_steps:
|
||||
noise_schedule = self._skip_few_diffusion_steps(noise_schedule)
|
||||
|
||||
if "partial_t" in f:
|
||||
if partial_t is not None:
|
||||
# For now, partial t is a global parameter
|
||||
partial_t = f["partial_t"].mean()
|
||||
partial_t = float(partial_t.mean())
|
||||
noise_schedule = t_hat
|
||||
ranked_logger.info("Using partial diffusion with t={}".format(partial_t))
|
||||
|
||||
# Debug the noise schedule filtering
|
||||
@@ -382,11 +170,44 @@ class SampleDiffusionWithMotif(SampleDiffusion):
|
||||
f"Using fallback: final step with t={noise_schedule[0].item():.6f}"
|
||||
)
|
||||
|
||||
return t_hat
|
||||
|
||||
def _get_initial_structure(
|
||||
self,
|
||||
c0: torch.Tensor,
|
||||
D: int,
|
||||
L: int,
|
||||
coord_atom_lvl_to_be_noised: torch.Tensor,
|
||||
is_motif_atom_with_fixed_coord,
|
||||
) -> torch.Tensor:
|
||||
noise = c0 * torch.normal(mean=0.0, std=1.0, size=(D, L, 3), device=c0.device)
|
||||
noise[..., is_motif_atom_with_fixed_coord, :] = 0 # Zero out noise going in
|
||||
X_L = noise + coord_atom_lvl_to_be_noised
|
||||
return X_L
|
||||
|
||||
def sample_diffusion_like_af3(
|
||||
self,
|
||||
*,
|
||||
f: dict[str, Any],
|
||||
diffusion_module: torch.nn.Module,
|
||||
diffusion_batch_size: int,
|
||||
coord_atom_lvl_to_be_noised: Float[torch.Tensor, "D L 3"],
|
||||
initializer_outputs,
|
||||
ref_initializer_outputs: dict[str, Any] | None,
|
||||
f_ref: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
# Motif setup to recenter the motif at every step
|
||||
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
|
||||
|
||||
# Book-keeping
|
||||
noise_schedule = self._construct_inference_noise_schedule(
|
||||
device=coord_atom_lvl_to_be_noised.device,
|
||||
partial_t=f.get("partial_t", None),
|
||||
)
|
||||
|
||||
L = f["ref_element"].shape[0]
|
||||
D = diffusion_batch_size
|
||||
|
||||
noise_schedule = noise_schedule * self.inference_noise_scaling_factor
|
||||
|
||||
X_L = self._get_initial_structure(
|
||||
c0=noise_schedule[0],
|
||||
D=D,
|
||||
@@ -433,7 +254,7 @@ class SampleDiffusionWithMotif(SampleDiffusion):
|
||||
|
||||
# Update gamma & step scale
|
||||
gamma = self.gamma_0 if c_t > self.gamma_min else 0
|
||||
step_scale = self.step_scale if c_t > self.gamma_min2 else 3.0
|
||||
step_scale = self.step_scale
|
||||
|
||||
# Compute the value of t_hat
|
||||
t_hat = c_t_minus_1 * (gamma + 1)
|
||||
@@ -444,19 +265,11 @@ class SampleDiffusionWithMotif(SampleDiffusion):
|
||||
* torch.sqrt(torch.square(t_hat) - torch.square(c_t_minus_1))
|
||||
* torch.normal(mean=0.0, std=1.0, size=X_L.shape, device=X_L.device)
|
||||
)
|
||||
if self.zero_drift_noise:
|
||||
epsilon_L = epsilon_L - torch.mean(epsilon_L, dim=-2, keepdim=True)
|
||||
epsilon_L[..., is_motif_atom_with_fixed_coord, :] = (
|
||||
0 # No noise injection for fixed atoms
|
||||
)
|
||||
X_noisy_L = X_L + epsilon_L
|
||||
|
||||
# Adjustg the center of mass
|
||||
if self.move_noise_to_reset_com:
|
||||
X_noisy_L = self._move_noise_to_reset_com(
|
||||
X_noisy_L, is_motif_atom_with_fixed_coord
|
||||
)
|
||||
|
||||
# Denoise the coordinates
|
||||
# Handle chunked mode vs standard mode
|
||||
if "chunked_pairwise_embedder" in initializer_outputs:
|
||||
@@ -624,53 +437,10 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
||||
is_motif_atom_with_fixed_coord = f["is_motif_atom_with_fixed_coord"]
|
||||
# Book-keeping
|
||||
noise_schedule = self._construct_inference_noise_schedule(
|
||||
device=coord_atom_lvl_to_be_noised.device
|
||||
device=coord_atom_lvl_to_be_noised.device,
|
||||
partial_t=f.get("partial_t", None),
|
||||
)
|
||||
|
||||
# Handle partial_t for symmetry sampler (same as regular sampler)
|
||||
if "partial_t" in f:
|
||||
# For now, partial t is a global parameter
|
||||
partial_t = f["partial_t"].mean()
|
||||
ranked_logger.info(
|
||||
"Symmetry sampler: Using partial diffusion with t={}".format(partial_t)
|
||||
)
|
||||
|
||||
# Debug the noise schedule filtering
|
||||
original_schedule_len = len(noise_schedule)
|
||||
original_max = noise_schedule.max().item()
|
||||
original_min = noise_schedule.min().item()
|
||||
|
||||
noise_schedule = noise_schedule[noise_schedule <= partial_t]
|
||||
|
||||
new_schedule_len = len(noise_schedule)
|
||||
if new_schedule_len > 0:
|
||||
new_max = noise_schedule.max().item()
|
||||
new_min = noise_schedule.min().item()
|
||||
ranked_logger.info(
|
||||
f"Symmetry noise schedule: {original_schedule_len} → {new_schedule_len} steps"
|
||||
)
|
||||
ranked_logger.info(
|
||||
f"Symmetry original range: [{original_min:.3f}, {original_max:.3f}]"
|
||||
)
|
||||
ranked_logger.info(
|
||||
f"Symmetry filtered range: [{new_min:.3f}, {new_max:.3f}]"
|
||||
)
|
||||
else:
|
||||
ranked_logger.warning(
|
||||
f"Symmetry sampler: No noise schedule steps found with t <= {partial_t}!"
|
||||
)
|
||||
ranked_logger.info(
|
||||
f"Symmetry original schedule range: [{original_min:.3f}, {original_max:.3f}]"
|
||||
)
|
||||
# Fallback to smallest available step
|
||||
noise_schedule_original = self._construct_inference_noise_schedule(
|
||||
device=coord_atom_lvl_to_be_noised.device
|
||||
)
|
||||
noise_schedule = noise_schedule_original[-1:] # Just use the final step
|
||||
ranked_logger.info(
|
||||
f"Symmetry using fallback: final step with t={noise_schedule[0].item():.6f}"
|
||||
)
|
||||
|
||||
L = f["ref_element"].shape[0]
|
||||
D = diffusion_batch_size
|
||||
X_L = self._get_initial_structure(
|
||||
@@ -711,7 +481,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
||||
|
||||
# Update gamma & step scale
|
||||
gamma = self.gamma_0 if c_t > self.gamma_min else 0
|
||||
step_scale = self.step_scale if c_t > self.gamma_min2 else 1.05
|
||||
step_scale = self.step_scale
|
||||
|
||||
# Compute the value of t_hat
|
||||
t_hat = c_t_minus_1 * (gamma + 1)
|
||||
@@ -6,18 +6,18 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from opt_einsum import contract as einsum
|
||||
from rfd3.model.block_utils import (
|
||||
from rfd3.model.layers.block_utils import (
|
||||
create_attention_indices,
|
||||
indices_to_mask,
|
||||
)
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.model.layers.layer_utils import (
|
||||
from rfd3.model.layers.layer_utils import (
|
||||
AdaLN,
|
||||
LinearBiasInit,
|
||||
RMSNorm,
|
||||
linearNoBias,
|
||||
)
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
@@ -417,7 +417,6 @@ def sparse_pairbias_attention(
|
||||
if B.ndim == 3:
|
||||
B_gathered = B[query_idx, indices, :] # (D, L, k, H)
|
||||
elif B.ndim == 4: # (D, L, L, H)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
B_gathered = B[batch_idx, query_idx, indices, :] # (D, L, k, H)
|
||||
else:
|
||||
assert B.shape == (D, L, k, H), "B must be batched with shape (D, L, k, H)"
|
||||
@@ -6,35 +6,62 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
from einops import rearrange
|
||||
from rfd3.model.attention import (
|
||||
from rfd3.model.layers.attention import (
|
||||
GatedCrossAttention,
|
||||
LocalAttentionPairBias,
|
||||
)
|
||||
from rfd3.model.block_utils import (
|
||||
from rfd3.model.layers.block_utils import (
|
||||
build_valid_mask,
|
||||
create_attention_indices,
|
||||
group_atoms,
|
||||
ungroup_atoms,
|
||||
)
|
||||
from torch.nn.functional import one_hot
|
||||
|
||||
from modelhub import DISABLE_CHECKPOINTING
|
||||
from modelhub.common import exists
|
||||
from modelhub.model.layers.af3_diffusion_transformer import (
|
||||
ConditionedTransitionBlock,
|
||||
)
|
||||
from modelhub.model.layers.layer_utils import (
|
||||
from rfd3.model.layers.layer_utils import (
|
||||
AdaLN,
|
||||
EmbeddingLayer,
|
||||
LinearBiasInit,
|
||||
RMSNorm,
|
||||
Transition,
|
||||
collapse,
|
||||
linearNoBias,
|
||||
)
|
||||
from modelhub.model.layers.pairformer_layers import PairformerBlock
|
||||
from rfd3.model.layers.pairformer_layers import PairformerBlock
|
||||
from torch.nn.functional import one_hot
|
||||
|
||||
from modelhub import DISABLE_CHECKPOINTING
|
||||
from modelhub.common import exists
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SwiGLU transition block with adaptive layernorm
|
||||
class ConditionedTransitionBlock(nn.Module):
|
||||
def __init__(self, c_token, c_s, n=2):
|
||||
super().__init__()
|
||||
self.ada_ln = AdaLN(c_a=c_token, c_s=c_s)
|
||||
self.linear_1 = linearNoBias(c_token, c_token * n)
|
||||
self.linear_2 = linearNoBias(c_token, c_token * n)
|
||||
self.linear_output_project = nn.Sequential(
|
||||
LinearBiasInit(c_s, c_token, biasinit=-2.0),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self.linear_3 = linearNoBias(c_token * n, c_token)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Ai, # [B, I, C_token]
|
||||
Si, # [B, I, C_token]
|
||||
):
|
||||
Ai = self.ada_ln(Ai, Si)
|
||||
# BUG: This is not the correct implementation of SwiGLU
|
||||
# Bi = torch.sigmoid(self.linear_1(Ai)) * self.linear_2(Ai)
|
||||
# FIX: This is the correct implementation of SwiGLU
|
||||
Bi = torch.nn.functional.silu(self.linear_1(Ai)) * self.linear_2(Ai)
|
||||
|
||||
# Output projection (from adaLN-Zero)
|
||||
return self.linear_output_project(Si) * self.linear_3(Bi)
|
||||
|
||||
|
||||
class PositionPairDistEmbedder(nn.Module):
|
||||
def __init__(self, c_atompair, embed_frame=True):
|
||||
super().__init__()
|
||||
@@ -10,8 +10,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelhub.model.layers.layer_utils import RMSNorm, linearNoBias
|
||||
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
|
||||
|
||||
|
||||
class ChunkedPositionPairDistEmbedder(nn.Module):
|
||||
@@ -3,11 +3,11 @@ import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rfd3.model.block_utils import (
|
||||
from rfd3.model.layers.block_utils import (
|
||||
bucketize_scaled_distogram,
|
||||
pairwise_mean_pool,
|
||||
)
|
||||
from rfd3.model.blocks import (
|
||||
from rfd3.model.layers.blocks import (
|
||||
Downcast,
|
||||
LocalAtomTransformer,
|
||||
OneDFeatureEmbedder,
|
||||
@@ -15,19 +15,19 @@ from rfd3.model.blocks import (
|
||||
RelativePositionEncodingWithIndexRemoval,
|
||||
SinusoidalDistEmbed,
|
||||
)
|
||||
from rfd3.model.chunked_pairwise import (
|
||||
from rfd3.model.layers.chunked_pairwise import (
|
||||
ChunkedPairwiseEmbedder,
|
||||
ChunkedPositionPairDistEmbedder,
|
||||
ChunkedSinusoidalDistEmbed,
|
||||
)
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.model.layers.layer_utils import (
|
||||
from rfd3.model.layers.layer_utils import (
|
||||
RMSNorm,
|
||||
Transition,
|
||||
linearNoBias,
|
||||
)
|
||||
from modelhub.model.layers.pairformer_layers import PairformerBlock
|
||||
from rfd3.model.layers.pairformer_layers import PairformerBlock
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
197
models/rfd3/src/rfd3/model/layers/layer_utils.py
Normal file
197
models/rfd3/src/rfd3/model/layers/layer_utils.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.functional import silu
|
||||
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||
|
||||
ranked_logger.info("Fused RMSNorm enabled!")
|
||||
RMSNorm_ = FusedRMSNorm
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
ranked_logger.warning(
|
||||
"Using nn.RMSNorm instead of apex.normalization.fused_layer_norm.FusedRMSNorm."
|
||||
"Ensure you're using the correct apptainer"
|
||||
)
|
||||
RMSNorm_ = nn.RMSNorm
|
||||
|
||||
|
||||
# Allow bias=False to be passed for RMSNorm
|
||||
def RMSNorm(*args, **kwargs):
|
||||
if "bias" in kwargs:
|
||||
kwargs.pop("bias")
|
||||
return RMSNorm_(*args, **kwargs)
|
||||
|
||||
|
||||
SWAP_LAYER_NORM_FOR_RMS_NORM = True
|
||||
RMSNorm = RMSNorm if SWAP_LAYER_NORM_FOR_RMS_NORM else nn.LayerNorm
|
||||
linearNoBias = partial(torch.nn.Linear, bias=False)
|
||||
|
||||
|
||||
class EmbeddingLayer(nn.Linear):
|
||||
"""
|
||||
Specialized linear layer for correct weight initialization for embedding layers.
|
||||
|
||||
Embedding layers are functionally a multiplication of an N channel input by an NxC weight matrix to produce an
|
||||
embedding of length C. However, we compute the components separately with a ModuleDict, then sum at the end, for
|
||||
embedding reusability and interoperability purposes.
|
||||
|
||||
This layer uses Xavier initialization as described in [1]_.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty
|
||||
of training deep feedforward neural networks." (2010)
|
||||
http://proceedings.mlr.press/v9/glorot10a.html
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
this_in_features,
|
||||
total_embedding_features,
|
||||
out_features,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
self.total_embedding_features = total_embedding_features
|
||||
self.out_features = out_features
|
||||
super().__init__(
|
||||
this_in_features, out_features, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self, **kwargs):
|
||||
super().reset_parameters()
|
||||
a = math.sqrt(6.0 / float(self.total_embedding_features + self.out_features))
|
||||
nn.init._no_grad_uniform_(self.weight, -a, a)
|
||||
|
||||
|
||||
def collapse(x, L):
|
||||
return x.reshape((L, x.numel() // L))
|
||||
|
||||
|
||||
class MultiDimLinear(nn.Linear):
|
||||
def __init__(self, in_features, out_shape, norm=False, **kwargs):
|
||||
self.out_shape = out_shape
|
||||
out_features = np.prod(out_shape)
|
||||
super().__init__(in_features, out_features, **kwargs)
|
||||
if norm:
|
||||
self.ln = RMSNorm((out_features,))
|
||||
self.use_ln = True
|
||||
else:
|
||||
self.use_ln = False
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self, **kwargs) -> None:
|
||||
super().reset_parameters()
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, x):
|
||||
out = super().forward(x)
|
||||
if self.use_ln:
|
||||
out = self.ln(out)
|
||||
return out.reshape(x.shape[:-1] + self.out_shape)
|
||||
|
||||
|
||||
class LinearBiasInit(nn.Linear):
|
||||
def __init__(self, *args, biasinit, **kwargs):
|
||||
assert biasinit == -2.0 # Sanity check
|
||||
self.biasinit = biasinit
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
super().reset_parameters()
|
||||
self.bias.data.fill_(self.biasinit)
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
def __init__(self, n, c):
|
||||
super().__init__()
|
||||
self.layer_norm_1 = RMSNorm(c)
|
||||
self.linear_1 = linearNoBias(c, n * c)
|
||||
self.linear_2 = linearNoBias(c, n * c)
|
||||
self.linear_3 = linearNoBias(n * c, c)
|
||||
|
||||
@activation_checkpointing
|
||||
def forward(
|
||||
self,
|
||||
X,
|
||||
):
|
||||
X = self.layer_norm_1(X)
|
||||
A = self.linear_1(X)
|
||||
B = self.linear_2(X)
|
||||
X = self.linear_3(silu(A) * B)
|
||||
return X
|
||||
|
||||
|
||||
class AdaLN(nn.Module):
|
||||
def __init__(self, c_a, c_s, n=2):
|
||||
super().__init__()
|
||||
self.ln_a = RMSNorm(normalized_shape=(c_a,), elementwise_affine=False)
|
||||
self.ln_s = RMSNorm(normalized_shape=(c_s,), bias=False)
|
||||
self.to_gain = nn.Sequential(
|
||||
nn.Linear(c_s, c_a),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self.to_bias = linearNoBias(c_s, c_a)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Ai, # [B, I, C_a]
|
||||
Si, # [B, I, C_s]
|
||||
):
|
||||
"""
|
||||
Output:
|
||||
[B, I, C_a]
|
||||
"""
|
||||
Ai = self.ln_a(Ai)
|
||||
Si = self.ln_s(Si)
|
||||
return self.to_gain(Si) * Ai + self.to_bias(Si)
|
||||
|
||||
|
||||
def create_batch_dimension_if_not_present(batched_n_dim):
|
||||
"""
|
||||
Decorator for adapting a function which expects batched arguments with ndim `batched_n_dim` also
|
||||
accept unbatched arguments.
|
||||
"""
|
||||
|
||||
def wrap(f):
|
||||
def _wrap(arg):
|
||||
inserted_batch_dim = False
|
||||
if arg.ndim == batched_n_dim - 1:
|
||||
arg = arg[None]
|
||||
inserted_batch_dim = True
|
||||
elif arg.ndim == batched_n_dim:
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"arg must have {batched_n_dim - 1} or {batched_n_dim} dimensions, got shape {arg.shape=}"
|
||||
)
|
||||
o = f(arg)
|
||||
|
||||
if inserted_batch_dim:
|
||||
assert o.shape[0] == 1, f"{o.shape=}[0] != 1"
|
||||
return o[0]
|
||||
return o
|
||||
|
||||
return _wrap
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def unpack_args_for_checkpointing(arg_names):
|
||||
def wrap(f):
|
||||
def _wrap(*args):
|
||||
f = args[0]
|
||||
return f(**dict(zip(arg_names, args)))
|
||||
|
||||
return _wrap
|
||||
|
||||
return wrap
|
||||
128
models/rfd3/src/rfd3/model/layers/pairformer_layers.py
Normal file
128
models/rfd3/src/rfd3/model/layers/pairformer_layers.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch
|
||||
from rfd3.model.layers.layer_utils import (
|
||||
MultiDimLinear,
|
||||
RMSNorm,
|
||||
Transition,
|
||||
linearNoBias,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
from modelhub.utils.torch import device_of
|
||||
|
||||
|
||||
class AttentionPairBiasPairformerDeepspeed(nn.Module):
|
||||
def __init__(self, c_a, c_s, c_pair, n_head, kq_norm=False):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.c_a = c_a
|
||||
self.c_pair = c_pair
|
||||
self.c = c_a // n_head
|
||||
|
||||
self.to_q = MultiDimLinear(c_a, (n_head, self.c))
|
||||
self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
|
||||
self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
|
||||
self.to_b = linearNoBias(c_pair, n_head)
|
||||
self.to_g = nn.Sequential(
|
||||
MultiDimLinear(c_a, (n_head, self.c), bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self.to_a = linearNoBias(c_a, c_a)
|
||||
# self.linear_output_project = nn.Sequential(
|
||||
# LinearBiasInit(c_s, c_a, biasinit=-2.),
|
||||
# nn.Sigmoid(),
|
||||
# )
|
||||
self.ln_0 = RMSNorm((c_pair,))
|
||||
# self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
|
||||
self.ln_1 = RMSNorm((c_a,))
|
||||
self.use_deepspeed_evo = False
|
||||
self.force_bfloat16 = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
A_I, # [I, C_a]
|
||||
S_I, # [I, C_a] | None
|
||||
Z_II, # [I, I, C_z]
|
||||
Beta_II=None, # [I, I]
|
||||
):
|
||||
# Input projections
|
||||
assert S_I is None
|
||||
A_I = self.ln_1(A_I)
|
||||
|
||||
if self.use_deepspeed_evo or self.force_bfloat16:
|
||||
A_I = A_I.to(torch.bfloat16)
|
||||
|
||||
Q_IH = self.to_q(A_I) # / np.sqrt(self.c)
|
||||
K_IH = self.to_k(A_I)
|
||||
V_IH = self.to_v(A_I)
|
||||
B_IIH = self.to_b(self.ln_0(Z_II)) + Beta_II[..., None]
|
||||
G_IH = self.to_g(A_I)
|
||||
|
||||
B, L = B_IIH.shape[:2]
|
||||
|
||||
if not self.use_deepspeed_evo or L <= 24:
|
||||
Q_IH = Q_IH / torch.sqrt(
|
||||
torch.tensor(self.c).to(Q_IH.device, torch.bfloat16)
|
||||
)
|
||||
# Attention
|
||||
A_IIH = torch.softmax(
|
||||
torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2
|
||||
) # softmax over j
|
||||
## G_IH: [I, H, C]
|
||||
## A_IIH: [I, I, H]
|
||||
## V_IH: [I, H, C]
|
||||
A_I = torch.einsum("...ijh,...jhc->...ihc", A_IIH, V_IH)
|
||||
A_I = G_IH * A_I # [B, I, H, C]
|
||||
A_I = A_I.flatten(start_dim=-2) # [B, I, Ca]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
A_I = self.to_a(A_I)
|
||||
|
||||
return A_I
|
||||
|
||||
|
||||
class PairformerBlock(nn.Module):
|
||||
"""
|
||||
Attempt to replicate AF3 architecture from scratch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c_s,
|
||||
c_z,
|
||||
attention_pair_bias,
|
||||
p_drop=0.1,
|
||||
triangle_multiplication=None,
|
||||
triangle_attention=None,
|
||||
n_transition=4,
|
||||
use_deepspeed_evo=True,
|
||||
use_triangle_mult=False,
|
||||
use_triangle_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# self.drop_row = Dropout(broadcast_dim=-2, p_drop=p_drop)
|
||||
# self.drop_col = Dropout(broadcast_dim=-3, p_drop=p_drop)
|
||||
|
||||
self.z_transition = Transition(c=c_z, n=n_transition)
|
||||
|
||||
if c_s > 0:
|
||||
self.s_transition = Transition(c=c_s, n=n_transition)
|
||||
|
||||
self.attention_pair_bias = AttentionPairBiasPairformerDeepspeed(
|
||||
c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias
|
||||
)
|
||||
|
||||
@activation_checkpointing
|
||||
def forward(self, S_I, Z_II):
|
||||
with torch.amp.autocast(
|
||||
device_type=device_of(self).type, enabled=True, dtype=torch.bfloat16
|
||||
):
|
||||
Z_II = Z_II + self.z_transition(Z_II)
|
||||
if S_I is not None:
|
||||
S_I = S_I + self.attention_pair_bias(
|
||||
S_I, None, Z_II, Beta_II=torch.tensor([0.0], device=Z_II.device)
|
||||
)
|
||||
S_I = S_I + self.s_transition(S_I)
|
||||
return S_I, Z_II
|
||||
@@ -1,13 +1,12 @@
|
||||
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rfd3_exec.sh" "$0" "$@"'
|
||||
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import rootutils
|
||||
from dotenv import load_dotenv
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from modelhub.utils.logging import suppress_warnings
|
||||
|
||||
@@ -15,6 +14,8 @@ from modelhub.utils.logging import suppress_warnings
|
||||
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
|
||||
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
|
||||
|
||||
@@ -27,16 +28,18 @@ _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
|
||||
def run_inference(cfg: DictConfig) -> None:
|
||||
"""Execute the specified inference pipeline"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir = Path(temp_dir)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
import ipdb
|
||||
run_params_set = {"inputs", "n_batches", "out_dir"}
|
||||
run_params = {k: v for k, v in cfg.items() if k in run_params_set}
|
||||
|
||||
ipdb.set_trace()
|
||||
# Create __init__ args by filtering for all configs not in run_params
|
||||
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
|
||||
init_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in run_params_set}
|
||||
init_cfg = OmegaConf.create(init_cfg_dict)
|
||||
inference_engine = instantiate(init_cfg, _convert_="partial", _recursive_=False)
|
||||
|
||||
inference_engine = instantiate(cfg, temp_dir=temp_dir, _convert_="partial")
|
||||
with suppress_warnings(is_inference=True):
|
||||
inference_engine.eval()
|
||||
# # Run inference
|
||||
with suppress_warnings(is_inference=True):
|
||||
inference_engine.run(**run_params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import ipdb # noqa: F401
|
||||
@@ -38,7 +37,7 @@ print(f"Project root: {os.environ.get('PROJECT_ROOT', '../..')}")
|
||||
|
||||
|
||||
is_inference = True
|
||||
outdir = Path("/home/jbutch/Projects/HT25/af3/modelhub_refactor/rfd3/tests/outs")
|
||||
# outdir = Path("/home/jbutch/Projects/HT25/af3/modelhub_refactor/rfd3/tests/outs")
|
||||
args = TEST_JSON_DATA["1qys-1-refactored"]
|
||||
input = instantiate_example(args, is_inference=is_inference)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from atomworks.common import sum_string_arrays
|
||||
from atomworks.io.utils.io_utils import to_cif_file
|
||||
from atomworks.ml.transforms.center_random_augmentation import CenterRandomAugmentation
|
||||
from biotite.structure import AtomArrayStack
|
||||
from rfd3.trainer.rfd3_trainer import _reassign_unindexed_token_chains
|
||||
from rfd3.trainer.rfd3 import _reassign_unindexed_token_chains
|
||||
from rfd3.transforms.design_transforms import (
|
||||
MotifCenterRandomAugmentation,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import copy
|
||||
import getpass
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -23,14 +22,12 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../src")
|
||||
|
||||
import atomworks
|
||||
from atomworks import parse
|
||||
from atomworks.io.parser import STANDARD_PARSER_ARGS
|
||||
from atomworks.io.utils.io_utils import to_cif_file
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import OmegaConf
|
||||
from rfd3.constants import STANDARD_PARSER_ARGS
|
||||
from rfd3.inference.datasets import (
|
||||
prepare_pipeline_input_from_atom_array,
|
||||
)
|
||||
from rfd3.inference.input_parsing import (
|
||||
DesignInputSpecification,
|
||||
create_atom_array_from_design_specification,
|
||||
)
|
||||
from rfd3.transforms.pipelines import (
|
||||
@@ -106,7 +103,7 @@ TEST_CFG_TRAIN = load_train_or_val_cfg()
|
||||
##########################################################################################
|
||||
|
||||
DIRS = [
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../../tests'),
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../tests"),
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__))),
|
||||
TEST_CFG_TRAIN.paths.data.design_benchmark_data_dir,
|
||||
]
|
||||
@@ -156,9 +153,6 @@ def load_test_json():
|
||||
TEST_JSON_DATA = load_test_json()
|
||||
assert TEST_JSON_DATA, "No test json data loaded!"
|
||||
|
||||
sig = inspect.signature(create_atom_array_from_design_specification)
|
||||
valid_keys_ = sig.parameters.keys()
|
||||
|
||||
|
||||
def filter_inference_args(args):
|
||||
return {k: v for k, v in args.items() if k in valid_keys_}
|
||||
@@ -171,8 +165,11 @@ def instantiate_example(args, is_inference=True):
|
||||
if is_inference:
|
||||
# Keep only the kwargs that the function actually accepts
|
||||
# args = filter_inference_args(args)
|
||||
atom_array, spec = create_atom_array_from_design_specification(**args)
|
||||
input = prepare_pipeline_input_from_atom_array(atom_array)
|
||||
# atom_array, spec = create_atom_array_from_design_specification(**args)
|
||||
# input = prepare_pipeline_input_from_atom_array(atom_array)
|
||||
input = DesignInputSpecification.safe_init(**args).to_pipeline_input(
|
||||
example_id=args.get("example_id", "example")
|
||||
)
|
||||
else:
|
||||
file = args.get("input")
|
||||
if file is None:
|
||||
|
||||
@@ -219,7 +219,6 @@ class FabricTrainer(ABC):
|
||||
)
|
||||
self.initialize_or_update_trainer_state({"scheduler_cfg": scheduler_cfg})
|
||||
|
||||
@abstractmethod
|
||||
def construct_model(self):
|
||||
"""Instantiate the model, updating the trainer state in-place.
|
||||
|
||||
|
||||
@@ -1,54 +1,34 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
spread_token_wise,
|
||||
)
|
||||
from beartype.typing import Any, List, Union
|
||||
from biotite.structure import AtomArray, AtomArrayStack, concatenate, infer_elements
|
||||
from biotite.structure import AtomArray, AtomArrayStack
|
||||
from biotite.structure.residues import get_residue_starts
|
||||
from einops import repeat
|
||||
from jaxtyping import Float, Int
|
||||
from lightning_utilities import apply_to_collection
|
||||
from omegaconf import DictConfig
|
||||
from rfd3.constants import (
|
||||
ATOM14_ATOM_NAMES,
|
||||
VIRTUAL_ATOM_ELEMENT_NAME,
|
||||
association_schemes,
|
||||
association_schemes_stripped,
|
||||
)
|
||||
from rfd3.metrics.design_metrics import get_all_backbone_metrics
|
||||
from rfd3.metrics.hbonds_hbplus_metrics import get_hbond_metrics
|
||||
from rfd3.trainer.fabric_trainer import FabricTrainer
|
||||
from rfd3.trainer.recycling import get_recycle_schedule
|
||||
from rfd3.transforms.conditioning_utils import (
|
||||
from rfd3.trainer.trainer_utils import (
|
||||
_build_atom_array_stack,
|
||||
_cleanup_virtual_atoms_and_assign_atom_name_elements,
|
||||
_reassign_unindexed_token_chains,
|
||||
_reorder_dict,
|
||||
process_unindexed_outputs,
|
||||
)
|
||||
from rfd3.util.io import (
|
||||
from rfd3.utils.io import (
|
||||
build_stack_from_atom_array_and_batched_coords,
|
||||
)
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.metrics.losses import Loss
|
||||
from modelhub.metrics.metric import MetricManager
|
||||
from modelhub.training.EMA import EMA
|
||||
from modelhub.trainers.fabric import FabricTrainer
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.torch import assert_no_nans, assert_same_shape
|
||||
|
||||
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
||||
|
||||
|
||||
def _remap_outputs(
|
||||
xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
|
||||
) -> Float[torch.Tensor, "D L 3"]:
|
||||
"""Helper function to remap outputs using a mapping tensor."""
|
||||
for i in range(xyz.shape[0]):
|
||||
xyz[i, mapping[i]] = xyz[i].clone()
|
||||
return xyz
|
||||
|
||||
class AADesignTrainer(FabricTrainer):
|
||||
"""Mostly for unique things like saving outputs and parsing inputs
|
||||
|
||||
@@ -90,9 +70,6 @@ class AADesignTrainer(FabricTrainer):
|
||||
)
|
||||
self.association_scheme = association_scheme
|
||||
self.seed = None
|
||||
self.inference_sampler_overrides = None
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# (Initialize recycle schedule upfront so all GPU's can sample the same number of recycles within a batch)
|
||||
self.n_recycles_train = n_recycles_train
|
||||
@@ -110,31 +87,10 @@ class AADesignTrainer(FabricTrainer):
|
||||
if metrics
|
||||
else None
|
||||
)
|
||||
|
||||
# Loss (full precision)
|
||||
with torch.autocast(device_type=self.fabric.device.type, enabled=False):
|
||||
self.loss = AF3Loss(**loss) if loss else None
|
||||
self.loss = Loss(**loss) if loss else None
|
||||
|
||||
# FROM RF3
|
||||
def construct_model(self):
|
||||
"""Construct the model and optionally wrap with EMA."""
|
||||
# ... instantiate model with Hydra and Fabric
|
||||
with self.fabric.init_module():
|
||||
ranked_logger.info("Instantiating model...")
|
||||
|
||||
model = hydra.utils.instantiate(
|
||||
self.state["train_cfg"].model.net,
|
||||
_recursive_=False,
|
||||
)
|
||||
|
||||
# Optionally, wrap the model with EMA
|
||||
if self.state["train_cfg"].model.ema is not None:
|
||||
ranked_logger.info("Wrapping model with EMA...")
|
||||
model = EMA(model, **self.state["train_cfg"].model.ema)
|
||||
|
||||
self.initialize_or_update_trainer_state({"model": model})
|
||||
|
||||
# ~~~FROM RF3
|
||||
def _assemble_network_inputs(self, example: dict) -> dict:
|
||||
"""Assemble and validate the network inputs."""
|
||||
assert_same_shape(example["coord_atom_lvl_to_be_noised"], example["noise"])
|
||||
@@ -159,7 +115,7 @@ class AADesignTrainer(FabricTrainer):
|
||||
network_input["X_noisy_L"] = torch.nan_to_num(
|
||||
network_input["X_noisy_L"]
|
||||
)
|
||||
ranked_logger.warning(str(e))
|
||||
global_logger.warning(str(e))
|
||||
else:
|
||||
# During validation, since we do not crop, there should be no NaN's in the coordinates to noise
|
||||
# (They were either removed, as is done with fully unresolved chains, or resolved accoring to our pipeline's rules)
|
||||
@@ -172,40 +128,6 @@ class AADesignTrainer(FabricTrainer):
|
||||
|
||||
return network_input
|
||||
|
||||
# FROM RF3
|
||||
def _assemble_metrics_extra_info(self, example: dict, network_output: dict) -> dict:
|
||||
"""Prepares the extra info for the metrics"""
|
||||
# We need the same information as for the loss...
|
||||
metrics_extra_info = self._assemble_loss_extra_info(example)
|
||||
|
||||
# ... and possibly some additional metadata from the example dictionary
|
||||
# TODO: Generalize, so we always use the `extra_info` key, rather than unpacking the ground truth as well
|
||||
metrics_extra_info.update(
|
||||
{
|
||||
# TODO: Remove, instead using `extra_info` for all keys
|
||||
**{
|
||||
k: example["ground_truth"][k]
|
||||
for k in [
|
||||
"interfaces_to_score",
|
||||
"pn_units_to_score",
|
||||
"chain_iid_token_lvl",
|
||||
]
|
||||
if k in example["ground_truth"]
|
||||
},
|
||||
"example_id": example[
|
||||
"example_id"
|
||||
], # We require the example ID for logging
|
||||
# (From the parser)
|
||||
**example.get("extra_info", {}),
|
||||
}
|
||||
)
|
||||
|
||||
# Record metrics_tags for this example
|
||||
metrics_extra_info["metrics_tags"] = example.get("metrics_tags", set())
|
||||
|
||||
# (Create a shallow copy to avoid modifying the original dictionary)
|
||||
return {**metrics_extra_info}
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
batch: Any,
|
||||
@@ -296,9 +218,6 @@ class AADesignTrainer(FabricTrainer):
|
||||
# (Note that forward() passes to the EMA/shadow model if the model is not training)
|
||||
network_output = model.forward(
|
||||
input=network_input,
|
||||
n_cycle=example["feats"]["msa_stack"].shape[
|
||||
0
|
||||
], # Determine the number of recycles from the MSA stack shape
|
||||
coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"],
|
||||
)
|
||||
|
||||
@@ -307,28 +226,19 @@ class AADesignTrainer(FabricTrainer):
|
||||
msg=f"network_output for example_id: {example['example_id']}",
|
||||
)
|
||||
|
||||
# ... Convert output to a stack of atom arrays
|
||||
predicted_atom_array_stack, prediction_metadata = (
|
||||
self._build_predicted_atom_array_stack(network_output, example)
|
||||
)
|
||||
|
||||
metrics_output = {}
|
||||
if compute_metrics and exists(self.metrics):
|
||||
if compute_metrics:
|
||||
assert self.metrics is not None, "Metrics are not defined!"
|
||||
|
||||
metrics_extra_info = self._assemble_metrics_extra_info(
|
||||
example, network_output
|
||||
)
|
||||
|
||||
# Symmetry resolution
|
||||
# TODO: Refactor such that symmetry returns the ideal coordinate permutation, we apply permutation, and pass adjusted prediction to metrics
|
||||
# (without needing to use `extra_info` as we are now)
|
||||
# TODO: Update symmetry resolution to be functional (vs. using class variable), take explicit inputs (vs. all from netowork_ouput), and use extra_info for the keys it needs
|
||||
metrics_extra_info = self.subunit_symm_resolve(
|
||||
network_output,
|
||||
metrics_extra_info,
|
||||
example["symmetry_resolution"],
|
||||
)
|
||||
|
||||
metrics_extra_info = self.residue_symm_resolve(
|
||||
network_output,
|
||||
metrics_extra_info,
|
||||
example["automorphisms"],
|
||||
)
|
||||
|
||||
metrics_output = self.metrics(
|
||||
network_input=network_input,
|
||||
network_output=network_output,
|
||||
@@ -337,22 +247,33 @@ class AADesignTrainer(FabricTrainer):
|
||||
ground_truth_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
|
||||
metrics_extra_info["X_gt_L"], example.get("atom_array", None)
|
||||
),
|
||||
predicted_atom_array_stack=build_stack_from_atom_array_and_batched_coords(
|
||||
network_output["X_L"], example.get("atom_array", None)
|
||||
),
|
||||
predicted_atom_array_stack=predicted_atom_array_stack,
|
||||
prediction_metadata=prediction_metadata,
|
||||
)
|
||||
|
||||
if "X_gt_index_to_X" in metrics_extra_info:
|
||||
# Remap outputs to minimize error with ground truth
|
||||
# TODO: Remap before computing metrics, so that we can avoid pass `extra_info` to metrics (we instead just pass the remapped prediction)
|
||||
mapping = metrics_extra_info["X_gt_index_to_X"] # [D, L]
|
||||
network_output["X_L"] = _remap_outputs(network_output["X_L"], mapping)
|
||||
|
||||
# Avoid gradients in stored values to prevent memory leaks
|
||||
if metrics_output is not None:
|
||||
metrics_output = apply_to_collection(
|
||||
metrics_output, torch.Tensor, lambda x: x.detach()
|
||||
)
|
||||
|
||||
network_output = apply_to_collection(
|
||||
network_output, torch.Tensor, lambda x: x.detach()
|
||||
)
|
||||
if network_output is not None:
|
||||
network_output = apply_to_collection(
|
||||
network_output, torch.Tensor, lambda x: x.detach()
|
||||
)
|
||||
|
||||
return {"metrics_output": metrics_output, "network_output": network_output}
|
||||
return {
|
||||
"metrics_output": metrics_output,
|
||||
"network_output": network_output,
|
||||
"predicted_atom_array_stack": predicted_atom_array_stack,
|
||||
"prediction_metadata": prediction_metadata,
|
||||
}
|
||||
|
||||
def _assemble_loss_extra_info(self, example: dict) -> dict:
|
||||
"""Assembles metadata arguments to the loss function (incremental to the network inputs and outputs)."""
|
||||
@@ -501,7 +422,7 @@ class AADesignTrainer(FabricTrainer):
|
||||
# Align ca and calculate RMSD:
|
||||
if xyz_ca_input.shape == xyz_ca_output.shape:
|
||||
try:
|
||||
from rfd3.util.alignment import weighted_rigid_align
|
||||
from rfd3.utils.alignment import weighted_rigid_align
|
||||
|
||||
xyz_ca_output_aligned = (
|
||||
weighted_rigid_align(
|
||||
@@ -532,266 +453,3 @@ class AADesignTrainer(FabricTrainer):
|
||||
# Reorder metadata dictionaries to ensure 'metrics' and 'specification' are last
|
||||
metadata_dict = {k: _reorder_dict(d) for k, d in metadata_dict.items()}
|
||||
return atom_array_stack, metadata_dict
|
||||
|
||||
|
||||
def _reorder_dict(d: dict) -> OrderedDict:
|
||||
"""
|
||||
Reorders keys in the dictionary to ensure 'metrics' and 'specification' are last (in that order if both present).
|
||||
"""
|
||||
ordered = OrderedDict()
|
||||
first_keys = ["task", "diffused_index_map"]
|
||||
last_keys = ["metrics", "specification", "inference_sampler"]
|
||||
# First
|
||||
for k in first_keys:
|
||||
if k in d:
|
||||
ordered[k] = d[k]
|
||||
# Middle
|
||||
for k in d:
|
||||
if k not in last_keys and k not in first_keys:
|
||||
ordered[k] = d[k]
|
||||
# Last
|
||||
for k in last_keys:
|
||||
if k in d:
|
||||
ordered[k] = d[k]
|
||||
return ordered
|
||||
|
||||
|
||||
def _build_atom_array_stack(
|
||||
coords,
|
||||
src_atom_array,
|
||||
sequence_indices,
|
||||
sequence_logits,
|
||||
allow_sequence_outputs=True,
|
||||
read_sequence_from_sequence_head=True,
|
||||
association_scheme: str = "atom14",
|
||||
):
|
||||
"""
|
||||
Wraps around build_atom_array_and_batched_coords to also include additional modifications to atom array
|
||||
"""
|
||||
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
|
||||
coords, src_atom_array.copy()
|
||||
)
|
||||
|
||||
# ... Spoof empty sequences to alanines
|
||||
atom_array_stack.res_name[
|
||||
atom_array_stack.is_protein & (atom_array_stack.res_name == "UNK")
|
||||
] = "ALA"
|
||||
|
||||
# ... Add sequence if available
|
||||
if allow_sequence_outputs:
|
||||
array_list = []
|
||||
if read_sequence_from_sequence_head and exists(sequence_logits):
|
||||
sequence_encoding = AF3SequenceEncoding()
|
||||
for i, (atom_array, seq_indices, seq_logits) in enumerate(
|
||||
zip(atom_array_stack, sequence_indices, sequence_logits)
|
||||
):
|
||||
# Set residue names
|
||||
diffused_mask = ~atom_array.is_motif_atom_with_fixed_seq
|
||||
three_letter_sequence = sequence_encoding.decode(
|
||||
seq_indices.cpu().numpy().astype(int)
|
||||
) # [I]
|
||||
|
||||
atom_array.res_name[diffused_mask] = three_letter_sequence[
|
||||
atom_array.token_id
|
||||
][diffused_mask] # [L]
|
||||
|
||||
# Set bfactor column as entropy of sequence logits
|
||||
p = torch.softmax(seq_logits, dim=-1).cpu().numpy() # shape (L, 32)
|
||||
res_entropy = -np.sum(p * np.log(p + 1e-10), axis=-1) # shape (L,)
|
||||
atom_array.b_factor = spread_token_wise(atom_array, res_entropy)
|
||||
array_list.append(atom_array.copy())
|
||||
else:
|
||||
# This automatically deletes virtual atoms and assigns resname, atom name, and elements
|
||||
for atom_array in atom_array_stack:
|
||||
atom_array = _readout_seq_from_struc(
|
||||
atom_array, association_scheme=association_scheme
|
||||
)
|
||||
array_list.append(atom_array)
|
||||
|
||||
# Return as list
|
||||
atom_array_stack = array_list
|
||||
|
||||
return atom_array_stack
|
||||
|
||||
|
||||
def _reassign_unindexed_token_chains(atom_array):
|
||||
if np.any((mask := atom_array.is_motif_atom_unindexed)):
|
||||
# HACK: Since res_ids are the same, we should save them with a different chain index.
|
||||
atom_array.chain_id[mask] = "X"
|
||||
atom_array.res_id[mask] = atom_array.orig_res_id[mask]
|
||||
|
||||
# Parse to separate chains
|
||||
starts = get_token_starts(atom_array)
|
||||
unindexed_starts = starts[mask[starts]]
|
||||
token_breaks = atom_array[
|
||||
unindexed_starts
|
||||
].is_motif_atom_unindexed_motif_breakpoint
|
||||
token_group_id = np.cumsum(token_breaks, dtype=int) # Group by motif breaks
|
||||
token_chain_id = np.array([f"X{i}" for i in token_group_id])
|
||||
|
||||
chains = atom_array.chain_id[starts]
|
||||
chains[mask[starts]] = token_chain_id
|
||||
atom_array.chain_id = spread_token_wise(atom_array, chains)
|
||||
return atom_array
|
||||
|
||||
|
||||
def _cleanup_virtual_atoms_and_assign_atom_name_elements(
|
||||
atom_array, association_scheme: str = "atom14"
|
||||
):
|
||||
## remove virtual atoms based on predicted residue and assign correct atom name and elements
|
||||
ret_mask = []
|
||||
atom_names = []
|
||||
# This is used to indicate which residue is unidentified, probably due to an invalid structure.
|
||||
# This is different from the ref_mask, which is used to delete virtual atoms, but this one is used to assign UNK resname for invalid residues.
|
||||
invalid_mask = []
|
||||
|
||||
# ... Iterate through each residue.
|
||||
# Here we iterate through res_id instead of token_id to avoid some atomization cases or something else.
|
||||
res_ids = atom_array.res_id
|
||||
res_start_indices = np.concatenate(
|
||||
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
|
||||
)
|
||||
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
|
||||
warning_issued = False
|
||||
for start, end in zip(res_start_indices, res_end_indices):
|
||||
res_array = atom_array[start:end]
|
||||
|
||||
is_seq_known = all(
|
||||
np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
|
||||
) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
|
||||
|
||||
# ... If sequence is known for the original atom array, just skip
|
||||
if is_seq_known:
|
||||
ret_mask += [True] * len(res_array)
|
||||
invalid_mask += [False] * len(res_array)
|
||||
res_name = res_array[0].res_name
|
||||
atom_names += res_array.gt_atom_name.tolist()
|
||||
continue
|
||||
|
||||
# ... If sequence is unknown for the original atom array, use the predicted / inferred sequence
|
||||
res_name = res_array[0].res_name
|
||||
if res_name not in association_schemes[association_scheme]:
|
||||
global_logger.warning(
|
||||
"Model predicted non-protein sequence for diffused residue. Cannot clean up outputs. Assigning unknown residue token."
|
||||
)
|
||||
warning_issued = True
|
||||
ret_mask += [True] * len(res_array)
|
||||
invalid_mask += [True] * len(res_array)
|
||||
atom_names += res_array.atom_name.tolist()
|
||||
continue
|
||||
|
||||
scheme = association_schemes[association_scheme][res_name]
|
||||
ret_mask += [True if item is not None else False for item in scheme]
|
||||
atom_names += [item.strip() if item is not None else "VX" for item in scheme]
|
||||
invalid_mask += [False] * len(scheme)
|
||||
|
||||
if len(atom_names) != atom_array.array_length():
|
||||
global_logger.warning(
|
||||
f"{atom_names=}\n{atom_array.atom_name=}\nAtom names length {len(atom_names)} does not match original array length {atom_array.array_length()}."
|
||||
"\nCould not cleanup atom array!!!"
|
||||
)
|
||||
if not warning_issued:
|
||||
raise ValueError("Atom names length does not match original array length. ")
|
||||
return atom_array
|
||||
atom_array.atom_name = atom_names
|
||||
atom_array.element = np.where(
|
||||
atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME,
|
||||
infer_elements(atom_names),
|
||||
atom_array.element,
|
||||
)
|
||||
atom_array.res_name[invalid_mask] = np.array(["UNK"] * sum(invalid_mask))
|
||||
return atom_array[ret_mask]
|
||||
|
||||
|
||||
def _readout_seq_from_struc(
|
||||
atom_array, central_atom="CB", threshold=0.5, association_scheme: str = "atom14"
|
||||
):
|
||||
cur_atom_array_list = []
|
||||
|
||||
# Iterate through each residue
|
||||
res_ids = atom_array.res_id
|
||||
res_start_indices = np.concatenate(
|
||||
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
|
||||
)
|
||||
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
|
||||
|
||||
for start, end in zip(res_start_indices, res_end_indices):
|
||||
# ... Check if the current residue is after padding (seq unknown):
|
||||
cur_res_atom_array = atom_array[start:end]
|
||||
is_seq_known = all(
|
||||
np.array(cur_res_atom_array.is_motif_atom_with_fixed_seq, dtype=bool)
|
||||
)
|
||||
|
||||
# Here it assumes that every non-protein part has its sequence shown (not padded)
|
||||
if not is_seq_known:
|
||||
# For Glycine: it doesn't have CB, so set the virtual atom as CA.
|
||||
# The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
|
||||
# There might be a better way to do this.
|
||||
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
|
||||
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
|
||||
if np.linalg.norm(CA_coord - CB_coord) < threshold:
|
||||
cur_central_atom = "CA"
|
||||
else:
|
||||
cur_central_atom = central_atom
|
||||
|
||||
central_mask = cur_res_atom_array.atom_name == cur_central_atom
|
||||
|
||||
# ... Calculate the distance to the central atom
|
||||
central_coord = cur_res_atom_array.coord[central_mask][
|
||||
0
|
||||
] # Should only have one central atom anyway
|
||||
dists = np.linalg.norm(cur_res_atom_array.coord - central_coord, axis=-1)
|
||||
|
||||
# ... Select virtual atom by the distance. Shouldn't count the central atom itself.
|
||||
is_virtual = (dists < threshold) & ~central_mask
|
||||
|
||||
# ... Throw away virtual atoms
|
||||
cur_res_atom_array_wo_virtual = cur_res_atom_array[~is_virtual]
|
||||
cur_pred_res_atom_names = (
|
||||
cur_res_atom_array_wo_virtual.atom_name
|
||||
) # e.g. [N, CA, C, O, CB, V6, V2]
|
||||
|
||||
# ... Iterate over the possible restypes and find the matched one if there is any
|
||||
has_restype_assigned = False
|
||||
for restype, atom_names in association_schemes_stripped[
|
||||
association_scheme
|
||||
].items():
|
||||
atom_names = np.array(atom_names)
|
||||
|
||||
# Shouldn't match these two
|
||||
if restype in ["UNK", "MSK"]:
|
||||
continue
|
||||
|
||||
# ... Find the index of virtual atom names in the standard atom14 names
|
||||
atom_name_idx_in_atom14_scheme = np.array(
|
||||
[
|
||||
np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
|
||||
for atom_name in cur_pred_res_atom_names
|
||||
]
|
||||
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
|
||||
atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
|
||||
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
|
||||
|
||||
# ... Find the matched restype by checking if all the non-None posititons and None positions match
|
||||
# This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
|
||||
if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
|
||||
x is None for x in atom_names[~atom14_scheme_mask]
|
||||
):
|
||||
cur_res_atom_array.res_name = np.array(
|
||||
[restype] * len(cur_res_atom_array)
|
||||
)
|
||||
cur_atom_array_list.append(cur_res_atom_array)
|
||||
has_restype_assigned = True
|
||||
break
|
||||
else:
|
||||
cur_atom_array_list.append(cur_res_atom_array)
|
||||
has_restype_assigned = True
|
||||
|
||||
# ... Give UNK as the residue name if the mapping fails (unrealistic sidechain)
|
||||
if not has_restype_assigned:
|
||||
cur_res_atom_array.res_name = np.array(["UNK"] * len(cur_res_atom_array))
|
||||
cur_atom_array_list.append(cur_res_atom_array)
|
||||
|
||||
cur_atom_array = concatenate(cur_atom_array_list)
|
||||
|
||||
return cur_atom_array
|
||||
502
models/rfd3/src/rfd3/trainer/trainer_utils.py
Normal file
502
models/rfd3/src/rfd3/trainer/trainer_utils.py
Normal file
@@ -0,0 +1,502 @@
|
||||
from collections import Counter, OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
spread_token_wise,
|
||||
)
|
||||
from biotite.structure import concatenate, infer_elements
|
||||
from jaxtyping import Float, Int
|
||||
from rfd3.constants import (
|
||||
ATOM14_ATOM_NAMES,
|
||||
VIRTUAL_ATOM_ELEMENT_NAME,
|
||||
association_schemes,
|
||||
association_schemes_stripped,
|
||||
)
|
||||
from rfd3.utils.io import (
|
||||
build_stack_from_atom_array_and_batched_coords,
|
||||
)
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
||||
|
||||
#######################################################################
|
||||
# Pythonic Helper functions
|
||||
#######################################################################
|
||||
|
||||
|
||||
def _remap_outputs(
|
||||
xyz: Float[torch.Tensor, "D L 3"], mapping: Int[torch.Tensor, "D L"]
|
||||
) -> Float[torch.Tensor, "D L 3"]:
|
||||
"""Helper function to remap outputs using a mapping tensor."""
|
||||
for i in range(xyz.shape[0]):
|
||||
xyz[i, mapping[i]] = xyz[i].clone()
|
||||
return xyz
|
||||
|
||||
|
||||
def _reorder_dict(d: dict) -> OrderedDict:
|
||||
"""
|
||||
Reorders keys in the dictionary to ensure 'metrics' and 'specification' are last (in that order if both present).
|
||||
"""
|
||||
ordered = OrderedDict()
|
||||
first_keys = ["task", "diffused_index_map"]
|
||||
last_keys = ["metrics", "specification", "inference_sampler"]
|
||||
# First
|
||||
for k in first_keys:
|
||||
if k in d:
|
||||
ordered[k] = d[k]
|
||||
# Middle
|
||||
for k in d:
|
||||
if k not in last_keys and k not in first_keys:
|
||||
ordered[k] = d[k]
|
||||
# Last
|
||||
for k in last_keys:
|
||||
if k in d:
|
||||
ordered[k] = d[k]
|
||||
return ordered
|
||||
|
||||
|
||||
#######################################################################
|
||||
# Biotite-related helper functions
|
||||
#######################################################################
|
||||
|
||||
|
||||
def _build_atom_array_stack(
|
||||
coords,
|
||||
src_atom_array,
|
||||
sequence_indices,
|
||||
sequence_logits,
|
||||
allow_sequence_outputs=True,
|
||||
read_sequence_from_sequence_head=True,
|
||||
association_scheme: str = "atom14",
|
||||
):
|
||||
"""
|
||||
Wraps around build_atom_array_and_batched_coords to also include additional modifications to atom array
|
||||
"""
|
||||
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
|
||||
coords, src_atom_array.copy()
|
||||
)
|
||||
|
||||
# ... Spoof empty sequences to alanines
|
||||
atom_array_stack.res_name[
|
||||
atom_array_stack.is_protein & (atom_array_stack.res_name == "UNK")
|
||||
] = "ALA"
|
||||
|
||||
# ... Add sequence if available
|
||||
if allow_sequence_outputs:
|
||||
array_list = []
|
||||
if read_sequence_from_sequence_head and exists(sequence_logits):
|
||||
sequence_encoding = AF3SequenceEncoding()
|
||||
for i, (atom_array, seq_indices, seq_logits) in enumerate(
|
||||
zip(atom_array_stack, sequence_indices, sequence_logits)
|
||||
):
|
||||
# Set residue names
|
||||
diffused_mask = ~atom_array.is_motif_atom_with_fixed_seq
|
||||
three_letter_sequence = sequence_encoding.decode(
|
||||
seq_indices.cpu().numpy().astype(int)
|
||||
) # [I]
|
||||
|
||||
atom_array.res_name[diffused_mask] = three_letter_sequence[
|
||||
atom_array.token_id
|
||||
][diffused_mask] # [L]
|
||||
|
||||
# Set bfactor column as entropy of sequence logits
|
||||
p = torch.softmax(seq_logits, dim=-1).cpu().numpy() # shape (L, 32)
|
||||
res_entropy = -np.sum(p * np.log(p + 1e-10), axis=-1) # shape (L,)
|
||||
atom_array.b_factor = spread_token_wise(atom_array, res_entropy)
|
||||
array_list.append(atom_array.copy())
|
||||
else:
|
||||
# This automatically deletes virtual atoms and assigns resname, atom name, and elements
|
||||
for atom_array in atom_array_stack:
|
||||
atom_array = _readout_seq_from_struc(
|
||||
atom_array, association_scheme=association_scheme
|
||||
)
|
||||
array_list.append(atom_array)
|
||||
|
||||
# Return as list
|
||||
atom_array_stack = array_list
|
||||
|
||||
return atom_array_stack
|
||||
|
||||
|
||||
def _cleanup_virtual_atoms_and_assign_atom_name_elements(
|
||||
atom_array, association_scheme: str = "atom14"
|
||||
):
|
||||
## remove virtual atoms based on predicted residue and assign correct atom name and elements
|
||||
ret_mask = []
|
||||
atom_names = []
|
||||
# This is used to indicate which residue is unidentified, probably due to an invalid structure.
|
||||
# This is different from the ref_mask, which is used to delete virtual atoms, but this one is used to assign UNK resname for invalid residues.
|
||||
invalid_mask = []
|
||||
|
||||
# ... Iterate through each residue.
|
||||
# Here we iterate through res_id instead of token_id to avoid some atomization cases or something else.
|
||||
res_ids = atom_array.res_id
|
||||
res_start_indices = np.concatenate(
|
||||
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
|
||||
)
|
||||
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
|
||||
warning_issued = False
|
||||
for start, end in zip(res_start_indices, res_end_indices):
|
||||
res_array = atom_array[start:end]
|
||||
|
||||
is_seq_known = all(
|
||||
np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
|
||||
) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
|
||||
|
||||
# ... If sequence is known for the original atom array, just skip
|
||||
if is_seq_known:
|
||||
ret_mask += [True] * len(res_array)
|
||||
invalid_mask += [False] * len(res_array)
|
||||
res_name = res_array[0].res_name
|
||||
atom_names += res_array.gt_atom_name.tolist()
|
||||
continue
|
||||
|
||||
# ... If sequence is unknown for the original atom array, use the predicted / inferred sequence
|
||||
res_name = res_array[0].res_name
|
||||
if res_name not in association_schemes[association_scheme]:
|
||||
global_logger.warning(
|
||||
"Model predicted non-protein sequence for diffused residue. Cannot clean up outputs. Assigning unknown residue token."
|
||||
)
|
||||
warning_issued = True
|
||||
ret_mask += [True] * len(res_array)
|
||||
invalid_mask += [True] * len(res_array)
|
||||
atom_names += res_array.atom_name.tolist()
|
||||
continue
|
||||
|
||||
scheme = association_schemes[association_scheme][res_name]
|
||||
ret_mask += [True if item is not None else False for item in scheme]
|
||||
atom_names += [item.strip() if item is not None else "VX" for item in scheme]
|
||||
invalid_mask += [False] * len(scheme)
|
||||
|
||||
if len(atom_names) != atom_array.array_length():
|
||||
global_logger.warning(
|
||||
f"{atom_names=}\n{atom_array.atom_name=}\nAtom names length {len(atom_names)} does not match original array length {atom_array.array_length()}."
|
||||
"\nCould not cleanup atom array!!!"
|
||||
)
|
||||
if not warning_issued:
|
||||
raise ValueError("Atom names length does not match original array length. ")
|
||||
return atom_array
|
||||
atom_array.atom_name = atom_names
|
||||
atom_array.element = np.where(
|
||||
atom_array.element == VIRTUAL_ATOM_ELEMENT_NAME,
|
||||
infer_elements(atom_names),
|
||||
atom_array.element,
|
||||
)
|
||||
atom_array.res_name[invalid_mask] = np.array(["UNK"] * sum(invalid_mask))
|
||||
return atom_array[ret_mask]
|
||||
|
||||
|
||||
def _readout_seq_from_struc(
|
||||
atom_array, central_atom="CB", threshold=0.5, association_scheme: str = "atom14"
|
||||
):
|
||||
cur_atom_array_list = []
|
||||
|
||||
# Iterate through each residue
|
||||
res_ids = atom_array.res_id
|
||||
res_start_indices = np.concatenate(
|
||||
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
|
||||
)
|
||||
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
|
||||
|
||||
for start, end in zip(res_start_indices, res_end_indices):
|
||||
# ... Check if the current residue is after padding (seq unknown):
|
||||
cur_res_atom_array = atom_array[start:end]
|
||||
is_seq_known = all(
|
||||
np.array(cur_res_atom_array.is_motif_atom_with_fixed_seq, dtype=bool)
|
||||
)
|
||||
|
||||
# Here it assumes that every non-protein part has its sequence shown (not padded)
|
||||
if not is_seq_known:
|
||||
# For Glycine: it doesn't have CB, so set the virtual atom as CA.
|
||||
# The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
|
||||
# There might be a better way to do this.
|
||||
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
|
||||
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
|
||||
if np.linalg.norm(CA_coord - CB_coord) < threshold:
|
||||
cur_central_atom = "CA"
|
||||
else:
|
||||
cur_central_atom = central_atom
|
||||
|
||||
central_mask = cur_res_atom_array.atom_name == cur_central_atom
|
||||
|
||||
# ... Calculate the distance to the central atom
|
||||
central_coord = cur_res_atom_array.coord[central_mask][
|
||||
0
|
||||
] # Should only have one central atom anyway
|
||||
dists = np.linalg.norm(cur_res_atom_array.coord - central_coord, axis=-1)
|
||||
|
||||
# ... Select virtual atom by the distance. Shouldn't count the central atom itself.
|
||||
is_virtual = (dists < threshold) & ~central_mask
|
||||
|
||||
# ... Throw away virtual atoms
|
||||
cur_res_atom_array_wo_virtual = cur_res_atom_array[~is_virtual]
|
||||
cur_pred_res_atom_names = (
|
||||
cur_res_atom_array_wo_virtual.atom_name
|
||||
) # e.g. [N, CA, C, O, CB, V6, V2]
|
||||
|
||||
# ... Iterate over the possible restypes and find the matched one if there is any
|
||||
has_restype_assigned = False
|
||||
for restype, atom_names in association_schemes_stripped[
|
||||
association_scheme
|
||||
].items():
|
||||
atom_names = np.array(atom_names)
|
||||
|
||||
# Shouldn't match these two
|
||||
if restype in ["UNK", "MSK"]:
|
||||
continue
|
||||
|
||||
# ... Find the index of virtual atom names in the standard atom14 names
|
||||
atom_name_idx_in_atom14_scheme = np.array(
|
||||
[
|
||||
np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
|
||||
for atom_name in cur_pred_res_atom_names
|
||||
]
|
||||
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
|
||||
atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
|
||||
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
|
||||
|
||||
# ... Find the matched restype by checking if all the non-None posititons and None positions match
|
||||
# This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
|
||||
if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
|
||||
x is None for x in atom_names[~atom14_scheme_mask]
|
||||
):
|
||||
cur_res_atom_array.res_name = np.array(
|
||||
[restype] * len(cur_res_atom_array)
|
||||
)
|
||||
cur_atom_array_list.append(cur_res_atom_array)
|
||||
has_restype_assigned = True
|
||||
break
|
||||
else:
|
||||
cur_atom_array_list.append(cur_res_atom_array)
|
||||
has_restype_assigned = True
|
||||
|
||||
# ... Give UNK as the residue name if the mapping fails (unrealistic sidechain)
|
||||
if not has_restype_assigned:
|
||||
cur_res_atom_array.res_name = np.array(["UNK"] * len(cur_res_atom_array))
|
||||
cur_atom_array_list.append(cur_res_atom_array)
|
||||
|
||||
cur_atom_array = concatenate(cur_atom_array_list)
|
||||
|
||||
return cur_atom_array
|
||||
|
||||
|
||||
#######################################################################
|
||||
# Unindexed output parsing
|
||||
#######################################################################
|
||||
|
||||
|
||||
def _reassign_unindexed_token_chains(atom_array):
|
||||
if np.any((mask := atom_array.is_motif_atom_unindexed)):
|
||||
# HACK: Since res_ids are the same, we should save them with a different chain index.
|
||||
atom_array.chain_id[mask] = "X"
|
||||
atom_array.res_id[mask] = atom_array.orig_res_id[mask]
|
||||
|
||||
# Parse to separate chains
|
||||
starts = get_token_starts(atom_array)
|
||||
unindexed_starts = starts[mask[starts]]
|
||||
token_breaks = atom_array[
|
||||
unindexed_starts
|
||||
].is_motif_atom_unindexed_motif_breakpoint
|
||||
token_group_id = np.cumsum(token_breaks, dtype=int) # Group by motif breaks
|
||||
token_chain_id = np.array([f"X{i}" for i in token_group_id])
|
||||
|
||||
chains = atom_array.chain_id[starts]
|
||||
chains[mask[starts]] = token_chain_id
|
||||
atom_array.chain_id = spread_token_wise(atom_array, chains)
|
||||
return atom_array
|
||||
|
||||
|
||||
def process_unindexed_outputs(
|
||||
atom_array,
|
||||
match_atom_names=True,
|
||||
insert_guideposts=False,
|
||||
verbose=False,
|
||||
):
|
||||
"""
|
||||
Process design outputs containing unindexed tokens.
|
||||
Returns metadata such as the assigned positional indices from the input indices
|
||||
and the RMSD of the unindexed tokens.
|
||||
|
||||
Returns:
|
||||
- Diffused atom array (without additional unindexed tokens)
|
||||
- Metadata:
|
||||
- diffused_indices: keys = original (contig) indices, values = diffused indices
|
||||
- insertion_rmsd: overall RMSD of insertion
|
||||
- insertion_rmsd_by_residue: RMSD of insertion for each token
|
||||
|
||||
TODO: Add additional geometry metrics such as bond angle non-ideality, clashes etc.
|
||||
TODO: atom1d conditioning adherence - does the output contain HBonds in the right places, correct rasa values?
|
||||
"""
|
||||
# ... Find assignments based on greedy search
|
||||
starts = get_token_starts(atom_array, add_exclusive_stop=True)
|
||||
|
||||
# [N_diffused,]
|
||||
atom_array_diffused = atom_array[~atom_array.is_motif_atom_unindexed].copy()
|
||||
global_idx = np.arange(atom_array.array_length())[
|
||||
~atom_array.is_motif_atom_unindexed
|
||||
]
|
||||
|
||||
metadata = {
|
||||
"diffused_index_map": {},
|
||||
"insertion_rmsd_by_token": {},
|
||||
"join_point_rmsd_by_token": {},
|
||||
"insertion_rmsd_by_restype": {},
|
||||
}
|
||||
token_maes = []
|
||||
token_rmcds = []
|
||||
n_conjoined_residues = 0
|
||||
|
||||
# Initialize an empty array
|
||||
inserted_mask = np.full_like(atom_array_diffused.is_motif_atom_unindexed, False)
|
||||
|
||||
for start, end in zip(starts[:-1], starts[1:]):
|
||||
token = atom_array[start:end]
|
||||
if not token.is_motif_atom_unindexed.all():
|
||||
continue
|
||||
|
||||
if "src_component" in token.get_annotation_categories():
|
||||
token_pdb_id = token.src_component[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Missing annotation 'src_component' in token. Is this inference?"
|
||||
)
|
||||
|
||||
if "src_sym_component" in token.get_annotation_categories():
|
||||
# if symmetry, token_pdb_id are updated to match the symmetrized component
|
||||
token_pdb_id = token.src_sym_component[0]
|
||||
|
||||
res_name = token.res_name[0]
|
||||
|
||||
# ... Calculate [N_unindex, N_diffused] distance matrix
|
||||
dists = np.linalg.norm(
|
||||
token.coord[:, None] - atom_array_diffused.coord[None, :], axis=-1
|
||||
)
|
||||
|
||||
# ... Match atom indices based on atom names (mask out non-identical) and remove already inserted
|
||||
dists[:, inserted_mask.copy()] = np.inf
|
||||
if match_atom_names:
|
||||
matching_atom_name = (
|
||||
token.atom_name[:, None] == atom_array_diffused.atom_name[None, :]
|
||||
)
|
||||
dists[~matching_atom_name] = np.inf
|
||||
|
||||
# ... Find the res_id's in the diffused regions belonging to the diffused indices
|
||||
row_ind, col_ind = linear_sum_assignment(dists)
|
||||
res_id, chain_id, is_conjoined = indices_to_components_(
|
||||
atom_array_diffused, col_ind
|
||||
)
|
||||
n_conjoined_residues += int(is_conjoined)
|
||||
|
||||
# ... Recompute distance indices based on single residue pairings only
|
||||
token_match = (atom_array_diffused.res_id == res_id) & (
|
||||
atom_array_diffused.chain_id == chain_id
|
||||
)
|
||||
dists[:, ~token_match] = np.nan
|
||||
BIG = 1e12
|
||||
dists = np.nan_to_num(dists, nan=BIG, posinf=BIG, neginf=BIG)
|
||||
row_ind, col_ind = linear_sum_assignment(dists)
|
||||
res_id_, chain_id_, _ = indices_to_components_(atom_array_diffused, col_ind)
|
||||
|
||||
assert (res_id_ == res_id) & (chain_id_ == chain_id)
|
||||
inserted_mask = np.logical_or(inserted_mask, token_match)
|
||||
|
||||
# ... Compute metrics based on the new distances
|
||||
diff = token.coord[row_ind] - atom_array_diffused.coord[col_ind]
|
||||
token_rmsd = float(np.sqrt((diff**2).sum(-1).mean()))
|
||||
token_rmcd = float(np.cbrt((np.abs(diff) ** 3).sum(-1).mean()))
|
||||
token_mae = float((np.abs(diff)).sum(-1).mean())
|
||||
|
||||
metadata["insertion_rmsd_by_token"][token_pdb_id] = token_rmsd
|
||||
token_maes.append(token_mae)
|
||||
token_rmcds.append(token_rmcd)
|
||||
|
||||
if res_name not in metadata["insertion_rmsd_by_restype"]:
|
||||
metadata["insertion_rmsd_by_restype"][res_name] = []
|
||||
metadata["insertion_rmsd_by_restype"][res_name].append(token_rmsd)
|
||||
if not np.any(np.isin(token.atom_name, ["N", "CA", "C", "O"])):
|
||||
if np.sum(token.atomize) == 1:
|
||||
join_atom = np.where(token.atomize)[0][0]
|
||||
elif "CB" in token.atom_name:
|
||||
join_atom = np.where(token.atom_name == "CB")[0][0]
|
||||
else:
|
||||
join_atom = None
|
||||
|
||||
if join_atom is None:
|
||||
global_logger.warning(
|
||||
f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
|
||||
)
|
||||
else:
|
||||
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
|
||||
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
|
||||
|
||||
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
|
||||
|
||||
# ... Decide whether to cleanup guideposts or not
|
||||
if insert_guideposts:
|
||||
atom_array_diffused.coord[global_idx[col_ind]] = token.coord[row_ind]
|
||||
if token.is_motif_atom_with_fixed_seq[0]:
|
||||
atom_array_diffused.res_name[token_match] = token.res_name[0]
|
||||
# atom_array_diffused.is_motif_token[token_match] = True
|
||||
# atom_array_diffused.is_motif_atom[global_idx[col_ind]] = True
|
||||
atom_array_diffused.is_motif_atom_with_fixed_coord[global_idx[col_ind]] = (
|
||||
True
|
||||
)
|
||||
|
||||
# ... Calculate global metrics
|
||||
def safe_mean(x):
|
||||
"""Return nan-safe mean for empty or nan arrays."""
|
||||
x = np.asarray(x, float)
|
||||
if x.size == 0 or not np.isfinite(x).any():
|
||||
return float("nan")
|
||||
return float(np.nanmean(x))
|
||||
|
||||
metadata["insertion.mae"] = safe_mean(token_maes)
|
||||
metadata["insertion.rmcd"] = safe_mean(token_rmcds)
|
||||
metadata["insertion_rmsd"] = safe_mean(
|
||||
list(metadata["insertion_rmsd_by_token"].values())
|
||||
)
|
||||
metadata["join_point_rmsd"] = safe_mean(
|
||||
list(metadata["join_point_rmsd_by_token"].values())
|
||||
)
|
||||
metadata["insertion_rmsd_by_restype"] = {
|
||||
a: safe_mean(v) for a, v in metadata["insertion_rmsd_by_restype"].items()
|
||||
}
|
||||
metadata["n_conjoined_residues"] = n_conjoined_residues
|
||||
|
||||
if not verbose:
|
||||
metadata = {
|
||||
k: v for k, v in metadata.items() if not k.startswith("insertion_rmsd_by_")
|
||||
}
|
||||
|
||||
return atom_array_diffused, metadata
|
||||
|
||||
|
||||
def indices_to_components_(atom_array, col_ind):
|
||||
"""
|
||||
Fetch chain and resids in atom array given a set of raw indices
|
||||
will return 'conjoined' if indices to not map to a unique residue
|
||||
"""
|
||||
res_ids, chain_ids = (
|
||||
atom_array.res_id[col_ind],
|
||||
atom_array.chain_id[col_ind],
|
||||
)
|
||||
if len(set(res_ids.tolist())) > 1 or len(set(chain_ids.tolist())) > 1:
|
||||
global_logger.warning(
|
||||
f"Unindexed token mapped its atoms to multiple diffused residues: {res_ids.tolist()} and chains {chain_ids.tolist()}."
|
||||
)
|
||||
# Handle by majority
|
||||
pair_counts = Counter(zip(chain_ids.tolist(), res_ids.tolist()))
|
||||
(chain_id, res_id), _ = pair_counts.most_common(1)[0]
|
||||
conjoined = True
|
||||
else:
|
||||
res_id = res_ids[0]
|
||||
chain_id = chain_ids[0]
|
||||
conjoined = False
|
||||
|
||||
return res_id, chain_id, conjoined
|
||||
@@ -1,56 +0,0 @@
|
||||
from atomworks.ml.transforms._checks import (
|
||||
check_contains_keys,
|
||||
check_is_instance,
|
||||
)
|
||||
from atomworks.ml.transforms.base import Transform
|
||||
from biotite.structure import AtomArray
|
||||
|
||||
|
||||
class SetOccToZeroOnBfactor(Transform):
|
||||
"""
|
||||
This component marks atoms as occ=0 based on bfactor values
|
||||
|
||||
It takes as input 'brange', a list specifying the Mminimum and maximum B factors to
|
||||
keep.
|
||||
|
||||
Example:
|
||||
brange = [-1.0,70.0] will mark with occ=0 any atom with b>70 or b<-1
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bmin=None,
|
||||
bmax=None,
|
||||
):
|
||||
self.bmin = bmin
|
||||
self.bmax = bmax
|
||||
|
||||
def check_input(self, data: dict):
|
||||
check_contains_keys(data, ["atom_array"])
|
||||
check_is_instance(data, "atom_array", AtomArray)
|
||||
# check_atom_array_annotation(data, ["b_factor", "occupancy"])
|
||||
|
||||
def forward(self, data: dict) -> dict:
|
||||
atom_array = data["atom_array"]
|
||||
|
||||
if (
|
||||
self.bmin is None and self.bmax is None
|
||||
) or "b_factor" not in atom_array.get_annotation_categories():
|
||||
return data
|
||||
|
||||
bfact = atom_array.get_annotation("b_factor")
|
||||
if self.bmin is not None:
|
||||
mask = bfact < self.bmin
|
||||
if self.bmax is not None:
|
||||
mask = mask | (bfact > self.bmax)
|
||||
else:
|
||||
mask = bfact > self.bmax
|
||||
|
||||
occ = atom_array.get_annotation("occupancy")
|
||||
occ[mask] = 0.0
|
||||
|
||||
atom_array.set_annotation("occupancy", occ)
|
||||
|
||||
data["atom_array"] = atom_array
|
||||
|
||||
return data
|
||||
@@ -1,10 +1,6 @@
|
||||
from collections import Counter
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from atomworks.io.utils.bonds import _atom_array_to_networkx_graph
|
||||
from atomworks.ml.utils.token import get_token_starts
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
@@ -192,171 +188,6 @@ def choose_uniformly_random_atom_name(subarray):
|
||||
#################################################################################
|
||||
|
||||
|
||||
def process_unindexed_outputs(
|
||||
atom_array,
|
||||
match_atom_names=True,
|
||||
insert_guideposts=False,
|
||||
verbose=False,
|
||||
):
|
||||
"""
|
||||
Process design outputs containing unindexed tokens.
|
||||
Returns metadata such as the assigned positional indices from the input indices
|
||||
and the RMSD of the unindexed tokens.
|
||||
|
||||
Returns:
|
||||
- Diffused atom array (without additional unindexed tokens)
|
||||
- Metadata:
|
||||
- diffused_indices: keys = original (contig) indices, values = diffused indices
|
||||
- insertion_rmsd: overall RMSD of insertion
|
||||
- insertion_rmsd_by_residue: RMSD of insertion for each token
|
||||
|
||||
TODO: Add additional geometry metrics such as bond angle non-ideality, clashes etc.
|
||||
TODO: atom1d conditioning adherence - does the output contain HBonds in the right places, correct rasa values?
|
||||
"""
|
||||
# ... Find assignments based on greedy search
|
||||
starts = get_token_starts(atom_array, add_exclusive_stop=True)
|
||||
|
||||
# [N_diffused,]
|
||||
atom_array_diffused = atom_array[~atom_array.is_motif_atom_unindexed].copy()
|
||||
global_idx = np.arange(atom_array.array_length())[
|
||||
~atom_array.is_motif_atom_unindexed
|
||||
]
|
||||
|
||||
metadata = {
|
||||
"diffused_index_map": {},
|
||||
"insertion_rmsd_by_token": {},
|
||||
"join_point_rmsd_by_token": {},
|
||||
"insertion_rmsd_by_restype": {},
|
||||
}
|
||||
token_maes = []
|
||||
token_rmcds = []
|
||||
n_conjoined_residues = 0
|
||||
|
||||
# Initialize an empty array
|
||||
inserted_mask = np.full_like(atom_array_diffused.is_motif_atom_unindexed, False)
|
||||
|
||||
for start, end in zip(starts[:-1], starts[1:]):
|
||||
token = atom_array[start:end]
|
||||
if not token.is_motif_atom_unindexed.all():
|
||||
continue
|
||||
|
||||
if "src_component" in token.get_annotation_categories():
|
||||
token_pdb_id = token.src_component[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Missing annotation 'src_component' in token. Is this inference?"
|
||||
)
|
||||
|
||||
if "src_sym_component" in token.get_annotation_categories():
|
||||
# if symmetry, token_pdb_id are updated to match the symmetrized component
|
||||
token_pdb_id = token.src_sym_component[0]
|
||||
|
||||
res_name = token.res_name[0]
|
||||
|
||||
# ... Calculate [N_unindex, N_diffused] distance matrix
|
||||
dists = np.linalg.norm(
|
||||
token.coord[:, None] - atom_array_diffused.coord[None, :], axis=-1
|
||||
)
|
||||
|
||||
# ... Match atom indices based on atom names (mask out non-identical) and remove already inserted
|
||||
dists[:, inserted_mask.copy()] = np.inf
|
||||
if match_atom_names:
|
||||
matching_atom_name = (
|
||||
token.atom_name[:, None] == atom_array_diffused.atom_name[None, :]
|
||||
)
|
||||
dists[~matching_atom_name] = np.inf
|
||||
|
||||
# ... Find the res_id's in the diffused regions belonging to the diffused indices
|
||||
row_ind, col_ind = linear_sum_assignment(dists)
|
||||
res_id, chain_id, is_conjoined = indices_to_components(
|
||||
atom_array_diffused, col_ind
|
||||
)
|
||||
n_conjoined_residues += int(is_conjoined)
|
||||
|
||||
# ... Recompute distance indices based on single residue pairings only
|
||||
token_match = (atom_array_diffused.res_id == res_id) & (
|
||||
atom_array_diffused.chain_id == chain_id
|
||||
)
|
||||
dists[:, ~token_match] = np.nan
|
||||
BIG = 1e12
|
||||
dists = np.nan_to_num(dists, nan=BIG, posinf=BIG, neginf=BIG)
|
||||
row_ind, col_ind = linear_sum_assignment(dists)
|
||||
res_id_, chain_id_, _ = indices_to_components(atom_array_diffused, col_ind)
|
||||
|
||||
assert (res_id_ == res_id) & (chain_id_ == chain_id)
|
||||
inserted_mask = np.logical_or(inserted_mask, token_match)
|
||||
|
||||
# ... Compute metrics based on the new distances
|
||||
diff = token.coord[row_ind] - atom_array_diffused.coord[col_ind]
|
||||
token_rmsd = float(np.sqrt((diff**2).sum(-1).mean()))
|
||||
token_rmcd = float(np.cbrt((np.abs(diff) ** 3).sum(-1).mean()))
|
||||
token_mae = float((np.abs(diff)).sum(-1).mean())
|
||||
|
||||
metadata["insertion_rmsd_by_token"][token_pdb_id] = token_rmsd
|
||||
token_maes.append(token_mae)
|
||||
token_rmcds.append(token_rmcd)
|
||||
|
||||
if res_name not in metadata["insertion_rmsd_by_restype"]:
|
||||
metadata["insertion_rmsd_by_restype"][res_name] = []
|
||||
metadata["insertion_rmsd_by_restype"][res_name].append(token_rmsd)
|
||||
if not np.any(np.isin(token.atom_name, ["N", "CA", "C", "O"])):
|
||||
if np.sum(token.atomize) == 1:
|
||||
join_atom = np.where(token.atomize)[0][0]
|
||||
elif "CB" in token.atom_name:
|
||||
join_atom = np.where(token.atom_name == "CB")[0][0]
|
||||
else:
|
||||
join_atom = None
|
||||
|
||||
if join_atom is None:
|
||||
global_logger.warning(
|
||||
f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
|
||||
)
|
||||
else:
|
||||
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
|
||||
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
|
||||
|
||||
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"
|
||||
|
||||
# ... Decide whether to cleanup guideposts or not
|
||||
if insert_guideposts:
|
||||
atom_array_diffused.coord[global_idx[col_ind]] = token.coord[row_ind]
|
||||
if token.is_motif_atom_with_fixed_seq[0]:
|
||||
atom_array_diffused.res_name[token_match] = token.res_name[0]
|
||||
# atom_array_diffused.is_motif_token[token_match] = True
|
||||
# atom_array_diffused.is_motif_atom[global_idx[col_ind]] = True
|
||||
atom_array_diffused.is_motif_atom_with_fixed_coord[global_idx[col_ind]] = (
|
||||
True
|
||||
)
|
||||
|
||||
# ... Calculate global metrics
|
||||
def safe_mean(x):
|
||||
"""Return nan-safe mean for empty or nan arrays."""
|
||||
x = np.asarray(x, float)
|
||||
if x.size == 0 or not np.isfinite(x).any():
|
||||
return float("nan")
|
||||
return float(np.nanmean(x))
|
||||
|
||||
metadata["insertion.mae"] = safe_mean(token_maes)
|
||||
metadata["insertion.rmcd"] = safe_mean(token_rmcds)
|
||||
metadata["insertion_rmsd"] = safe_mean(
|
||||
list(metadata["insertion_rmsd_by_token"].values())
|
||||
)
|
||||
metadata["join_point_rmsd"] = safe_mean(
|
||||
list(metadata["join_point_rmsd_by_token"].values())
|
||||
)
|
||||
metadata["insertion_rmsd_by_restype"] = {
|
||||
a: safe_mean(v) for a, v in metadata["insertion_rmsd_by_restype"].items()
|
||||
}
|
||||
metadata["n_conjoined_residues"] = n_conjoined_residues
|
||||
|
||||
if not verbose:
|
||||
metadata = {
|
||||
k: v for k, v in metadata.items() if not k.startswith("insertion_rmsd_by_")
|
||||
}
|
||||
|
||||
return atom_array_diffused, metadata
|
||||
|
||||
|
||||
def random_condition(p_cond):
|
||||
"""
|
||||
Made this function because I always get confused by which order the
|
||||
@@ -367,28 +198,3 @@ def random_condition(p_cond):
|
||||
return False
|
||||
else:
|
||||
return np.random.rand() < p_cond
|
||||
|
||||
|
||||
def indices_to_components(atom_array, col_ind):
|
||||
"""
|
||||
Fetch chain and resids in atom array given a set of raw indices
|
||||
will return 'conjoined' if indices to not map to a unique residue
|
||||
"""
|
||||
res_ids, chain_ids = (
|
||||
atom_array.res_id[col_ind],
|
||||
atom_array.chain_id[col_ind],
|
||||
)
|
||||
if len(set(res_ids.tolist())) > 1 or len(set(chain_ids.tolist())) > 1:
|
||||
global_logger.warning(
|
||||
f"Unindexed token mapped its atoms to multiple diffused residues: {res_ids.tolist()} and chains {chain_ids.tolist()}."
|
||||
)
|
||||
# Handle by majority
|
||||
pair_counts = Counter(zip(chain_ids.tolist(), res_ids.tolist()))
|
||||
(chain_id, res_id), _ = pair_counts.most_common(1)[0]
|
||||
conjoined = True
|
||||
else:
|
||||
res_id = res_ids[0]
|
||||
chain_id = chain_ids[0]
|
||||
conjoined = False
|
||||
|
||||
return res_id, chain_id, conjoined
|
||||
|
||||
@@ -624,35 +624,33 @@ def build_atom14_base_pipeline(
|
||||
kwargs.setdefault("crop_spatial_probability", 0.0)
|
||||
kwargs.setdefault("dna_contact_crop_probability", 0.0)
|
||||
kwargs.setdefault("max_atoms_in_crop", None)
|
||||
kwargs.setdefault("b_factor_min", None)
|
||||
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
|
||||
kwargs.setdefault("meta_conditioning_probabilities", {})
|
||||
kwargs.setdefault("association_scheme", "atom14")
|
||||
|
||||
kwargs.setdefault("sigma_perturb", 0.0)
|
||||
kwargs.setdefault("sigma_perturb_com", 0.0)
|
||||
kwargs.setdefault("allowed_types", "ALL")
|
||||
kwargs.setdefault("train_conditions", {})
|
||||
|
||||
# TODO: Delete these once all checkpoints are updated with the latest defaults
|
||||
kwargs.setdefault("generate_conformers_for_non_protein_only", True)
|
||||
kwargs.setdefault("atom_1d_features", None)
|
||||
kwargs.setdefault("token_1d_features", None)
|
||||
kwargs.setdefault("diffusion_batch_size", 16)
|
||||
kwargs.setdefault("sigma_data", 16)
|
||||
kwargs.setdefault("return_atom_array", True)
|
||||
kwargs.setdefault("provide_elements_for_unindexed_components", False)
|
||||
kwargs.setdefault("center_option", "all")
|
||||
kwargs.setdefault("use_element_for_atom_names_of_atomized_tokens", False)
|
||||
|
||||
kwargs.setdefault("residue_cache_dir", None)
|
||||
kwargs.setdefault("keep_full_binder_in_spatial_crop", True)
|
||||
kwargs.setdefault("max_binder_length", 999)
|
||||
kwargs.setdefault("max_ppi_hotspots_frac_to_provide", 0)
|
||||
kwargs.setdefault("ppi_hotspot_max_distance", 15)
|
||||
kwargs.setdefault("max_ss_frac_to_provide", 0.0)
|
||||
kwargs.setdefault("min_ss_island_len", 0)
|
||||
kwargs.setdefault("max_ss_island_len", 999)
|
||||
kwargs.setdefault("max_binder_length", 999)
|
||||
|
||||
kwargs.setdefault("b_factor_min", None)
|
||||
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
|
||||
kwargs.setdefault("meta_conditioning_probabilities", {})
|
||||
kwargs.setdefault("association_scheme", "dense")
|
||||
kwargs.setdefault("sigma_perturb", 0.0)
|
||||
kwargs.setdefault("sigma_perturb_com", 0.0)
|
||||
kwargs.setdefault("allowed_types", "ALL")
|
||||
kwargs.setdefault("train_conditions", {})
|
||||
kwargs.setdefault("residue_cache_dir", None)
|
||||
|
||||
# TODO: Delete these once all checkpoints are updated with the latest defaults
|
||||
kwargs.setdefault("generate_conformers_for_non_protein_only", True)
|
||||
# kwargs.setdefault("atom_1d_features", None)
|
||||
# kwargs.setdefault("token_1d_features", None)
|
||||
# kwargs.setdefault("diffusion_batch_size", 16)
|
||||
# kwargs.setdefault("sigma_data", 16)
|
||||
kwargs.setdefault("return_atom_array", True)
|
||||
kwargs.setdefault("provide_elements_for_unindexed_components", False)
|
||||
kwargs.setdefault("center_option", "all")
|
||||
|
||||
return build_atom14_base_pipeline_(
|
||||
is_inference=is_inference,
|
||||
|
||||
@@ -37,7 +37,9 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato
|
||||
raise ValueError(
|
||||
f"Scheme {scheme} not found in association_schemes_stripped. Available schemes: {list(association_schemes_stripped.keys())}"
|
||||
)
|
||||
atom_names = [str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
|
||||
atom_names = (
|
||||
[str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
|
||||
)
|
||||
idxs = np.array(
|
||||
[
|
||||
association_schemes_stripped[scheme][res_name].index(name)
|
||||
|
||||
@@ -10,6 +10,7 @@ import biotite.structure as struc
|
||||
import numpy as np
|
||||
from atomworks import parse
|
||||
from atomworks.constants import STANDARD_DNA
|
||||
from atomworks.io.parser import STANDARD_PARSER_ARGS
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
from atomworks.ml.preprocessing.utils.structure_utils import (
|
||||
get_atom_mask_from_cell_list,
|
||||
@@ -21,18 +22,18 @@ from atomworks.ml.utils.token import (
|
||||
from rfd3.constants import (
|
||||
REQUIRED_CONDITIONING_ANNOTATIONS,
|
||||
)
|
||||
from rfd3.inference.components import (
|
||||
fetch_mask_from_component,
|
||||
get_name_mask,
|
||||
unravel_components,
|
||||
)
|
||||
from rfd3.transforms.conditioning_base import (
|
||||
convert_existing_annotations_to_bool,
|
||||
set_default_conditioning_annotations,
|
||||
)
|
||||
from rfd3.transforms.conditioning_utils import sample_island_tokens
|
||||
from rfd3.constants import STANDARD_PARSER_ARGS
|
||||
|
||||
from modelhub.common import exists
|
||||
from modelhub.utils.components import (
|
||||
fetch_mask_from_component,
|
||||
get_name_mask,
|
||||
unravel_components,
|
||||
)
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -8,7 +8,8 @@ import numpy as np
|
||||
import torch
|
||||
from atomworks.io.utils.io_utils import to_cif_file
|
||||
from biotite.structure import AtomArray, AtomArrayStack, stack
|
||||
from rfd3.util.alignment import weighted_rigid_align
|
||||
|
||||
from modelhub.utils.alignment import weighted_rigid_align
|
||||
|
||||
DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
|
||||
CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
|
||||
@@ -258,7 +258,7 @@ def _viz_from_file(
|
||||
atom_array = pickle.load(f)
|
||||
elif file_path.endswith((".cif", ".cif.gz", ".bcif", ".bcif.gz")):
|
||||
from atomworks.io.utils.io_utils import get_structure, read_any
|
||||
from rfd3.inference.inference_utils import (
|
||||
from rfd3.utils.inference import (
|
||||
_add_design_annotations_from_cif_block_metadata,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@ from rfd3.testing.testing_utils import (
|
||||
TEST_CFG_INFERENCE,
|
||||
TEST_CFG_TRAIN,
|
||||
TEST_JSON_DATA,
|
||||
build_pipelines,
|
||||
instantiate_example,
|
||||
)
|
||||
|
||||
@@ -29,6 +28,23 @@ sys.path.append(PATH_TO_SRC)
|
||||
smoke_test = list(TEST_JSON_DATA.keys())
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_imports():
|
||||
import rfd3
|
||||
|
||||
import modelhub
|
||||
|
||||
print("Imported rfd3 version:", rfd3)
|
||||
print("Imported modelhub version:", modelhub)
|
||||
|
||||
# Try imports from main modules
|
||||
from rfd3.metrics.losses import DiffusionLoss
|
||||
from rfd3.model.RFD3 import RFD3
|
||||
from rfd3.trainer.rfd3 import AADesignTrainer
|
||||
|
||||
print("imported modules:", RFD3, AADesignTrainer, DiffusionLoss)
|
||||
|
||||
|
||||
def test_inference():
|
||||
# Silence outputs:
|
||||
|
||||
@@ -60,6 +76,6 @@ def test_training_pipeline(example_name):
|
||||
example.get("training_condition_name"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# pytest.main(sys.argv)
|
||||
pytest.main(["-v", __file__ + "::test_train_pipeline[rfd3-NA-full]"])
|
||||
pytest.main(sys.argv)
|
||||
|
||||
@@ -8,7 +8,7 @@ from rfd3.testing.testing_utils import (
|
||||
build_pipelines,
|
||||
instantiate_example,
|
||||
)
|
||||
from rfd3.trainer.rfd3_trainer import (
|
||||
from rfd3.trainer.trainer_utils import (
|
||||
_cleanup_virtual_atoms_and_assign_atom_name_elements,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,18 +3,19 @@ import sys
|
||||
import numpy as np
|
||||
import pytest
|
||||
from atomworks.io.utils.io_utils import load_any
|
||||
from rfd3.inference.components import (
|
||||
from rfd3.inference.input_parsing import DesignInputSpecification
|
||||
from rfd3.inference.parsing import InputSelection
|
||||
from rfd3.testing.testing_utils import (
|
||||
TEST_JSON_DATA,
|
||||
)
|
||||
|
||||
from modelhub.utils.components import (
|
||||
fetch_mask_from_component,
|
||||
fetch_mask_from_idx,
|
||||
fetch_mask_from_name,
|
||||
get_name_mask,
|
||||
unravel_components,
|
||||
)
|
||||
from rfd3.inference.input_parsing import InputSpecification
|
||||
from rfd3.inference.parsing import InputSelection
|
||||
from rfd3.testing.testing_utils import (
|
||||
TEST_JSON_DATA,
|
||||
)
|
||||
|
||||
# TEST 1 - test the selections
|
||||
args = TEST_JSON_DATA["amidase_helix"]
|
||||
@@ -116,7 +117,7 @@ atom_array_ref_unindexed = load_any(file)[0]
|
||||
def test_unindexed_break():
|
||||
# Assert that unindexed selections with breaks work
|
||||
sele = InputSelection.from_any(args["unindex"], atom_array=atom_array_ref_unindexed)
|
||||
comps, breaks = InputSpecification.break_unindexed(sele)
|
||||
comps, breaks = DesignInputSpecification.break_unindexed(sele)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,7 +10,7 @@ from rfd3.testing.testing_utils import (
|
||||
instantiate_example,
|
||||
load_train_or_val_cfg,
|
||||
)
|
||||
from rfd3.transforms.conditioning_utils import (
|
||||
from rfd3.trainer.trainer_utils import (
|
||||
process_unindexed_outputs,
|
||||
)
|
||||
|
||||
@@ -48,10 +48,10 @@ def test_unindexed_cleanup(example, is_inference):
|
||||
) # spoof inference label
|
||||
atom_array_cleaned, metadata = process_unindexed_outputs(atom_array)
|
||||
|
||||
from rfd3.testing.debug_utils import save_pipe_out
|
||||
# from rfd3.testing.debug_utils import save_pipe_out
|
||||
|
||||
save_pipe_out(atom_array)
|
||||
print("Metadata:", metadata)
|
||||
# save_pipe_out(atom_array)
|
||||
# print("Metadata:", metadata)
|
||||
|
||||
xyz_cleaned = np.nan_to_num(atom_array_cleaned.coord)
|
||||
xyz_diffused = np.nan_to_num(atom_array[~atom_array.is_motif_atom_unindexed].coord)
|
||||
|
||||
@@ -192,6 +192,8 @@ def test_atom14_pipeline_regression(
|
||||
# Compare results
|
||||
config_desc = f" ({config.name})" if config.name != "a14-base-train" else ""
|
||||
mode_description = f"{'inference' if is_inference else 'training'}{config_desc}"
|
||||
if "specification" in result:
|
||||
expected_result["specification"] = result["specification"]
|
||||
|
||||
_assert_pipeline_results_equal(
|
||||
result, expected_result, example_name, mode_description
|
||||
|
||||
@@ -32,7 +32,7 @@ dependencies = [
|
||||
# ... typing & documentation
|
||||
"jaxtyping>=0.2.17,<1",
|
||||
"beartype>=0.18.0,<1",
|
||||
"typer>=0.9.0,<1",
|
||||
"typer>=0.20.0,<1",
|
||||
# ... ml tools (core)
|
||||
"torch>=2.2.0,<3",
|
||||
"lightning>=2.5.0",
|
||||
@@ -41,12 +41,17 @@ dependencies = [
|
||||
"einx>=0.1.0,<1",
|
||||
"opt_einsum>=3.4.0,<4",
|
||||
"dm-tree>=0.1.6,<1",
|
||||
"zstandard",
|
||||
"pandas",
|
||||
# "biotite",
|
||||
"atomworks"
|
||||
]
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
rfd3 = [
|
||||
"pydantic>=2.8",
|
||||
"toolz",
|
||||
]
|
||||
rf3 = [
|
||||
"cuequivariance_ops_cu12>=0.6.1; sys_platform == 'linux'",
|
||||
@@ -72,7 +77,8 @@ dev = [
|
||||
"pytest-dotenv>=0.5.2,<1", # load environment variables from .env file
|
||||
"pytest-cov>=4.1.0,<5", # generate coverage report
|
||||
"pytest-benchmark>=5.0.0,<6", # benchmark tests for speed
|
||||
"atomworks==1.0.2",
|
||||
# "atomworks==1.0.2",
|
||||
"pre-commit"
|
||||
]
|
||||
[project.scripts]
|
||||
rf3 = "rf3.cli:app"
|
||||
|
||||
@@ -49,5 +49,9 @@ try:
|
||||
except ImportError:
|
||||
logger.debug("cuEquivariance unavailable: import failed")
|
||||
|
||||
|
||||
# Whether to disable checkpointing globally
|
||||
DISABLE_CHECKPOINTING = False
|
||||
|
||||
# Export for easy access
|
||||
__all__ = ["SHOULD_USE_CUEQUIVARIANCE"]
|
||||
__all__ = ["SHOULD_USE_CUEQUIVARIANCE", "DISABLE_CHECKPOINTING"]
|
||||
|
||||
28
src/modelhub/constants.py
Normal file
28
src/modelhub/constants.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# fmt: off
|
||||
# ... For convenience, define BKBN, or TIP to be used as a shortcut | TIP is the largest set of fixed atom given at least 2 tip atoms
|
||||
TIP_BY_RESTYPE = {
|
||||
"TRP": ["CG","CD1","CD2","NE1","CE2","CE3","CZ2","CZ3","CH2"], # fix both rings
|
||||
"HIS": ["CG","ND1","CD2","CE1","NE2"], # fixed ring
|
||||
"TYR": ["CZ","OH"], # keeps ring dihedral flexible
|
||||
"PHE": ["CG","CD1","CD2","CE1","CE2","CZ"],
|
||||
"ASN": ["CB", "CG","OD1","ND2"],
|
||||
"ASP": ["CB", "CG","OD1","OD2"],
|
||||
"GLN": ["CG", "CD","OE1","NE2"],
|
||||
"GLU": ["CG", "CD","OE1","OE2"],
|
||||
"CYS": ["CB", "SG"],
|
||||
"SER": ["CB", "OG"],
|
||||
"THR": ["CB", "OG1"],
|
||||
"LEU": ["CB", "CG", "CD1", "CD2"],
|
||||
"VAL": ["CG1", "CG2"],
|
||||
"ILE": ["CB", "CG2"],
|
||||
"MET": ["SD", "CE"],
|
||||
"LYS": ["CE","NZ"],
|
||||
"ARG": ["CD","NE","CZ","NH1","NH2"],
|
||||
"PRO": None,
|
||||
"ALA": None,
|
||||
"GLY": None,
|
||||
"UNK": None,
|
||||
"MSK": None
|
||||
}
|
||||
|
||||
# fmt: on
|
||||
216
src/modelhub/inference_engines/base.py
Normal file
216
src/modelhub/inference_engines/base.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import logging
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
from biotite.structure import AtomArray
|
||||
from lightning.fabric import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
|
||||
from modelhub.utils.logging import print_config_tree
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
def merge(cfg, overrides: dict):
|
||||
return OmegaConf.merge(cfg, OmegaConf.create(overrides))
|
||||
|
||||
|
||||
class BaseInferenceEngine:
|
||||
"""
|
||||
Base inference engine.
|
||||
Separates model setup (expensive, once) from inference (can run multiple times).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_path: PathLike,
|
||||
num_nodes: int = 1,
|
||||
devices_per_node: int = 1,
|
||||
# Config overrides
|
||||
transform_overrides={},
|
||||
inference_sampler_overrides={},
|
||||
trainer_overrides={},
|
||||
# Debug
|
||||
print_config: bool = False,
|
||||
seed: int | None = None,
|
||||
):
|
||||
"""Initialize inference engine and load model.
|
||||
|
||||
Model config is loaded from checkpoint and overridden with parameters provided here.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to model checkpoint.
|
||||
seed: Random seed. If None, uses external RNG state. Defaults to ``None``.
|
||||
num_nodes: Number of nodes for distributed inference. Defaults to ``1``.
|
||||
devices_per_node: Number of devices per node. Defaults to ``1``.
|
||||
print_config: Whether to print config trees. Defaults to ``False``.
|
||||
"""
|
||||
# Set attrs
|
||||
self.initialized_ = False
|
||||
self.trainer = None
|
||||
self.pipeline = None
|
||||
self.print_config = print_config
|
||||
self.ckpt_path = ckpt_path
|
||||
|
||||
# Set random seed (only if seed is not None)
|
||||
if seed is not None:
|
||||
ranked_logger.info(f"Seeding everything with seed={seed}...")
|
||||
seed_everything(seed, workers=True, verbose=True)
|
||||
else:
|
||||
ranked_logger.info("Seed is None - using external RNG state")
|
||||
self.seed = seed
|
||||
|
||||
# Stored for later;
|
||||
self.transform_overrides = transform_overrides
|
||||
self.overrides: dict[str, Any] = {}
|
||||
|
||||
base_overrides = {
|
||||
"trainer.seed": seed,
|
||||
"trainer.metrics": {},
|
||||
"trainer.loss": None,
|
||||
"trainer.num_nodes": num_nodes,
|
||||
"trainer.devices_per_node": devices_per_node,
|
||||
}
|
||||
for key, value in base_overrides.items():
|
||||
self._assign_override(key, value)
|
||||
|
||||
for key, value in trainer_overrides.items():
|
||||
self._assign_override(f"trainer.{key}", value)
|
||||
|
||||
for key, value in inference_sampler_overrides.items():
|
||||
self._assign_override(f"model.net.inference_sampler.{key}", value)
|
||||
|
||||
###################################################################################
|
||||
# Required subclasss methods
|
||||
###################################################################################
|
||||
|
||||
def initialize(self):
|
||||
if self.initialized_:
|
||||
return getattr(self, "cfg", None)
|
||||
|
||||
# Load checkpoint and config
|
||||
ranked_logger.info(
|
||||
f"Loading checkpoint from {Path(self.ckpt_path).resolve()}..."
|
||||
)
|
||||
checkpoint = torch.load(self.ckpt_path, "cpu", weights_only=False)
|
||||
cfg = self._override_checkpoint_config(checkpoint["train_cfg"])
|
||||
|
||||
# Load pipeline first before trainer/model
|
||||
self._construct_pipeline(cfg)
|
||||
self._construct_trainer(cfg, checkpoint=checkpoint)
|
||||
|
||||
ranked_logger.info("Model loaded and ready for inference.")
|
||||
self.initialized_ = True
|
||||
return cfg
|
||||
|
||||
def run(
|
||||
self,
|
||||
inputs: (
|
||||
Dict[str, dict] | AtomArray | list[AtomArray] | PathLike | list[PathLike]
|
||||
),
|
||||
*_,
|
||||
) -> dict[str, dict] | None:
|
||||
self.initialize()
|
||||
raise NotImplementedError(
|
||||
"Subclasses must implement inference logic in `run` method."
|
||||
)
|
||||
|
||||
###################################################################################
|
||||
# Util methods
|
||||
###################################################################################
|
||||
|
||||
def _override_checkpoint_config(self, cfg):
|
||||
cfg = merge(cfg, self.overrides)
|
||||
cfg = set_accelerator_based_on_availability(cfg)
|
||||
return cfg
|
||||
|
||||
def _construct_trainer(self, cfg, checkpoint=None):
|
||||
"""
|
||||
Sets attr self.trainer
|
||||
"""
|
||||
# Instantiate trainer
|
||||
ranked_logger.info("Instantiating trainer...")
|
||||
if self.print_config:
|
||||
print_config_tree(
|
||||
cfg.trainer, resolve=True, title="INFERENCE TRAINER CONFIGURATION"
|
||||
)
|
||||
trainer = hydra.utils.instantiate(
|
||||
cfg.trainer,
|
||||
_convert_="partial",
|
||||
_recursive_=False,
|
||||
)
|
||||
|
||||
# Setup model
|
||||
ranked_logger.info("Setting up model...")
|
||||
trainer.fabric.launch()
|
||||
trainer.initialize_or_update_trainer_state(
|
||||
{"train_cfg": cfg}
|
||||
) # config from training stores net params
|
||||
trainer.construct_model()
|
||||
|
||||
ranked_logger.info("Loading model weights from checkpoint...")
|
||||
trainer.load_checkpoint(checkpoint=checkpoint or self.ckpt_path)
|
||||
|
||||
# Ensure optimizer isn't loaded
|
||||
trainer.state["optimizer"] = None
|
||||
trainer.state["train_cfg"].model.optimizer = None
|
||||
trainer.setup_model_optimizers_and_schedulers()
|
||||
trainer.state["model"].eval()
|
||||
self.trainer = trainer
|
||||
|
||||
def _assign_override(self, dotted_key: str, value: Any) -> None:
|
||||
"""Assign ``value`` into ``self.overrides`` using a dotted path."""
|
||||
target = self.overrides
|
||||
keys = dotted_key.split(".")
|
||||
for key in keys[:-1]:
|
||||
if key not in target or not isinstance(target[key], dict):
|
||||
target[key] = {}
|
||||
target = target[key]
|
||||
target[keys[-1]] = value
|
||||
|
||||
def _construct_pipeline(self, cfg):
|
||||
"""
|
||||
Sets attr self.pipeline
|
||||
"""
|
||||
# Construct pipeline
|
||||
ranked_logger.info("Building Transform pipeline...")
|
||||
first_val_dataset_key, first_val_dataset = next(iter(cfg.datasets.val.items()))
|
||||
ranked_logger.info(
|
||||
f"Using settings from validation dataset: {first_val_dataset_key}."
|
||||
)
|
||||
transform = first_val_dataset.dataset.transform
|
||||
transform = merge(transform, self.transform_overrides)
|
||||
|
||||
if self.print_config:
|
||||
print_config_tree(
|
||||
transform,
|
||||
resolve=True,
|
||||
title="INFERENCE TRANSFORM PIPELINE",
|
||||
)
|
||||
|
||||
self.pipeline = hydra.utils.instantiate(transform)
|
||||
|
||||
# aliases for run
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
# for use as a context manager: e.g. `with BaseInferenceEngine(...) as engine:` to automatically cleanup
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.trainer = None
|
||||
self.pipeline = None
|
||||
self.initialized_ = False
|
||||
27
src/modelhub/metrics/losses.py
Normal file
27
src/modelhub/metrics/losses.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import hydra
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Loss(nn.Module):
|
||||
def __init__(self, **losses):
|
||||
super().__init__()
|
||||
self.to_compute = []
|
||||
for loss_name, loss in losses.items():
|
||||
loss_fn = hydra.utils.instantiate(loss)
|
||||
print(f"Adding loss {loss_name} to the loss function")
|
||||
self.to_compute.append(loss_fn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
network_input,
|
||||
network_output,
|
||||
loss_input,
|
||||
):
|
||||
loss_dict = {}
|
||||
loss = 0
|
||||
for loss_fn in self.to_compute:
|
||||
loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input)
|
||||
loss += loss_
|
||||
loss_dict.update(loss_dict_)
|
||||
loss_dict["total_loss"] = loss.detach()
|
||||
return loss, loss_dict
|
||||
47
src/modelhub/model/layers/blocks.py
Normal file
47
src/modelhub/model/layers/blocks.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||
|
||||
|
||||
class FourierEmbedding(nn.Module):
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.c = c
|
||||
self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
|
||||
self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
# super().reset_parameters()
|
||||
nn.init.normal_(self.w)
|
||||
nn.init.normal_(self.b)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
t, # [D]
|
||||
):
|
||||
return torch.cos(2 * pi * (t[..., None] * self.w + self.b))
|
||||
|
||||
|
||||
class Dropout(nn.Module):
|
||||
# Dropout entire row or column
|
||||
def __init__(self, broadcast_dim=None, p_drop=0.15):
|
||||
super(Dropout, self).__init__()
|
||||
# give ones with probability of 1-p_drop / zeros with p_drop
|
||||
self.sampler = torch.distributions.bernoulli.Bernoulli(
|
||||
torch.tensor([1 - p_drop])
|
||||
)
|
||||
self.broadcast_dim = broadcast_dim
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, x):
|
||||
if not self.training: # no drophead during evaluation mode
|
||||
return x
|
||||
shape = list(x.shape)
|
||||
if self.broadcast_dim is not None:
|
||||
shape[self.broadcast_dim] = 1
|
||||
mask = self.sampler.sample(shape).to(x.device).view(shape)
|
||||
|
||||
x = mask * x / (1.0 - self.p_drop)
|
||||
return x
|
||||
@@ -69,11 +69,13 @@ class FabricTrainer(ABC):
|
||||
checkpoint_every_n_steps: int | None = None,
|
||||
clip_grad_max_norm: float | None = None,
|
||||
skip_nan_grad: bool = False,
|
||||
error_if_grad_nonfinite: bool = False,
|
||||
limit_train_batches: int | float = float("inf"),
|
||||
limit_val_batches: int | float = float("inf"),
|
||||
prevalidate: bool = False,
|
||||
nccl_timeout: int = 3_200,
|
||||
find_unused_parameters: bool = False,
|
||||
skip_optimizer_loading: bool = False,
|
||||
) -> None:
|
||||
"""Base Trainer class built around Lightning Fabric.
|
||||
|
||||
@@ -103,6 +105,7 @@ class FabricTrainer(ABC):
|
||||
clip_grad_max_norm: Maximum gradient norm to clip to (default: None). If None, no gradient clipping is performed.
|
||||
skip_nan_grad: Whether to skip optimizer updates when gradients contain NaN or Inf values (default: False).
|
||||
Useful for training stability, especially with mixed precision or challenging datasets.
|
||||
error_if_grad_nonfinite: Whether to raise when gradient clipping detects NaN or Inf gradients (default: False).
|
||||
limit_train_batches: Limit on the number of training batches per epoch (default: float("inf")).
|
||||
Helpful for debugging; should NOT be used when training production models.
|
||||
limit_val_batches: Limit on the number of validation batches per epoch (default: float("inf")).
|
||||
@@ -111,6 +114,7 @@ class FabricTrainer(ABC):
|
||||
nccl_timeout: Timeout for NCCL operations (default: 3200). Only used with DDP strategy.
|
||||
find_unused_parameters: Whether to let DDP find and skip gradient synchronization for unused parameters in the model (default: False). NOTE: Setting to True will incur a performance penalty,
|
||||
but allow for training for bespoke use cases where parts of the model are frozen.
|
||||
skip_optimizer_loading: Whether to skip loading the optimizer/scheduler state when restoring from checkpoints (default: False).
|
||||
|
||||
References:
|
||||
(1) Fabric Arguments (https://lightning.ai/docs/fabric/stable/api/fabric_args.html)
|
||||
@@ -145,6 +149,7 @@ class FabricTrainer(ABC):
|
||||
# Training
|
||||
self.clip_grad_max_norm = clip_grad_max_norm
|
||||
self.skip_nan_grad = skip_nan_grad
|
||||
self.error_if_grad_nonfinite = error_if_grad_nonfinite
|
||||
self.grad_accum_steps = grad_accum_steps
|
||||
|
||||
# Stopping
|
||||
@@ -162,6 +167,7 @@ class FabricTrainer(ABC):
|
||||
self.output_dir = Path(output_dir) if output_dir else None
|
||||
self.checkpoint_every_n_epochs = checkpoint_every_n_epochs
|
||||
self.checkpoint_every_n_steps = checkpoint_every_n_steps
|
||||
self.skip_optimizer_loading = skip_optimizer_loading
|
||||
|
||||
# Inject trainer reference into callbacks for easy access to trainer state
|
||||
self._inject_trainer_into_callbacks()
|
||||
@@ -262,14 +268,28 @@ class FabricTrainer(ABC):
|
||||
)
|
||||
self.initialize_or_update_trainer_state({"scheduler_cfg": scheduler_cfg})
|
||||
|
||||
@abstractmethod
|
||||
def construct_model(self):
|
||||
"""Instantiate the model, updating the trainer state in-place.
|
||||
|
||||
This method must set the "model" key in the state dictionary using `self.initialize_or_update_trainer_state()`.
|
||||
For an example, see the `construct_model` method in the `AF3Trainer`
|
||||
Construct the model and optionally wrap with EMA.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# ... instantiate model with Hydra and Fabric
|
||||
with self.fabric.init_module():
|
||||
ranked_logger.info("Instantiating model...")
|
||||
|
||||
model = hydra.utils.instantiate(
|
||||
self.state["train_cfg"].model.net,
|
||||
_recursive_=False,
|
||||
)
|
||||
|
||||
# Optionally, wrap the model with EMA
|
||||
if self.state["train_cfg"].model.ema is not None:
|
||||
ranked_logger.info("Wrapping model with EMA...")
|
||||
model = EMA(model, **self.state["train_cfg"].model.ema)
|
||||
|
||||
self.initialize_or_update_trainer_state({"model": model})
|
||||
|
||||
def setup_model_optimizers_and_schedulers(self) -> None:
|
||||
"""Setup the model, optimizer(s), and scheduler(s) with Fabric.
|
||||
@@ -354,6 +374,11 @@ class FabricTrainer(ABC):
|
||||
), "Checkpoint path not found in checkpoint configuration!"
|
||||
ckpt_path = Path(ckpt_config.path)
|
||||
|
||||
reset_optimizer = bool(
|
||||
getattr(ckpt_config, "reset_optimizer", False)
|
||||
or self.skip_optimizer_loading
|
||||
)
|
||||
|
||||
if ckpt_path.is_dir():
|
||||
# If given a directory, load the latest checkpoint from the directory
|
||||
ranked_logger.info(
|
||||
@@ -362,14 +387,14 @@ class FabricTrainer(ABC):
|
||||
self.load_checkpoint(
|
||||
self.get_latest_checkpoint(ckpt_path),
|
||||
weight_loading_config=ckpt_config.weight_loading_config,
|
||||
reset_optimizer=ckpt_config.reset_optimizer,
|
||||
reset_optimizer=reset_optimizer,
|
||||
)
|
||||
else:
|
||||
# If given a specific checkpoint file, load that checkpoint
|
||||
self.load_checkpoint(
|
||||
ckpt_path,
|
||||
weight_loading_config=ckpt_config.weight_loading_config,
|
||||
reset_optimizer=ckpt_config.reset_optimizer,
|
||||
reset_optimizer=reset_optimizer,
|
||||
)
|
||||
|
||||
# Apply parameter freezing if a freezing config is provided
|
||||
@@ -716,7 +741,7 @@ class FabricTrainer(ABC):
|
||||
module=model,
|
||||
optimizer=optimizer,
|
||||
max_norm=self.clip_grad_max_norm,
|
||||
error_if_nonfinite=False, # Don't error on NaN/Inf if skip_nan_grad is True
|
||||
error_if_nonfinite=self.error_if_grad_nonfinite,
|
||||
)
|
||||
|
||||
# ... step the optimizer
|
||||
|
||||
@@ -8,8 +8,8 @@ logger = logging.getLogger(__name__)
|
||||
def weighted_rigid_align(
|
||||
X_L, # [B, L, 3]
|
||||
X_gt_L, # [B, L, 3]
|
||||
X_exists_L, # [L]
|
||||
w_L, # [B, L]
|
||||
X_exists_L=None, # [L]
|
||||
w_L=None, # [B, L]
|
||||
):
|
||||
"""
|
||||
Weighted rigid body alignment of X_gt_L onto X_L with weights w_L
|
||||
@@ -21,6 +21,13 @@ def weighted_rigid_align(
|
||||
assert X_L.shape == X_gt_L.shape
|
||||
assert X_L.shape[:-1] == w_L.shape
|
||||
|
||||
if X_exists_L is None:
|
||||
X_exists_L = torch.ones((X_L.shape[-2]), dtype=torch.bool)
|
||||
if w_L is None:
|
||||
w_L = torch.ones_like(X_L[..., 0])
|
||||
else:
|
||||
w_L = w_L.to(torch.float32)
|
||||
|
||||
# Assert `X_exists_L` is a boolean mask
|
||||
assert (
|
||||
X_exists_L.dtype == torch.bool
|
||||
@@ -44,6 +44,7 @@ def set_accelerator_based_on_availability(cfg: dict | DictConfig):
|
||||
cfg.trainer.num_nodes = 1
|
||||
else:
|
||||
cfg.trainer.accelerator = "gpu"
|
||||
return cfg
|
||||
|
||||
|
||||
class RankedLogger(logging.LoggerAdapter):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from rf3.flow_matching.rigid_utils import rot_vec_mul
|
||||
|
||||
from modelhub.utils.rigid import rot_vec_mul
|
||||
|
||||
|
||||
def centre(X_L, X_exists_L):
|
||||
Reference in New Issue
Block a user