wip: stage 2, al working

This commit is contained in:
ncorley
2025-08-10 15:20:32 -07:00
parent 07aa1765d6
commit 7d32dcecf3
32 changed files with 238 additions and 2455 deletions

View File

@@ -1,5 +1,5 @@
# Model architecture
_target_: modelhub.model.AF3.AF3
_target_: modelhub.model.RF3.RF3
# +---------- Channel dimensions ----------+
c_s: 384

View File

@@ -2,11 +2,11 @@ defaults:
- rf3_net
# Model architecture
_target_: modelhub.model.AF3.AF3WithConfidence
_target_: modelhub.model.RF3.RF3WithConfidence
# +---------- Mini rollout sampler ----------+
# From the AF-3 main text:
# > ...To remedy this, we developed a diffusion rollout procedure for the full-structure prediction generation during training (using a larger step size than normal)
# > ... To remedy this, we developed a diffusion rollout procedure for the full-structure prediction generation during training (using a larger step size than normal)
# They do not further elaborate on how they adjusted the step size during diffusion rollout, but this may be a fruitful area of exploration moving forwards
mini_rollout_sampler:
solver: "af3"

View File

@@ -6,7 +6,6 @@ defaults:
- _self_
net:
_target_: projects.rfscore.model.RFScore.RFScore
feature_initializer:
input_feature_embedder:
atom_attention_encoder:

View File

@@ -4,4 +4,4 @@ defaults:
- _self_
net:
_target_: projects.rfscore.model.RFScore.RFScoreWithConfidence
_target_: modelhub.model.RF3.RF3WithConfidence

View File

@@ -1,10 +0,0 @@
defaults:
- default
fit_and_evaluate_on_experimental_data_callback:
_target_: projects.rfscore.callbacks.experimental_validation.FitAndEvaluateOnExperimentalDataCallback
# Metric outputs to use as features for experimental validation
feature_metrics:
- selected_atom_by_atom_distances
datasets:
- ts1_unconditional

View File

@@ -1,30 +0,0 @@
# @package _global_
name: rfscore-arbitrary-templating
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
defaults:
- override /callbacks: default
- override /datasets: rfscore_arbitrary_templating
- override /model: rfscore
- override /logger: csv
- override /trainer/metrics: structure_prediction
tags:
# list of tags to add to the run (& on wandb to easily find & filter runs)
- arbitrary-templating
- rfscore
model:
lr_scheduler:
base_lr: 0.9e-3 # 1/2 of original learning rate (1.8e-3)
ckpt_config:
_target_: modelhub.utils.weights.CheckpointConfig
path: /net/software/containers/versions/modelhub_inference/ckpts/modelhub_af3_with_confidence_latest.ckpt
reset_optimizer: true
weight_loading_config:
_target_: modelhub.utils.weights.WeightLoadingConfig
fallback_policy: copy_and_zero_pad
param_policies:
"*.recycler.template_embedder.emb_templ.weight": zero_init

View File

@@ -1,12 +0,0 @@
# @package _global_
# ^ The "package" determines where the content of the config is placed in the output config
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
# This line enables us to access the common configs in the root-level `configs` directory
# For more information, see: https://hydra.cc/docs/advanced/search_path/
hydra:
searchpath:
- pkg://configs
defaults:
- inference_engine: ???

View File

@@ -1,10 +0,0 @@
# @package _global_
defaults:
- /inference_engine/af3
- _self_
_target_: projects.rfscore.inference_engines.rfscore.RFScoreInferenceEngine
ckpt_path: /projects/ml/rfscore/inference/rfscore_no_confidence_2025-03-03.ckpt
template_selection_strings: null

View File

@@ -1,36 +0,0 @@
# @package _global_
defaults:
- /model/af3
- _self_
model:
net:
_target_: projects.rfscore.model.RFScore.RFScore
feature_initializer:
input_feature_embedder:
atom_attention_encoder:
c_atom_1d_features: 392 # Add three channels to concatenate the ground-truth reference conformer as well
atom_1d_features:
- ref_pos
- ref_charge
- ref_mask
- ref_element
- ref_atom_name_chars
- ref_pos_ground_truth
recycler:
template_embedder:
raw_template_dim: 66
diffusion_module:
atom_attention_encoder:
c_atom_1d_features: 392 # Add three channels to concatenate the ground-truth reference conformer as well
atom_1d_features:
- ref_pos
- ref_charge
- ref_mask
- ref_element
- ref_atom_name_chars
- ref_pos_ground_truth

View File

