fix: weight initialization bug in chunked P_LL (#229)

- Also cache static projections for speedup
This commit is contained in:
Aiko
2026-02-25 16:29:51 -08:00
committed by GitHub
parent 99a0cb773a
commit 290b2fd0bb
6 changed files with 135 additions and 104 deletions

3
.gitignore vendored
View File

@@ -184,6 +184,9 @@ ruff.toml
# Development # Development
dev.py dev.py
lib
.gitmodules
.ipd/
# Pytest # Pytest
*.benchmarks/ *.benchmarks/

View File

@@ -1,4 +1,5 @@
import inspect import inspect
import logging
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
@@ -16,6 +17,7 @@ from foundry.utils.rotation_augmentation import (
uniform_random_rotation, uniform_random_rotation,
) )
logging.basicConfig(level=logging.INFO)
ranked_logger = RankedLogger(__name__, rank_zero_only=True) ranked_logger = RankedLogger(__name__, rank_zero_only=True)
@@ -246,7 +248,9 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
**other_outputs, **other_outputs,
) )
toc = time.time() toc = time.time()
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds") ranked_logger.info(
f"[chunked] step {step_num}: {(toc - tic)*1000:.1f} ms"
)
else: else:
# Standard mode: P_LL is included in initializer_outputs # Standard mode: P_LL is included in initializer_outputs
outs = diffusion_module( outs = diffusion_module(
@@ -473,7 +477,9 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
**other_outputs, **other_outputs,
) )
toc = time.time() toc = time.time()
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds") ranked_logger.info(
f"[chunked] step {step_num}: {(toc - tic)*1000:.1f} ms"
)
else: else:
# Standard mode: P_LL is included in initializer_outputs # Standard mode: P_LL is included in initializer_outputs
outs = diffusion_module( outs = diffusion_module(

View File

@@ -10,23 +10,26 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from rfd3.model.layers.blocks import PositionPairDistEmbedder, SinusoidalDistEmbed
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
class ChunkedPositionPairDistEmbedder(nn.Module): class ChunkedPositionPairDistEmbedder:
""" """
Memory-efficient version of PositionPairDistEmbedder that computes pairs on-demand. Memory-efficient version of PositionPairDistEmbedder that computes pairs on-demand.
Uses a trained PositionPairDistEmbedder instance and shares the forward method.
""" """
def __init__(self, c_atompair, embed_frame=True): def __init__(self, embedder_instance: PositionPairDistEmbedder):
super().__init__() """
self.c_atompair = c_atompair Initialize the ChunkedPositionPairDistEmbedder from a parent PositionPairDistEmbedder instance.
self.embed_frame = embed_frame """
if embed_frame: self.embed_frame = embedder_instance.embed_frame
self.process_d = linearNoBias(3, c_atompair) if embedder_instance.embed_frame:
self.process_d = embedder_instance.process_d
self.process_inverse_dist = linearNoBias(1, c_atompair) self.process_inverse_dist = embedder_instance.process_inverse_dist
self.process_valid_mask = linearNoBias(1, c_atompair) self.process_valid_mask = embedder_instance.process_valid_mask
self.forward = embedder_instance.forward
def compute_pairs_chunked( def compute_pairs_chunked(
self, self,
@@ -45,8 +48,6 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
Returns: Returns:
P_sparse: Pairwise embeddings [B, k, c_atompair] P_sparse: Pairwise embeddings [B, k, c_atompair]
""" """
B, k = key_pos.shape[:2]
# Compute pairwise distances: [B, k, 3] # Compute pairwise distances: [B, k, 3]
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3] D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
@@ -78,20 +79,25 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
return P_pairs return P_pairs
class ChunkedSinusoidalDistEmbed(nn.Module): class ChunkedSinusoidalDistEmbed:
""" """
Memory-efficient version of SinusoidalDistEmbed. Memory-efficient version of SinusoidalDistEmbed.
Uses a trained SinusoidalDistEmbed instance and shares the forward method.
""" """
def __init__(self, c_atompair, n_freqs=32): def __init__(self, embedder_instance: SinusoidalDistEmbed):
super().__init__() """
assert c_atompair % 2 == 0, "Output embedding dim must be even" Initialize the ChunkedSinusoidalDistEmbed from a parent SinusoidalDistEmbed instance.
"""
assert (
embedder_instance.c_atompair % 2 == 0
), "Output embedding dim must be even"
self.n_freqs = n_freqs self.n_freqs = embedder_instance.n_freqs
self.c_atompair = c_atompair self.c_atompair = embedder_instance.c_atompair
self.output_proj = embedder_instance.output_proj
self.output_proj = linearNoBias(2 * n_freqs, c_atompair) self.process_valid_mask = embedder_instance.process_valid_mask
self.process_valid_mask = linearNoBias(1, c_atompair) self.forward = embedder_instance.forward
def compute_pairs_chunked( def compute_pairs_chunked(
self, self,
@@ -102,7 +108,6 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
""" """
Compute sinusoidal distance embeddings for specific query-key pairs. Compute sinusoidal distance embeddings for specific query-key pairs.
""" """
B, k = key_pos.shape[:2]
device = query_pos.device device = query_pos.device
# Compute pairwise distances # Compute pairwise distances
@@ -134,24 +139,28 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
return P_pairs return P_pairs
class ChunkedPairwiseEmbedder(nn.Module): class ChunkedPairwiseEmbedder:
""" """
Main chunked pairwise embedder that combines all embedding types. Main chunked pairwise embedder that combines all embedding types.
This replaces the full P_LL computation with sparse computation. This replaces the full P_LL computation with sparse computation.
Not an nn.Module: all sub-components are shared references from TokenInitializer
and already registered there. Inheriting nn.Module would cause them to appear
under duplicate paths (chunked_pairwise_embedder.*) in state_dict, which don't
exist in checkpoints trained without use_chunked_pll.
""" """
def __init__( def __init__(
self, self,
c_atompair: int, c_atompair: int,
motif_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None, motif_pos_embedder: ChunkedSinusoidalDistEmbed,
ref_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None, ref_pos_embedder: ChunkedPositionPairDistEmbedder,
process_single_l: Optional[nn.Module] = None, process_single_l: Optional[nn.Module] = None,
process_single_m: Optional[nn.Module] = None, process_single_m: Optional[nn.Module] = None,
process_z: Optional[nn.Module] = None, process_z: Optional[nn.Module] = None,
pair_mlp: Optional[nn.Module] = None, pair_mlp: Optional[nn.Module] = None,
**kwargs, **kwargs,
): ):
super().__init__()
self.c_atompair = c_atompair self.c_atompair = c_atompair
self.motif_pos_embedder = motif_pos_embedder self.motif_pos_embedder = motif_pos_embedder
self.ref_pos_embedder = ref_pos_embedder self.ref_pos_embedder = ref_pos_embedder
@@ -188,31 +197,53 @@ class ChunkedPairwiseEmbedder(nn.Module):
linearNoBias(c_atompair, c_atompair), linearNoBias(c_atompair, c_atompair),
) )
# Cached static projections — populated once at tokenization by
# cache_static_projections(). None means "not yet cached; run the MLP."
self._sl_cached: Optional[torch.Tensor] = None # [L, c_atompair]
self._sm_cached: Optional[torch.Tensor] = None # [L, c_atompair]
self._Z_proc_cached: Optional[torch.Tensor] = None # [I, I, c_atompair]
def cache_static_projections(
self, C_L: torch.Tensor, Z_init_II: torch.Tensor
) -> None:
"""
Precompute and cache the three MLP projections that are identical across
all diffusion steps (they depend only on the static tokenization outputs).
Call this once after tokenization, before the diffusion loop.
forward_chunked will then replace those MLP calls with free tensor indexing.
Args:
C_L: Atom features [L, c_token]
Z_init_II: Token-pair features [I, I, c_z]
"""
self._sl_cached = self.process_single_l(C_L) # [L, c_atompair]
self._sm_cached = self.process_single_m(C_L) # [L, c_atompair]
self._Z_proc_cached = self.process_z(Z_init_II) # [I, I, c_atompair]
def forward_chunked( def forward_chunked(
self, self,
f: dict, f: dict,
indices: torch.Tensor, # [B, L, k] - sparse attention indices indices: torch.Tensor, # [B, L, k] - sparse attention indices
C_L: torch.Tensor, # [B, L, c_token] - atom features C_L: torch.Tensor, # [L, c_token] or [B, L, c_token] - atom features
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
tok_idx: torch.Tensor, # [L] - atom to token mapping tok_idx: torch.Tensor, # [L] - atom to token mapping
) -> torch.Tensor: ) -> torch.Tensor:
# Add logging for chunked P_LL computation
import logging
logger = logging.getLogger(__name__)
logger.info(
f"ChunkedPairwiseEmbedder: Computing sparse P_LL for {indices.shape[1]} atoms with {indices.shape[2]} neighbors each"
)
""" """
Compute P_LL only for the pairs specified by attention indices. Compute P_LL only for the pairs specified by attention indices.
When cache_static_projections() has been called beforehand, the three MLP
terms (process_single_l, process_single_m, process_z) are replaced with
free tensor index operations, since those projections are identical across
all diffusion steps.
Args: Args:
f: Feature dictionary f: Feature dictionary (motif_pos, ref_pos, etc.)
indices: Sparse attention indices [B, L, k] indices: Sparse attention indices [B, L, k]
C_L: Atom-level features [B, L, c_token] C_L: Atom-level features [L, c_token] or [B, L, c_token]
Z_init_II: Token-level pair features [I, I, c_z] Z_init_II: Token-level pair features [I, I, c_z]
tok_idx: Atom to token mapping [L] tok_idx: Atom-to-token mapping [L]
Returns: Returns:
P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair] P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
""" """
@@ -286,73 +317,53 @@ class ChunkedPairwiseEmbedder(nn.Module):
) )
P_LL_sparse[:, l, :, :] += ref_pairs P_LL_sparse[:, l, :, :] += ref_pairs
# 3. Single embedding terms (broadcasted) # 3. Single embedding terms
# Expand C_L to match valid_indices batch dimension if self._sl_cached is not None:
if C_L.shape[0] != B: # Fast path: MLP already run at tokenisation — just index into the result.
C_L = C_L.expand(B, -1, -1) # [B, L, c_token] # sl_cached [L, c_atompair]: query atom l always maps to row l.
# Gather key features for each query single_l = self._sl_cached.unsqueeze(0).unsqueeze(2).expand(B, -1, k, -1)
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token] # sm_cached [L, c_atompair]: key atoms are given by valid_indices [B, L, k].
C_L_keys = torch.gather( single_m = self._sm_cached[valid_indices] # [B, L, k, c_atompair]
C_L_queries, else:
1, # Slow path (no cache): run the MLPs over the raw atom features.
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]), if C_L.shape[0] != B:
) # [B, L, k, c_token] C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
# Add single embeddings - match standard implementation structure C_L_keys = torch.gather(
# Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3) C_L_queries,
# We need to broadcast from [B, L, k, c_atompair] to match this 1,
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair] valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair] ) # [B, L, k, c_token]
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
P_LL_sparse += single_l + single_m P_LL_sparse += single_l + single_m
# 4. Token pair features Z_init_II # 4. Token pair features Z_init_II
# Map atoms to tokens and gather token pair features # Map atoms to tokens and gather token pair features.
# Handle tok_idx dimensions properly
if tok_idx.dim() == 1: # [L] - add batch dimension for consistency if tok_idx.dim() == 1: # [L] - add batch dimension for consistency
tok_idx_expanded = tok_idx.unsqueeze(0) # [1, L] tok_idx_expanded = tok_idx.unsqueeze(0) # [1, L]
else: else:
tok_idx_expanded = tok_idx tok_idx_expanded = tok_idx
# Expand tok_idx_expanded to match valid_indices batch dimension
if tok_idx_expanded.shape[0] != B: if tok_idx_expanded.shape[0] != B:
tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L] tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k] tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
# Use valid_indices for token mapping as well
tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k] tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
# Gather Z_init_II[tok_queries, tok_keys] with safe indexing if self._Z_proc_cached is not None:
# Z_init_II shape is [I, I, c_z] (3D), not 4D # Fast path: process_z already run at tokenisation.
# tok_queries shape: [B, L, k] - each value is a token index Z_processed = self._Z_proc_cached # [I, I, c_atompair]
# We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k else:
# Slow path: run the MLP over the token-pair matrix.
I_z, I_z2, c_z = Z_init_II.shape Z_processed = self.process_z(Z_init_II) # [I, I, c_atompair]
# CRITICAL: Match standard implementation exactly!
# Standard does: self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :]
# This means: 1) Process Z_init_II first, 2) Then do double token indexing
# Step 1: Process Z_init_II to get processed token pair features
Z_processed = self.process_z(Z_init_II) # [I, I, c_atompair]
# Step 2: Do the double indexing like the standard implementation
# Standard: Z_processed[..., tok_idx, :, :][..., tok_idx, :]
# This creates Z_processed[tok_idx, :][:, tok_idx] which is [L, L, c_atompair]
# Then we need to gather the sparse version
I_z, I_z2 = Z_processed.shape[:2]
Z_pairs_processed = torch.zeros( Z_pairs_processed = torch.zeros(
B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
) )
for b in range(B): for b in range(B):
# For this batch, get the token queries and keys tq = torch.clamp(tok_queries[b], 0, I_z - 1) # [L, k]
tq = tok_queries[b] # [L, k] tk = torch.clamp(tok_keys[b], 0, I_z2 - 1) # [L, k]
tk = tok_keys[b] # [L, k]
# Ensure indices are within bounds
tq = torch.clamp(tq, 0, I_z - 1)
tk = torch.clamp(tk, 0, I_z2 - 1)
# Apply the double token indexing like standard implementation
Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair] Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
P_LL_sparse += Z_pairs_processed P_LL_sparse += Z_pairs_processed
@@ -369,8 +380,12 @@ def create_chunked_embedders(
""" """
Factory function to create chunked pairwise embedder with standard components. Factory function to create chunked pairwise embedder with standard components.
""" """
motif_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame) motif_pos_embedder = ChunkedSinusoidalDistEmbed(
ref_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame) embedder_instance=SinusoidalDistEmbed(c_atompair, embed_frame)
)
ref_pos_embedder = ChunkedPositionPairDistEmbedder(
embedder_instance=PositionPairDistEmbedder(c_atompair, embed_frame)
)
return ChunkedPairwiseEmbedder( return ChunkedPairwiseEmbedder(
c_atompair=c_atompair, c_atompair=c_atompair,

View File

@@ -128,17 +128,21 @@ class TokenInitializer(nn.Module):
# Atom pair feature processing # Atom pair feature processing
if self.use_chunked_pll: if self.use_chunked_pll:
# Initialize chunked embedders and share the trained MLPs! # Share trained components so chunked inference uses same weights as full training.
motif_pos_embedder = ChunkedSinusoidalDistEmbed(
embedder_instance=self.motif_pos_embedder
)
ref_pos_embedder = ChunkedPositionPairDistEmbedder(
embedder_instance=self.ref_pos_embedder
)
self.chunked_pairwise_embedder = ChunkedPairwiseEmbedder( self.chunked_pairwise_embedder = ChunkedPairwiseEmbedder(
c_atompair=c_atompair, c_atompair=c_atompair,
motif_pos_embedder=ChunkedSinusoidalDistEmbed(c_atompair=c_atompair), motif_pos_embedder=motif_pos_embedder,
ref_pos_embedder=ChunkedPositionPairDistEmbedder( ref_pos_embedder=ref_pos_embedder,
c_atompair, embed_frame=False process_single_l=self.process_single_l,
), process_single_m=self.process_single_m,
process_single_l=self.process_single_l, # Share trained parameters! process_z=self.process_z,
process_single_m=self.process_single_m, # Share trained parameters! pair_mlp=self.pair_mlp,
process_z=self.process_z, # Share trained parameters!
pair_mlp=self.pair_mlp, # Share trained parameters!
) )
self.process_pll = linearNoBias(c_atompair, c_atompair) self.process_pll = linearNoBias(c_atompair, c_atompair)
self.project_pll = linearNoBias(c_atompair, c_z) self.project_pll = linearNoBias(c_atompair, c_z)
@@ -223,7 +227,9 @@ class TokenInitializer(nn.Module):
C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :] C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :]
if self.use_chunked_pll: if self.use_chunked_pll:
# Chunked mode: return embedder for later sparse computation # Precompute static MLP projections once so forward_chunked can
# skip those MLP calls at every subsequent diffusion step.
self.chunked_pairwise_embedder.cache_static_projections(C_L, Z_init_II)
return { return {
"Q_L_init": Q_L_init, "Q_L_init": Q_L_init,
"C_L": C_L, "C_L": C_L,

View File

@@ -1,4 +1,4 @@
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"' #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/foundry_exec.sh" "$0" "$@"'
import os import os

View File

@@ -13,6 +13,7 @@ from rfd3.trainer.trainer_utils import (
_build_atom_array_stack, _build_atom_array_stack,
_cleanup_virtual_atoms_and_assign_atom_name_elements, _cleanup_virtual_atoms_and_assign_atom_name_elements,
_reassign_unindexed_token_chains, _reassign_unindexed_token_chains,
_remap_outputs,
_reorder_dict, _reorder_dict,
process_unindexed_outputs, process_unindexed_outputs,
) )