mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fix: weight initialization bug in chunked P_LL (#229)
- Also cache static projections for speedup
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -184,6 +184,9 @@ ruff.toml
|
||||
|
||||
# Development
|
||||
dev.py
|
||||
lib
|
||||
.gitmodules
|
||||
.ipd/
|
||||
|
||||
# Pytest
|
||||
*.benchmarks/
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
@@ -16,6 +17,7 @@ from foundry.utils.rotation_augmentation import (
|
||||
uniform_random_rotation,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@@ -246,7 +248,9 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
||||
**other_outputs,
|
||||
)
|
||||
toc = time.time()
|
||||
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
||||
ranked_logger.info(
|
||||
f"[chunked] step {step_num}: {(toc - tic)*1000:.1f} ms"
|
||||
)
|
||||
else:
|
||||
# Standard mode: P_LL is included in initializer_outputs
|
||||
outs = diffusion_module(
|
||||
@@ -473,7 +477,9 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
||||
**other_outputs,
|
||||
)
|
||||
toc = time.time()
|
||||
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
||||
ranked_logger.info(
|
||||
f"[chunked] step {step_num}: {(toc - tic)*1000:.1f} ms"
|
||||
)
|
||||
else:
|
||||
# Standard mode: P_LL is included in initializer_outputs
|
||||
outs = diffusion_module(
|
||||
|
||||
@@ -10,23 +10,26 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rfd3.model.layers.blocks import PositionPairDistEmbedder, SinusoidalDistEmbed
|
||||
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
|
||||
|
||||
|
||||
class ChunkedPositionPairDistEmbedder(nn.Module):
|
||||
class ChunkedPositionPairDistEmbedder:
|
||||
"""
|
||||
Memory-efficient version of PositionPairDistEmbedder that computes pairs on-demand.
|
||||
Uses a trained PositionPairDistEmbedder instance and shares the forward method.
|
||||
"""
|
||||
|
||||
def __init__(self, c_atompair, embed_frame=True):
|
||||
super().__init__()
|
||||
self.c_atompair = c_atompair
|
||||
self.embed_frame = embed_frame
|
||||
if embed_frame:
|
||||
self.process_d = linearNoBias(3, c_atompair)
|
||||
|
||||
self.process_inverse_dist = linearNoBias(1, c_atompair)
|
||||
self.process_valid_mask = linearNoBias(1, c_atompair)
|
||||
def __init__(self, embedder_instance: PositionPairDistEmbedder):
|
||||
"""
|
||||
Initialize the ChunkedPositionPairDistEmbedder from a parent PositionPairDistEmbedder instance.
|
||||
"""
|
||||
self.embed_frame = embedder_instance.embed_frame
|
||||
if embedder_instance.embed_frame:
|
||||
self.process_d = embedder_instance.process_d
|
||||
self.process_inverse_dist = embedder_instance.process_inverse_dist
|
||||
self.process_valid_mask = embedder_instance.process_valid_mask
|
||||
self.forward = embedder_instance.forward
|
||||
|
||||
def compute_pairs_chunked(
|
||||
self,
|
||||
@@ -45,8 +48,6 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
|
||||
Returns:
|
||||
P_sparse: Pairwise embeddings [B, k, c_atompair]
|
||||
"""
|
||||
B, k = key_pos.shape[:2]
|
||||
|
||||
# Compute pairwise distances: [B, k, 3]
|
||||
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
|
||||
|
||||
@@ -78,20 +79,25 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
|
||||
return P_pairs
|
||||
|
||||
|
||||
class ChunkedSinusoidalDistEmbed(nn.Module):
|
||||
class ChunkedSinusoidalDistEmbed:
|
||||
"""
|
||||
Memory-efficient version of SinusoidalDistEmbed.
|
||||
Uses a trained SinusoidalDistEmbed instance and shares the forward method.
|
||||
"""
|
||||
|
||||
def __init__(self, c_atompair, n_freqs=32):
|
||||
super().__init__()
|
||||
assert c_atompair % 2 == 0, "Output embedding dim must be even"
|
||||
def __init__(self, embedder_instance: SinusoidalDistEmbed):
|
||||
"""
|
||||
Initialize the ChunkedSinusoidalDistEmbed from a parent SinusoidalDistEmbed instance.
|
||||
"""
|
||||
assert (
|
||||
embedder_instance.c_atompair % 2 == 0
|
||||
), "Output embedding dim must be even"
|
||||
|
||||
self.n_freqs = n_freqs
|
||||
self.c_atompair = c_atompair
|
||||
|
||||
self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
|
||||
self.process_valid_mask = linearNoBias(1, c_atompair)
|
||||
self.n_freqs = embedder_instance.n_freqs
|
||||
self.c_atompair = embedder_instance.c_atompair
|
||||
self.output_proj = embedder_instance.output_proj
|
||||
self.process_valid_mask = embedder_instance.process_valid_mask
|
||||
self.forward = embedder_instance.forward
|
||||
|
||||
def compute_pairs_chunked(
|
||||
self,
|
||||
@@ -102,7 +108,6 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
||||
"""
|
||||
Compute sinusoidal distance embeddings for specific query-key pairs.
|
||||
"""
|
||||
B, k = key_pos.shape[:2]
|
||||
device = query_pos.device
|
||||
|
||||
# Compute pairwise distances
|
||||
@@ -134,24 +139,28 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
||||
return P_pairs
|
||||
|
||||
|
||||
class ChunkedPairwiseEmbedder(nn.Module):
|
||||
class ChunkedPairwiseEmbedder:
|
||||
"""
|
||||
Main chunked pairwise embedder that combines all embedding types.
|
||||
This replaces the full P_LL computation with sparse computation.
|
||||
|
||||
Not an nn.Module: all sub-components are shared references from TokenInitializer
|
||||
and already registered there. Inheriting nn.Module would cause them to appear
|
||||
under duplicate paths (chunked_pairwise_embedder.*) in state_dict, which don't
|
||||
exist in checkpoints trained without use_chunked_pll.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c_atompair: int,
|
||||
motif_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
|
||||
ref_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
|
||||
motif_pos_embedder: ChunkedSinusoidalDistEmbed,
|
||||
ref_pos_embedder: ChunkedPositionPairDistEmbedder,
|
||||
process_single_l: Optional[nn.Module] = None,
|
||||
process_single_m: Optional[nn.Module] = None,
|
||||
process_z: Optional[nn.Module] = None,
|
||||
pair_mlp: Optional[nn.Module] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.c_atompair = c_atompair
|
||||
self.motif_pos_embedder = motif_pos_embedder
|
||||
self.ref_pos_embedder = ref_pos_embedder
|
||||
@@ -188,30 +197,52 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
||||
linearNoBias(c_atompair, c_atompair),
|
||||
)
|
||||
|
||||
# Cached static projections — populated once at tokenization by
|
||||
# cache_static_projections(). None means "not yet cached; run the MLP."
|
||||
self._sl_cached: Optional[torch.Tensor] = None # [L, c_atompair]
|
||||
self._sm_cached: Optional[torch.Tensor] = None # [L, c_atompair]
|
||||
self._Z_proc_cached: Optional[torch.Tensor] = None # [I, I, c_atompair]
|
||||
|
||||
def cache_static_projections(
|
||||
self, C_L: torch.Tensor, Z_init_II: torch.Tensor
|
||||
) -> None:
|
||||
"""
|
||||
Precompute and cache the three MLP projections that are identical across
|
||||
all diffusion steps (they depend only on the static tokenization outputs).
|
||||
|
||||
Call this once after tokenization, before the diffusion loop.
|
||||
forward_chunked will then replace those MLP calls with free tensor indexing.
|
||||
|
||||
Args:
|
||||
C_L: Atom features [L, c_token]
|
||||
Z_init_II: Token-pair features [I, I, c_z]
|
||||
"""
|
||||
self._sl_cached = self.process_single_l(C_L) # [L, c_atompair]
|
||||
self._sm_cached = self.process_single_m(C_L) # [L, c_atompair]
|
||||
self._Z_proc_cached = self.process_z(Z_init_II) # [I, I, c_atompair]
|
||||
|
||||
def forward_chunked(
|
||||
self,
|
||||
f: dict,
|
||||
indices: torch.Tensor, # [B, L, k] - sparse attention indices
|
||||
C_L: torch.Tensor, # [B, L, c_token] - atom features
|
||||
C_L: torch.Tensor, # [L, c_token] or [B, L, c_token] - atom features
|
||||
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
|
||||
tok_idx: torch.Tensor, # [L] - atom to token mapping
|
||||
) -> torch.Tensor:
|
||||
# Add logging for chunked P_LL computation
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
f"ChunkedPairwiseEmbedder: Computing sparse P_LL for {indices.shape[1]} atoms with {indices.shape[2]} neighbors each"
|
||||
)
|
||||
"""
|
||||
Compute P_LL only for the pairs specified by attention indices.
|
||||
|
||||
When cache_static_projections() has been called beforehand, the three MLP
|
||||
terms (process_single_l, process_single_m, process_z) are replaced with
|
||||
free tensor index operations, since those projections are identical across
|
||||
all diffusion steps.
|
||||
|
||||
Args:
|
||||
f: Feature dictionary
|
||||
f: Feature dictionary (motif_pos, ref_pos, etc.)
|
||||
indices: Sparse attention indices [B, L, k]
|
||||
C_L: Atom-level features [B, L, c_token]
|
||||
C_L: Atom-level features [L, c_token] or [B, L, c_token]
|
||||
Z_init_II: Token-level pair features [I, I, c_z]
|
||||
tok_idx: Atom to token mapping [L]
|
||||
tok_idx: Atom-to-token mapping [L]
|
||||
|
||||
Returns:
|
||||
P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
|
||||
@@ -286,73 +317,53 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
||||
)
|
||||
P_LL_sparse[:, l, :, :] += ref_pairs
|
||||
|
||||
# 3. Single embedding terms (broadcasted)
|
||||
# Expand C_L to match valid_indices batch dimension
|
||||
# 3. Single embedding terms
|
||||
if self._sl_cached is not None:
|
||||
# Fast path: MLP already run at tokenisation — just index into the result.
|
||||
# sl_cached [L, c_atompair]: query atom l always maps to row l.
|
||||
single_l = self._sl_cached.unsqueeze(0).unsqueeze(2).expand(B, -1, k, -1)
|
||||
# sm_cached [L, c_atompair]: key atoms are given by valid_indices [B, L, k].
|
||||
single_m = self._sm_cached[valid_indices] # [B, L, k, c_atompair]
|
||||
else:
|
||||
# Slow path (no cache): run the MLPs over the raw atom features.
|
||||
if C_L.shape[0] != B:
|
||||
C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
|
||||
# Gather key features for each query
|
||||
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
|
||||
C_L_keys = torch.gather(
|
||||
C_L_queries,
|
||||
1,
|
||||
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
|
||||
) # [B, L, k, c_token]
|
||||
|
||||
# Add single embeddings - match standard implementation structure
|
||||
# Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
|
||||
# We need to broadcast from [B, L, k, c_atompair] to match this
|
||||
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
|
||||
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
|
||||
P_LL_sparse += single_l + single_m
|
||||
|
||||
# 4. Token pair features Z_init_II
|
||||
# Map atoms to tokens and gather token pair features
|
||||
# Handle tok_idx dimensions properly
|
||||
# Map atoms to tokens and gather token pair features.
|
||||
if tok_idx.dim() == 1: # [L] - add batch dimension for consistency
|
||||
tok_idx_expanded = tok_idx.unsqueeze(0) # [1, L]
|
||||
else:
|
||||
tok_idx_expanded = tok_idx
|
||||
|
||||
# Expand tok_idx_expanded to match valid_indices batch dimension
|
||||
if tok_idx_expanded.shape[0] != B:
|
||||
tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
|
||||
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
|
||||
# Use valid_indices for token mapping as well
|
||||
tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
|
||||
|
||||
# Gather Z_init_II[tok_queries, tok_keys] with safe indexing
|
||||
# Z_init_II shape is [I, I, c_z] (3D), not 4D
|
||||
# tok_queries shape: [B, L, k] - each value is a token index
|
||||
# We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
|
||||
|
||||
I_z, I_z2, c_z = Z_init_II.shape
|
||||
|
||||
# CRITICAL: Match standard implementation exactly!
|
||||
# Standard does: self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :]
|
||||
# This means: 1) Process Z_init_II first, 2) Then do double token indexing
|
||||
|
||||
# Step 1: Process Z_init_II to get processed token pair features
|
||||
if self._Z_proc_cached is not None:
|
||||
# Fast path: process_z already run at tokenisation.
|
||||
Z_processed = self._Z_proc_cached # [I, I, c_atompair]
|
||||
else:
|
||||
# Slow path: run the MLP over the token-pair matrix.
|
||||
Z_processed = self.process_z(Z_init_II) # [I, I, c_atompair]
|
||||
|
||||
# Step 2: Do the double indexing like the standard implementation
|
||||
# Standard: Z_processed[..., tok_idx, :, :][..., tok_idx, :]
|
||||
# This creates Z_processed[tok_idx, :][:, tok_idx] which is [L, L, c_atompair]
|
||||
# Then we need to gather the sparse version
|
||||
|
||||
I_z, I_z2 = Z_processed.shape[:2]
|
||||
Z_pairs_processed = torch.zeros(
|
||||
B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
|
||||
)
|
||||
|
||||
for b in range(B):
|
||||
# For this batch, get the token queries and keys
|
||||
tq = tok_queries[b] # [L, k]
|
||||
tk = tok_keys[b] # [L, k]
|
||||
|
||||
# Ensure indices are within bounds
|
||||
tq = torch.clamp(tq, 0, I_z - 1)
|
||||
tk = torch.clamp(tk, 0, I_z2 - 1)
|
||||
|
||||
# Apply the double token indexing like standard implementation
|
||||
tq = torch.clamp(tok_queries[b], 0, I_z - 1) # [L, k]
|
||||
tk = torch.clamp(tok_keys[b], 0, I_z2 - 1) # [L, k]
|
||||
Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
|
||||
|
||||
P_LL_sparse += Z_pairs_processed
|
||||
@@ -369,8 +380,12 @@ def create_chunked_embedders(
|
||||
"""
|
||||
Factory function to create chunked pairwise embedder with standard components.
|
||||
"""
|
||||
motif_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
|
||||
ref_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
|
||||
motif_pos_embedder = ChunkedSinusoidalDistEmbed(
|
||||
embedder_instance=SinusoidalDistEmbed(c_atompair, embed_frame)
|
||||
)
|
||||
ref_pos_embedder = ChunkedPositionPairDistEmbedder(
|
||||
embedder_instance=PositionPairDistEmbedder(c_atompair, embed_frame)
|
||||
)
|
||||
|
||||
return ChunkedPairwiseEmbedder(
|
||||
c_atompair=c_atompair,
|
||||
|
||||
@@ -128,17 +128,21 @@ class TokenInitializer(nn.Module):
|
||||
|
||||
# Atom pair feature processing
|
||||
if self.use_chunked_pll:
|
||||
# Initialize chunked embedders and share the trained MLPs!
|
||||
# Share trained components so chunked inference uses same weights as full training.
|
||||
motif_pos_embedder = ChunkedSinusoidalDistEmbed(
|
||||
embedder_instance=self.motif_pos_embedder
|
||||
)
|
||||
ref_pos_embedder = ChunkedPositionPairDistEmbedder(
|
||||
embedder_instance=self.ref_pos_embedder
|
||||
)
|
||||
self.chunked_pairwise_embedder = ChunkedPairwiseEmbedder(
|
||||
c_atompair=c_atompair,
|
||||
motif_pos_embedder=ChunkedSinusoidalDistEmbed(c_atompair=c_atompair),
|
||||
ref_pos_embedder=ChunkedPositionPairDistEmbedder(
|
||||
c_atompair, embed_frame=False
|
||||
),
|
||||
process_single_l=self.process_single_l, # Share trained parameters!
|
||||
process_single_m=self.process_single_m, # Share trained parameters!
|
||||
process_z=self.process_z, # Share trained parameters!
|
||||
pair_mlp=self.pair_mlp, # Share trained parameters!
|
||||
motif_pos_embedder=motif_pos_embedder,
|
||||
ref_pos_embedder=ref_pos_embedder,
|
||||
process_single_l=self.process_single_l,
|
||||
process_single_m=self.process_single_m,
|
||||
process_z=self.process_z,
|
||||
pair_mlp=self.pair_mlp,
|
||||
)
|
||||
self.process_pll = linearNoBias(c_atompair, c_atompair)
|
||||
self.project_pll = linearNoBias(c_atompair, c_z)
|
||||
@@ -223,7 +227,9 @@ class TokenInitializer(nn.Module):
|
||||
C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :]
|
||||
|
||||
if self.use_chunked_pll:
|
||||
# Chunked mode: return embedder for later sparse computation
|
||||
# Precompute static MLP projections once so forward_chunked can
|
||||
# skip those MLP calls at every subsequent diffusion step.
|
||||
self.chunked_pairwise_embedder.cache_static_projections(C_L, Z_init_II)
|
||||
return {
|
||||
"Q_L_init": Q_L_init,
|
||||
"C_L": C_L,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
|
||||
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/foundry_exec.sh" "$0" "$@"'
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from rfd3.trainer.trainer_utils import (
|
||||
_build_atom_array_stack,
|
||||
_cleanup_virtual_atoms_and_assign_atom_name_elements,
|
||||
_reassign_unindexed_token_chains,
|
||||
_remap_outputs,
|
||||
_reorder_dict,
|
||||
process_unindexed_outputs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user