@@ -1,52 +0,0 @@
# @package _global_
# ^ The "package" determines where the content of the config is placed in the output config
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
# This line enables us to access the common configs in the root-level `configs` directory
# For more information, see: https://hydra.cc/docs/advanced/search_path/
hydra:
searchpath:
- pkg://configs
# @package _global_
# ^ The "package" determines where the content of the config is placed in the output config
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
defaults:
- callbacks: default
- logger: csv
- trainer: af3
- paths: default
- datasets: ???
- dataloader: default
- hydra: default
- model: rfscore
# We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
- _self_
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: ???
# debug configs to add onto any experiment for quickly testing or debugging code
- debug: null
# DO NOT set these here. Set them in the relevant experiment config file.
# ... these are just here to ensure users always specify these fields in their experiment configs.
name: ???
tags: ???
# NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
# here.
# ... task name (determines the output directory path)
task_name: "train"
project: rfscore # required for W&B logging
seed: 42
# Provide checkpoint path to resume training from a checkpoint
# NOTE: If using W&B, must also set the `id` and `resume` fields in the `logger/wandb` config
ckpt_path: null

View File

@@ -1,8 +0,0 @@
defaults:
- structure_prediction
extra_info:
_target_: projects.rfscore.metrics.metadata.ExtraInfo
selected_atom_by_atom_distances:
_target_: projects.rfscore.metrics.distances.SelectedAtomByAtomDistances

View File

@@ -1,55 +0,0 @@
# @package _global_
# ^ The "package" determines where the content of the config is placed in the output config
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
# This line enables us to access the common configs in the root-level `configs` directory
# For more information, see: https://hydra.cc/docs/advanced/search_path/
hydra:
searchpath:
- pkg://configs
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
defaults:
- callbacks: default
- logger: csv
- trainer: af3
- paths: default
- datasets: ???
- dataloader: default
- hydra: default
- model: rfscore
# We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
- _self_
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: ???
# debug configs to add onto any experiment for quickly testing or debugging code
- debug: null
# DO NOT set these here. Set them in the relevant experiment config file.
# ... these are just here to ensure users always specify these fields in their experiment configs.
name: ???
tags: ???
# NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
# here.
# ... task name (determines the output directory path)
task_name: "validate"
project: rfscore # required for W&B logging
seed: 42
# Dump CIF files for validation structures
callbacks:
dump_validation_structures_callback:
dump_predictions: True
one_model_per_file: False
dump_trajectories: False
# passing checkpoint path required for validation
# DO NOT set here; set in the experiment config file
ckpt_path: ???

View File

@@ -1,4 +1,4 @@
from modelhub.model.AF3 import AF3, AF3WithConfidence
from modelhub.model.RF3 import AF3, AF3WithConfidence
from projects.rfscore.model.recycler import RFScoreRecycler

View File

