fix: weight initialization bug in chunked P_LL (#229)

- Also cache static projections for speedup
This commit is contained in:
Aiko
2026-02-25 16:29:51 -08:00
committed by GitHub
parent 99a0cb773a
commit 290b2fd0bb
6 changed files with 135 additions and 104 deletions

3
.gitignore vendored
View File

@@ -184,6 +184,9 @@ ruff.toml
# Development
dev.py
lib
.gitmodules
.ipd/
# Pytest
*.benchmarks/

View File

@@ -1,4 +1,5 @@
import inspect
import logging
import time
from dataclasses import dataclass
from typing import Any, Literal
@@ -16,6 +17,7 @@ from foundry.utils.rotation_augmentation import (
uniform_random_rotation,
)
logging.basicConfig(level=logging.INFO)
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
@@ -246,7 +248,9 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
**other_outputs,
)
toc = time.time()
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
ranked_logger.info(
f"[chunked] step {step_num}: {(toc - tic)*1000:.1f} ms"
)
else:
# Standard mode: P_LL is included in initializer_outputs
outs = diffusion_module(
@@ -473,7 +477,9 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
**other_outputs,
)
toc = time.time()
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
ranked_logger.info(
f"[chunked] step {step_num}: {(toc - tic)*1000:.1f} ms"
)
else:
# Standard mode: P_LL is included in initializer_outputs
outs = diffusion_module(

View File

@@ -10,23 +10,26 @@ from typing import Optional
import torch
import torch.nn as nn
from rfd3.model.layers.blocks import PositionPairDistEmbedder, SinusoidalDistEmbed
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
class ChunkedPositionPairDistEmbedder(nn.Module):
class ChunkedPositionPairDistEmbedder:
"""
Memory-efficient version of PositionPairDistEmbedder that computes pairs on-demand.
Uses a trained PositionPairDistEmbedder instance and shares the forward method.
"""
def __init__(self, c_atompair, embed_frame=True):
super().__init__()
self.c_atompair = c_atompair
self.embed_frame = embed_frame
if embed_frame:
self.process_d = linearNoBias(3, c_atompair)
self.process_inverse_dist = linearNoBias(1, c_atompair)
self.process_valid_mask = linearNoBias(1, c_atompair)
def __init__(self, embedder_instance: PositionPairDistEmbedder):
"""
Initialize the ChunkedPositionPairDistEmbedder from a parent PositionPairDistEmbedder instance.
"""
self.embed_frame = embedder_instance.embed_frame
if embedder_instance.embed_frame:
self.process_d = embedder_instance.process_d
self.process_inverse_dist = embedder_instance.process_inverse_dist
self.process_valid_mask = embedder_instance.process_valid_mask
self.forward = embedder_instance.forward
def compute_pairs_chunked(
self,
@@ -45,8 +48,6 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
Returns:
P_sparse: Pairwise embeddings [B, k, c_atompair]
"""
B, k = key_pos.shape[:2]
# Compute pairwise distances: [B, k, 3]
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
@@ -78,20 +79,25 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
return P_pairs
class ChunkedSinusoidalDistEmbed(nn.Module):
class ChunkedSinusoidalDistEmbed:
"""
Memory-efficient version of SinusoidalDistEmbed.
Uses a trained SinusoidalDistEmbed instance and shares the forward method.
"""
def __init__(self, c_atompair, n_freqs=32):
super().__init__()
assert c_atompair % 2 == 0, "Output embedding dim must be even"
def __init__(self, embedder_instance: SinusoidalDistEmbed):
"""
Initialize the ChunkedSinusoidalDistEmbed from a parent SinusoidalDistEmbed instance.
"""
assert (
embedder_instance.c_atompair % 2 == 0
), "Output embedding dim must be even"
self.n_freqs = n_freqs
self.c_atompair = c_atompair
self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
self.process_valid_mask = linearNoBias(1, c_atompair)
self.n_freqs = embedder_instance.n_freqs
self.c_atompair = embedder_instance.c_atompair
self.output_proj = embedder_instance.output_proj
self.process_valid_mask = embedder_instance.process_valid_mask
self.forward = embedder_instance.forward
def compute_pairs_chunked(
self,
@@ -102,7 +108,6 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
"""
Compute sinusoidal distance embeddings for specific query-key pairs.
"""
B, k = key_pos.shape[:2]
device = query_pos.device
# Compute pairwise distances
@@ -134,24 +139,28 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
return P_pairs
class ChunkedPairwiseEmbedder(nn.Module):
class ChunkedPairwiseEmbedder:
"""
Main chunked pairwise embedder that combines all embedding types.
This replaces the full P_LL computation with sparse computation.
Not an nn.Module: all sub-components are shared references from TokenInitializer
and already registered there. Inheriting nn.Module would cause them to appear
under duplicate paths (chunked_pairwise_embedder.*) in state_dict, which don't
exist in checkpoints trained without use_chunked_pll.
"""
def __init__(
self,
c_atompair: int,
motif_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
ref_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
motif_pos_embedder: ChunkedSinusoidalDistEmbed,
ref_pos_embedder: ChunkedPositionPairDistEmbedder,
process_single_l: Optional[nn.Module] = None,
process_single_m: Optional[nn.Module] = None,
process_z: Optional[nn.Module] = None,
pair_mlp: Optional[nn.Module] = None,
**kwargs,
):
super().__init__()
self.c_atompair = c_atompair
self.motif_pos_embedder = motif_pos_embedder
self.ref_pos_embedder = ref_pos_embedder
@@ -188,31 +197,53 @@ class ChunkedPairwiseEmbedder(nn.Module):
linearNoBias(c_atompair, c_atompair),
)
# Cached static projections — populated once at tokenization by
# cache_static_projections(). None means "not yet cached; run the MLP."
self._sl_cached: Optional[torch.Tensor] = None # [L, c_atompair]
self._sm_cached: Optional[torch.Tensor] = None # [L, c_atompair]
self._Z_proc_cached: Optional[torch.Tensor] = None # [I, I, c_atompair]
def cache_static_projections(
self, C_L: torch.Tensor, Z_init_II: torch.Tensor
) -> None:
"""
Precompute and cache the three MLP projections that are identical across
all diffusion steps (they depend only on the static tokenization outputs).
Call this once after tokenization, before the diffusion loop.
forward_chunked will then replace those MLP calls with free tensor indexing.
Args:
C_L: Atom features [L, c_token]
Z_init_II: Token-pair features [I, I, c_z]
"""
self._sl_cached = self.process_single_l(C_L) # [L, c_atompair]
self._sm_cached = self.process_single_m(C_L) # [L, c_atompair]
self._Z_proc_cached = self.process_z(Z_init_II) # [I, I, c_atompair]
def forward_chunked(
self,
f: dict,
indices: torch.Tensor, # [B, L, k] - sparse attention indices
C_L: torch.Tensor, # [B, L, c_token] - atom features
C_L: torch.Tensor, # [L, c_token] or [B, L, c_token] - atom features
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
tok_idx: torch.Tensor, # [L] - atom to token mapping
) -> torch.Tensor:
# Add logging for chunked P_LL computation
import logging
logger = logging.getLogger(__name__)
logger.info(
f"ChunkedPairwiseEmbedder: Computing sparse P_LL for {indices.shape[1]} atoms with {indices.shape[2]} neighbors each"
)
"""
Compute P_LL only for the pairs specified by attention indices.
When cache_static_projections() has been called beforehand, the three MLP
terms (process_single_l, process_single_m, process_z) are replaced with
free tensor index operations, since those projections are identical across
all diffusion steps.
Args:
f: Feature dictionary
indices: Sparse attention indices [B, L, k]
C_L: Atom-level features [B, L, c_token]
f: Feature dictionary (motif_pos, ref_pos, etc.)
indices: Sparse attention indices [B, L, k]
C_L: Atom-level features [L, c_token] or [B, L, c_token]
Z_init_II: Token-level pair features [I, I, c_z]
tok_idx: Atom to token mapping [L]
tok_idx: Atom-to-token mapping [L]
Returns:
P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
"""
@@ -286,73 +317,53 @@ class ChunkedPairwiseEmbedder(nn.Module):
)
P_LL_sparse[:, l, :, :] += ref_pairs
# 3. Single embedding terms (broadcasted)
# Expand C_L to match valid_indices batch dimension
if C_L.shape[0] != B:
C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
# Gather key features for each query
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
C_L_keys = torch.gather(
C_L_queries,
1,
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
) # [B, L, k, c_token]
# Add single embeddings - match standard implementation structure
# Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
# We need to broadcast from [B, L, k, c_atompair] to match this
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
# 3. Single embedding terms
if self._sl_cached is not None:
# Fast path: MLP already run at tokenisation — just index into the result.
# sl_cached [L, c_atompair]: query atom l always maps to row l.
single_l = self._sl_cached.unsqueeze(0).unsqueeze(2).expand(B, -1, k, -1)
# sm_cached [L, c_atompair]: key atoms are given by valid_indices [B, L, k].
single_m = self._sm_cached[valid_indices] # [B, L, k, c_atompair]
else:
# Slow path (no cache): run the MLPs over the raw atom features.
if C_L.shape[0] != B:
C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
C_L_keys = torch.gather(
C_L_queries,
1,
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
) # [B, L, k, c_token]
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
P_LL_sparse += single_l + single_m
# 4. Token pair features Z_init_II
# Map atoms to tokens and gather token pair features
# Handle tok_idx dimensions properly
# Map atoms to tokens and gather token pair features.
if tok_idx.dim() == 1: # [L] - add batch dimension for consistency
tok_idx_expanded = tok_idx.unsqueeze(0) # [1, L]
else:
tok_idx_expanded = tok_idx
# Expand tok_idx_expanded to match valid_indices batch dimension
if tok_idx_expanded.shape[0] != B:
tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
# Use valid_indices for token mapping as well
tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
# Gather Z_init_II[tok_queries, tok_keys] with safe indexing
# Z_init_II shape is [I, I, c_z] (3D), not 4D
# tok_queries shape: [B, L, k] - each value is a token index
# We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
I_z, I_z2, c_z = Z_init_II.shape
# CRITICAL: Match standard implementation exactly!
# Standard does: self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :]
# This means: 1) Process Z_init_II first, 2) Then do double token indexing
# Step 1: Process Z_init_II to get processed token pair features
Z_processed = self.process_z(Z_init_II) # [I, I, c_atompair]
# Step 2: Do the double indexing like the standard implementation
# Standard: Z_processed[..., tok_idx, :, :][..., tok_idx, :]
# This creates Z_processed[tok_idx, :][:, tok_idx] which is [L, L, c_atompair]
# Then we need to gather the sparse version
if self._Z_proc_cached is not None:
# Fast path: process_z already run at tokenisation.
Z_processed = self._Z_proc_cached # [I, I, c_atompair]
else:
# Slow path: run the MLP over the token-pair matrix.
Z_processed = self.process_z(Z_init_II) # [I, I, c_atompair]
I_z, I_z2 = Z_processed.shape[:2]
Z_pairs_processed = torch.zeros(
B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
)
for b in range(B):
# For this batch, get the token queries and keys
tq = tok_queries[b] # [L, k]
tk = tok_keys[b] # [L, k]
# Ensure indices are within bounds
tq = torch.clamp(tq, 0, I_z - 1)
tk = torch.clamp(tk, 0, I_z2 - 1)
# Apply the double token indexing like standard implementation
tq = torch.clamp(tok_queries[b], 0, I_z - 1) # [L, k]
tk = torch.clamp(tok_keys[b], 0, I_z2 - 1) # [L, k]
Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
P_LL_sparse += Z_pairs_processed
@@ -369,8 +380,12 @@ def create_chunked_embedders(
"""
Factory function to create chunked pairwise embedder with standard components.
"""
motif_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
ref_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
motif_pos_embedder = ChunkedSinusoidalDistEmbed(
embedder_instance=SinusoidalDistEmbed(c_atompair, embed_frame)
)
ref_pos_embedder = ChunkedPositionPairDistEmbedder(
embedder_instance=PositionPairDistEmbedder(c_atompair, embed_frame)
)
return ChunkedPairwiseEmbedder(
c_atompair=c_atompair,

View File

@@ -128,17 +128,21 @@ class TokenInitializer(nn.Module):
# Atom pair feature processing
if self.use_chunked_pll:
# Initialize chunked embedders and share the trained MLPs!
# Share trained components so chunked inference uses same weights as full training.
motif_pos_embedder = ChunkedSinusoidalDistEmbed(
embedder_instance=self.motif_pos_embedder
)
ref_pos_embedder = ChunkedPositionPairDistEmbedder(
embedder_instance=self.ref_pos_embedder
)
self.chunked_pairwise_embedder = ChunkedPairwiseEmbedder(
c_atompair=c_atompair,
motif_pos_embedder=ChunkedSinusoidalDistEmbed(c_atompair=c_atompair),
ref_pos_embedder=ChunkedPositionPairDistEmbedder(
c_atompair, embed_frame=False
),
process_single_l=self.process_single_l, # Share trained parameters!
process_single_m=self.process_single_m, # Share trained parameters!
process_z=self.process_z, # Share trained parameters!
pair_mlp=self.pair_mlp, # Share trained parameters!
motif_pos_embedder=motif_pos_embedder,
ref_pos_embedder=ref_pos_embedder,
process_single_l=self.process_single_l,
process_single_m=self.process_single_m,
process_z=self.process_z,
pair_mlp=self.pair_mlp,
)
self.process_pll = linearNoBias(c_atompair, c_atompair)
self.project_pll = linearNoBias(c_atompair, c_z)
@@ -223,7 +227,9 @@ class TokenInitializer(nn.Module):
C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :]
if self.use_chunked_pll:
# Chunked mode: return embedder for later sparse computation
# Precompute static MLP projections once so forward_chunked can
# skip those MLP calls at every subsequent diffusion step.
self.chunked_pairwise_embedder.cache_static_projections(C_L, Z_init_II)
return {
"Q_L_init": Q_L_init,
"C_L": C_L,

View File

@@ -1,4 +1,4 @@
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/foundry_exec.sh" "$0" "$@"'
import os

View File

@@ -13,6 +13,7 @@ from rfd3.trainer.trainer_utils import (
_build_atom_array_stack,
_cleanup_virtual_atoms_and_assign_atom_name_elements,
_reassign_unindexed_token_chains,
_remap_outputs,
_reorder_dict,
process_unindexed_outputs,
)