@@ -1,5 +1,5 @@
from modelhub.model.AF3_structure import Recycler
from projects.rfscore.model.template_embedder import RFScoreTemplateEmbedder
from modelhub.model.RF3_structure import Recycler
from projects.rfscore.model.template_embedder import RF3TemplateEmbedder
class RFScoreRecycler(Recycler):
@@ -10,7 +10,7 @@ class RFScoreRecycler(Recycler):
super().__init__(**kwargs)
# ... override the template embedder to use the RFScore template embedder, which provides additional conditioning
self.template_embedder = RFScoreTemplateEmbedder(
self.template_embedder = RF3TemplateEmbedder(
c_z=kwargs["c_z"],
**kwargs["template_embedder"],
)

View File

@@ -4,14 +4,12 @@ from torch.nn.functional import relu
from modelhub.model.layers.pairformer_layers import PairformerBlock
from modelhub.training.checkpoint import activation_checkpointing
from modelhub.utils.torch_utils import device_of
from projects.rfscore.model.embeddings import FourierEmbedding
from projects.rfscore.transforms.ground_truth_template import (
from modelhub.data.ground_truth_template import (
af3_noise_scale_to_noise_level,
)
class RFScoreTemplateEmbedder(nn.Module):
class RF3TemplateEmbedder(nn.Module):
def __init__(
self,
n_block,

View File

@@ -1,6 +1,15 @@
from datahub.transforms._checks import check_atom_array_annotation
from datahub.transforms.crop import compute_local_hash
from modelhub.data.ground_truth_template import (
FeaturizeNoisedGroundTruthAsTemplateDistogram,
TokenGroupNoiseScaleSampler,
af3_noise_scale_distribution_wrapped,
af3_noise_scale_to_noise_level,
)
from omegaconf import DictConfig
from cifutils.enums import ChainType
from functools import partial
import torch
def annotate_pre_crop_hash(data: dict) -> dict:
hash_pre = compute_local_hash(data["atom_array"])
@@ -29,3 +38,89 @@ def set_to_occupancy_0_where_crop_hashes_differ(data: dict) -> dict:
atom_array.occupancy[mask] = 0
return data
def build_ground_truth_distogram_transform(
*,
template_noise_scales: dict[str, float | None] | DictConfig,
allowed_chain_types_for_conditioning: list[ChainType] | None = None,
p_condition_per_token: float = 1.0,
p_provide_inter_molecule_distances: float = 0.0,
is_inference: bool = False,
) -> FeaturizeNoisedGroundTruthAsTemplateDistogram:
"""
Build a FeaturizeNoisedGroundTruthAsTemplateDistogram transform for either training or inference.
For inference, we must be deterministic, so we:
- Use constant noise scales (1.0)
- Always apply token-level conditioning
Args:
template_noise_scales (dict[str, float | None] | DictConfig):
Noise scales for 'atomized' and 'not_atomized' tokens. If is_inference=True, these are used as constants.
If is_inference=False, these are used as upper bounds for the noise scale distribution.
allowed_chain_types_for_conditioning (list[ChainType] | None):
List of allowed chain types for conditioning. None disables conditioning.
p_condition_per_token (float):
Probability of conditioning each eligible token. For inference, this is always 1.0.
p_provide_inter_molecule_distances (float):
Probability of providing inter-molecule (inter-chain) distances.
is_inference (bool):
If True, use constant noise scales and always condition. If False, use distributions and provided probability.
Returns:
FeaturizeNoisedGroundTruthAsTemplateDistogram: The configured transform.
"""
mask_and_sampling_fns = []
if is_inference:
# Use constant noise scales for inference, rather than sampling (no stochasticity)
if template_noise_scales["atomized"] is not None:
mask_and_sampling_fns.append(
(
lambda arr: arr.atomize,
lambda size: torch.ones(size) * template_noise_scales["atomized"],
)
)
if template_noise_scales["not_atomized"] is not None:
mask_and_sampling_fns.append(
(
lambda arr: ~arr.atomize,
lambda size: torch.ones(size) * template_noise_scales["not_atomized"],
)
)
p_condition = 1.0 # Always condition for inference (no stochasticity)
else:
# Use noise scale distributions for training
if template_noise_scales["atomized"] is not None:
mask_and_sampling_fns.append(
(
lambda arr: arr.atomize,
partial(
af3_noise_scale_distribution_wrapped,
upper_noise_level=af3_noise_scale_to_noise_level(
template_noise_scales["atomized"]
).item(),
),
)
)
if template_noise_scales["not_atomized"] is not None:
mask_and_sampling_fns.append(
(
lambda arr: ~arr.atomize,
partial(
af3_noise_scale_distribution_wrapped,
upper_noise_level=af3_noise_scale_to_noise_level(
template_noise_scales["not_atomized"]
).item(),
),
)
)
p_condition = p_condition_per_token # Apply conditioning to only some tokens during training
return FeaturizeNoisedGroundTruthAsTemplateDistogram(
noise_scale_distribution=TokenGroupNoiseScaleSampler(
mask_and_sampling_fns=tuple(mask_and_sampling_fns),
),
allowed_chain_types=allowed_chain_types_for_conditioning,
p_condition_per_token=p_condition,
p_provide_inter_molecule_distances=p_provide_inter_molecule_distances,
)

View File

@@ -98,7 +98,7 @@ from modelhub.data.pipeline_utils import (
annotate_pre_crop_hash,
set_to_occupancy_0_where_crop_hashes_differ,
)
from projects.rfscore.pipelines.composed import build_ground_truth_distogram_transform
from modelhub.data.pipeline_utils import build_ground_truth_distogram_transform
def TrainingRoute(transform):

View File

@@ -14,7 +14,7 @@ from lightning.fabric import seed_everything
from omegaconf import OmegaConf
from modelhub.inference_engines.base import InferenceEngine
from modelhub.model.AF3 import ShouldEarlyStopFn
from modelhub.model.RF3 import ShouldEarlyStopFn
from modelhub.utils.datasets import (
assemble_distributed_inference_loader_from_list_of_paths,
)

View File

@@ -11,7 +11,7 @@ from modelhub.diffusion_samplers.inference_sampler import (
SampleDiffusion,
SamplePartialDiffusion,
)
from modelhub.model.AF3_structure import DiffusionModule, DistogramHead, Recycler
from modelhub.model.RF3_structure import DiffusionModule, DistogramHead, Recycler
from modelhub.model.layers.pairformer_layers import (
FeatureInitializer,
)
@@ -52,8 +52,8 @@ class ShouldEarlyStopFn(Protocol):
...
class AF3(nn.Module):
"""AF3 Network module.
class RF3(nn.Module):
"""RF3 Network module.
We adhere to the PyTorch Lightning Style Guide; see (1).
@@ -89,14 +89,11 @@ class AF3(nn.Module):
c_z: Token-level pair reprentation channel dimension
c_atom: Atom-level single reprentation channel dimension
c_atompair: Atom-level pair reprentation channel dimension
c_s_inputs: TBD what the heck this is
loss: Arguments for the loss function
partial_optimizer: Optimizer (partially initialized) to be used for training. The "configure_optimizers" method will finish instantiating the optimizer.
partial_lr_scheduler: Learning rate scheduler (partially initialized) to be used for training. The "configure_optimizers" method will finish instantiating the scheduler.
c_s_inputs: Output dimension of the InputFeatureEmbedder
"""
super().__init__()
# ... initialize the FeatureInitializer, which creates the initial token- and atom-level representations and conditioning
# ... initialize the FeatureInitializer, which creates the initial token-level representations and conditioning
self.feature_initializer = FeatureInitializer(
c_s=c_s,
c_z=c_z,
@@ -286,6 +283,7 @@ class AF3(nn.Module):
def recycle(
self,
# TODO: Jax typing
S_inputs_I,
S_init_I,
Z_init_II,
@@ -311,7 +309,7 @@ class AF3(nn.Module):
)
class AF3WithConfidence(AF3):
class RF3WithConfidence(RF3):
"""Model for training and inference with confidence metric computation"""
def __init__(
@@ -444,6 +442,8 @@ class AF3WithConfidence(AF3):
)
# ... run non-batched confidence head
# TODO: Write a version of the confidence head that splits into batches based on memory available
# (Currently, we OOM with the full batch size, so we loop, which is slow)
D = sample_diffusion_outs["X_L"].shape[0]
confidence_stack = {}
for i in range(D):

View File

@@ -13,7 +13,7 @@ from modelhub.model.layers.pairformer_layers import (
MSAModule,
PairformerBlock,
RelativePositionEncoding,
TemplateEmbedder,
RF3TemplateEmbedder,
)
from modelhub.training.checkpoint import activation_checkpointing
@@ -292,7 +292,7 @@ class Recycler(nn.Module):
nn.LayerNorm(c_z),
linearNoBias(c_z, c_z),
)
self.template_embedder = TemplateEmbedder(c_z=c_z, **template_embedder)
self.template_embedder = RF3TemplateEmbedder(c_z=c_z, **template_embedder)
self.msa_module = MSAModule(**msa_module)
self.process_sh = nn.Sequential(
nn.LayerNorm(c_s),

View File

@@ -6,7 +6,7 @@ from opt_einsum import contract as einsum
from modelhub.chemical import ChemicalData as ChemData
from modelhub.flow_matching import data_utils as du
from modelhub.model.AF3_structure import FourierEmbedding
from modelhub.model.RF3_structure import FourierEmbedding
from modelhub.model.layers.Attention_module import FeedForwardLayer
from modelhub.model.layers.SE3_network import (
SE3TransformerWrapper,

View File

@@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import modelhub
from modelhub.model.AF3_structure import PairformerBlock, linearNoBias
from modelhub.model.RF3_structure import PairformerBlock, linearNoBias
# TODO: Get from RF2AA encoding instead
CHEM_DATA_LEGACY = {"NHEAVY": 23, "aa2num": {"UNK": 20, "GLY": 7, "MAS": 21}}

View File

@@ -2,7 +2,7 @@ import torch
from torch import nn
from torch.nn.functional import one_hot, relu
from modelhub.model.AF3_blocks import MsaPairWeightedAverage, MsaSubsampleEmbedder
from modelhub.model.RF3_blocks import MsaPairWeightedAverage, MsaSubsampleEmbedder
from modelhub.model.layers.af3_diffusion_transformer import AtomTransformer
from modelhub.model.layers.Attention_module import (
TriangleAttention,
@@ -23,6 +23,9 @@ from modelhub.model.layers.outer_product import (
)
from modelhub.training.checkpoint import activation_checkpointing
from modelhub.util_module import Dropout
from modelhub.data.ground_truth_template import (
af3_noise_scale_to_noise_level,
)
class AtomAttentionEncoderPairformer(nn.Module):
@@ -646,7 +649,11 @@ class MSAModule(nn.Module):
return Z_II
class TemplateEmbedder(nn.Module):
class AF3TemplateEmbedder(nn.Module):
"""
AF3-like TemplateEmbedding (e.g., protein-only, etc.)
Unused in RF3.
"""
def __init__(self, n_block, raw_template_dim, c_z, c, p_drop):
super().__init__()
self.c = c
@@ -751,3 +758,110 @@ class TemplateEmbedder(nn.Module):
template_restype,
asym_id,
)
class RF3TemplateEmbedder(nn.Module):
"""
Template track that enables conditioning on noisy ground-truth templates at the token level.
Supports all chain types.
"""
def __init__(
self,
n_block,
raw_template_dim,
c_z,
c,
p_drop,
):
super().__init__()
self.c = c
self.emb_pair = nn.Linear(c_z, c, bias=False)
self.norm_pair_before_pairformer = nn.LayerNorm(c_z)
self.norm_after_pairformer = nn.LayerNorm(c)
self.emb_templ = nn.Linear(raw_template_dim, c, bias=False)
# template pairformer does not operate on sequence representation
self.pairformer = nn.ModuleList(
[
PairformerBlock(
c_s=0,
c_z=c,
p_drop=p_drop,
triangle_multiplication=dict(d_hidden=c),
triangle_attention=dict(d_hidden=c),
attention_pair_bias={},
n_transition=4,
)
for _ in range(n_block)
]
)
# NOTE: this is not consistent with AF3 paper which outputs this tensor in the template_channel dimension
# In Algorithm 1, line 9, the outputs of this function are added to the Z_II tensor which has dimensions [B, I, I, C_z]
# so we make the outputs of this module also has those dimensions
self.agg_emb = nn.Linear(c, c_z, bias=False)
def forward(
self,
f,
Z_II,
):
@activation_checkpointing
def embed_templates_like_rfscore(
has_distogram_condition, # [I, I]
distogram_condition_noise_scale, # [I]
distogram_condition, # [I, I, 64], where 64 is the number of distogram bins
):
I = Z_II.shape[0] # n_tokens
# Transform noise scale to reasonable range
joint_noise_scale = (
distogram_condition_noise_scale[None, :] ** 2
+ distogram_condition_noise_scale[:, None] ** 2
).sqrt()
joint_noise_level = af3_noise_scale_to_noise_level(joint_noise_scale)
# ---------------------------- #
# ... concatenate along the channel dimension
template_feats = torch.cat(
[
distogram_condition, # [I, I, 64]
has_distogram_condition.unsqueeze(-1), # [I, I, 1]
joint_noise_level.unsqueeze(-1), # [I, I, 1]
],
dim=-1,
) # [I, I, 66]
# ... remove any invalid interactions
template_feats = template_feats * has_distogram_condition.unsqueeze(
-1
) # [I, I, 66], where 66 = 64 + 1 + 1
# ... embed template features
template_channels = self.emb_templ(template_feats) # [I, I, c]
# ---------------------------- #
# ... pass through pairformer
u_II = torch.zeros(I, I, self.c, device=Z_II.device)
v_II = (
self.emb_pair(self.norm_pair_before_pairformer(Z_II))
+ template_channels
) # [I, I, c]
for block in self.pairformer:
_, v_II = block(None, v_II)
u_II = u_II + self.norm_after_pairformer(v_II)
return self.agg_emb(relu(u_II))
# rfscore template embedding (noisy ground-truth template as input)
embedded_templates = embed_templates_like_rfscore(
has_distogram_condition=f["has_distogram_condition"], # [I, I]
distogram_condition_noise_scale=f["distogram_condition_noise_scale"], # [I]
distogram_condition=f[
"distogram_condition"
], # [I, I, 64], where 64 is the number of distogram bins
)
return embedded_templates

View File

@@ -13,7 +13,7 @@ from modelhub.loss.af3_losses import (
SubunitSymmetryResolution,
)
from modelhub.metrics.base import MetricManager
from modelhub.model.AF3 import ShouldEarlyStopFn
from modelhub.model.RF3 import ShouldEarlyStopFn
from modelhub.trainers.fabric import FabricTrainer
from modelhub.training.EMA import EMA
from modelhub.utils.ddp import RankedLogger