mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fix: weight initialization bug in chunked P_LL (#229)
- Also cache static projections for speedup
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -184,6 +184,9 @@ ruff.toml
|
|||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev.py
|
dev.py
|
||||||
|
lib
|
||||||
|
.gitmodules
|
||||||
|
.ipd/
|
||||||
|
|
||||||
# Pytest
|
# Pytest
|
||||||
*.benchmarks/
|
*.benchmarks/
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
@@ -16,6 +17,7 @@ from foundry.utils.rotation_augmentation import (
|
|||||||
uniform_random_rotation,
|
uniform_random_rotation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -246,7 +248,9 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
|||||||
**other_outputs,
|
**other_outputs,
|
||||||
)
|
)
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
ranked_logger.info(
|
||||||
|
f"[chunked] step {step_num}: {(toc - tic)*1000:.1f} ms"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Standard mode: P_LL is included in initializer_outputs
|
# Standard mode: P_LL is included in initializer_outputs
|
||||||
outs = diffusion_module(
|
outs = diffusion_module(
|
||||||
@@ -473,7 +477,9 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
|||||||
**other_outputs,
|
**other_outputs,
|
||||||
)
|
)
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
ranked_logger.info(
|
||||||
|
f"[chunked] step {step_num}: {(toc - tic)*1000:.1f} ms"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Standard mode: P_LL is included in initializer_outputs
|
# Standard mode: P_LL is included in initializer_outputs
|
||||||
outs = diffusion_module(
|
outs = diffusion_module(
|
||||||
|
|||||||
@@ -10,23 +10,26 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from rfd3.model.layers.blocks import PositionPairDistEmbedder, SinusoidalDistEmbed
|
||||||
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
|
from rfd3.model.layers.layer_utils import RMSNorm, linearNoBias
|
||||||
|
|
||||||
|
|
||||||
class ChunkedPositionPairDistEmbedder(nn.Module):
|
class ChunkedPositionPairDistEmbedder:
|
||||||
"""
|
"""
|
||||||
Memory-efficient version of PositionPairDistEmbedder that computes pairs on-demand.
|
Memory-efficient version of PositionPairDistEmbedder that computes pairs on-demand.
|
||||||
|
Uses a trained PositionPairDistEmbedder instance and shares the forward method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, c_atompair, embed_frame=True):
|
def __init__(self, embedder_instance: PositionPairDistEmbedder):
|
||||||
super().__init__()
|
"""
|
||||||
self.c_atompair = c_atompair
|
Initialize the ChunkedPositionPairDistEmbedder from a parent PositionPairDistEmbedder instance.
|
||||||
self.embed_frame = embed_frame
|
"""
|
||||||
if embed_frame:
|
self.embed_frame = embedder_instance.embed_frame
|
||||||
self.process_d = linearNoBias(3, c_atompair)
|
if embedder_instance.embed_frame:
|
||||||
|
self.process_d = embedder_instance.process_d
|
||||||
self.process_inverse_dist = linearNoBias(1, c_atompair)
|
self.process_inverse_dist = embedder_instance.process_inverse_dist
|
||||||
self.process_valid_mask = linearNoBias(1, c_atompair)
|
self.process_valid_mask = embedder_instance.process_valid_mask
|
||||||
|
self.forward = embedder_instance.forward
|
||||||
|
|
||||||
def compute_pairs_chunked(
|
def compute_pairs_chunked(
|
||||||
self,
|
self,
|
||||||
@@ -45,8 +48,6 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
P_sparse: Pairwise embeddings [B, k, c_atompair]
|
P_sparse: Pairwise embeddings [B, k, c_atompair]
|
||||||
"""
|
"""
|
||||||
B, k = key_pos.shape[:2]
|
|
||||||
|
|
||||||
# Compute pairwise distances: [B, k, 3]
|
# Compute pairwise distances: [B, k, 3]
|
||||||
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
|
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
|
||||||
|
|
||||||
@@ -78,20 +79,25 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
|
|||||||
return P_pairs
|
return P_pairs
|
||||||
|
|
||||||
|
|
||||||
class ChunkedSinusoidalDistEmbed(nn.Module):
|
class ChunkedSinusoidalDistEmbed:
|
||||||
"""
|
"""
|
||||||
Memory-efficient version of SinusoidalDistEmbed.
|
Memory-efficient version of SinusoidalDistEmbed.
|
||||||
|
Uses a trained SinusoidalDistEmbed instance and shares the forward method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, c_atompair, n_freqs=32):
|
def __init__(self, embedder_instance: SinusoidalDistEmbed):
|
||||||
super().__init__()
|
"""
|
||||||
assert c_atompair % 2 == 0, "Output embedding dim must be even"
|
Initialize the ChunkedSinusoidalDistEmbed from a parent SinusoidalDistEmbed instance.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
embedder_instance.c_atompair % 2 == 0
|
||||||
|
), "Output embedding dim must be even"
|
||||||
|
|
||||||
self.n_freqs = n_freqs
|
self.n_freqs = embedder_instance.n_freqs
|
||||||
self.c_atompair = c_atompair
|
self.c_atompair = embedder_instance.c_atompair
|
||||||
|
self.output_proj = embedder_instance.output_proj
|
||||||
self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
|
self.process_valid_mask = embedder_instance.process_valid_mask
|
||||||
self.process_valid_mask = linearNoBias(1, c_atompair)
|
self.forward = embedder_instance.forward
|
||||||
|
|
||||||
def compute_pairs_chunked(
|
def compute_pairs_chunked(
|
||||||
self,
|
self,
|
||||||
@@ -102,7 +108,6 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Compute sinusoidal distance embeddings for specific query-key pairs.
|
Compute sinusoidal distance embeddings for specific query-key pairs.
|
||||||
"""
|
"""
|
||||||
B, k = key_pos.shape[:2]
|
|
||||||
device = query_pos.device
|
device = query_pos.device
|
||||||
|
|
||||||
# Compute pairwise distances
|
# Compute pairwise distances
|
||||||
@@ -134,24 +139,28 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
|||||||
return P_pairs
|
return P_pairs
|
||||||
|
|
||||||
|
|
||||||
class ChunkedPairwiseEmbedder(nn.Module):
|
class ChunkedPairwiseEmbedder:
|
||||||
"""
|
"""
|
||||||
Main chunked pairwise embedder that combines all embedding types.
|
Main chunked pairwise embedder that combines all embedding types.
|
||||||
This replaces the full P_LL computation with sparse computation.
|
This replaces the full P_LL computation with sparse computation.
|
||||||
|
|
||||||
|
Not an nn.Module: all sub-components are shared references from TokenInitializer
|
||||||
|
and already registered there. Inheriting nn.Module would cause them to appear
|
||||||
|
under duplicate paths (chunked_pairwise_embedder.*) in state_dict, which don't
|
||||||
|
exist in checkpoints trained without use_chunked_pll.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
c_atompair: int,
|
c_atompair: int,
|
||||||
motif_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
|
motif_pos_embedder: ChunkedSinusoidalDistEmbed,
|
||||||
ref_pos_embedder: Optional[ChunkedPositionPairDistEmbedder] = None,
|
ref_pos_embedder: ChunkedPositionPairDistEmbedder,
|
||||||
process_single_l: Optional[nn.Module] = None,
|
process_single_l: Optional[nn.Module] = None,
|
||||||
process_single_m: Optional[nn.Module] = None,
|
process_single_m: Optional[nn.Module] = None,
|
||||||
process_z: Optional[nn.Module] = None,
|
process_z: Optional[nn.Module] = None,
|
||||||
pair_mlp: Optional[nn.Module] = None,
|
pair_mlp: Optional[nn.Module] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
|
||||||
self.c_atompair = c_atompair
|
self.c_atompair = c_atompair
|
||||||
self.motif_pos_embedder = motif_pos_embedder
|
self.motif_pos_embedder = motif_pos_embedder
|
||||||
self.ref_pos_embedder = ref_pos_embedder
|
self.ref_pos_embedder = ref_pos_embedder
|
||||||
@@ -188,31 +197,53 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|||||||
linearNoBias(c_atompair, c_atompair),
|
linearNoBias(c_atompair, c_atompair),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Cached static projections — populated once at tokenization by
|
||||||
|
# cache_static_projections(). None means "not yet cached; run the MLP."
|
||||||
|
self._sl_cached: Optional[torch.Tensor] = None # [L, c_atompair]
|
||||||
|
self._sm_cached: Optional[torch.Tensor] = None # [L, c_atompair]
|
||||||
|
self._Z_proc_cached: Optional[torch.Tensor] = None # [I, I, c_atompair]
|
||||||
|
|
||||||
|
def cache_static_projections(
|
||||||
|
self, C_L: torch.Tensor, Z_init_II: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Precompute and cache the three MLP projections that are identical across
|
||||||
|
all diffusion steps (they depend only on the static tokenization outputs).
|
||||||
|
|
||||||
|
Call this once after tokenization, before the diffusion loop.
|
||||||
|
forward_chunked will then replace those MLP calls with free tensor indexing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
C_L: Atom features [L, c_token]
|
||||||
|
Z_init_II: Token-pair features [I, I, c_z]
|
||||||
|
"""
|
||||||
|
self._sl_cached = self.process_single_l(C_L) # [L, c_atompair]
|
||||||
|
self._sm_cached = self.process_single_m(C_L) # [L, c_atompair]
|
||||||
|
self._Z_proc_cached = self.process_z(Z_init_II) # [I, I, c_atompair]
|
||||||
|
|
||||||
def forward_chunked(
|
def forward_chunked(
|
||||||
self,
|
self,
|
||||||
f: dict,
|
f: dict,
|
||||||
indices: torch.Tensor, # [B, L, k] - sparse attention indices
|
indices: torch.Tensor, # [B, L, k] - sparse attention indices
|
||||||
C_L: torch.Tensor, # [B, L, c_token] - atom features
|
C_L: torch.Tensor, # [L, c_token] or [B, L, c_token] - atom features
|
||||||
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
|
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
|
||||||
tok_idx: torch.Tensor, # [L] - atom to token mapping
|
tok_idx: torch.Tensor, # [L] - atom to token mapping
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Add logging for chunked P_LL computation
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.info(
|
|
||||||
f"ChunkedPairwiseEmbedder: Computing sparse P_LL for {indices.shape[1]} atoms with {indices.shape[2]} neighbors each"
|
|
||||||
)
|
|
||||||
"""
|
"""
|
||||||
Compute P_LL only for the pairs specified by attention indices.
|
Compute P_LL only for the pairs specified by attention indices.
|
||||||
|
|
||||||
|
When cache_static_projections() has been called beforehand, the three MLP
|
||||||
|
terms (process_single_l, process_single_m, process_z) are replaced with
|
||||||
|
free tensor index operations, since those projections are identical across
|
||||||
|
all diffusion steps.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: Feature dictionary
|
f: Feature dictionary (motif_pos, ref_pos, etc.)
|
||||||
indices: Sparse attention indices [B, L, k]
|
indices: Sparse attention indices [B, L, k]
|
||||||
C_L: Atom-level features [B, L, c_token]
|
C_L: Atom-level features [L, c_token] or [B, L, c_token]
|
||||||
Z_init_II: Token-level pair features [I, I, c_z]
|
Z_init_II: Token-level pair features [I, I, c_z]
|
||||||
tok_idx: Atom to token mapping [L]
|
tok_idx: Atom-to-token mapping [L]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
|
P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
|
||||||
"""
|
"""
|
||||||
@@ -286,73 +317,53 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
|||||||
)
|
)
|
||||||
P_LL_sparse[:, l, :, :] += ref_pairs
|
P_LL_sparse[:, l, :, :] += ref_pairs
|
||||||
|
|
||||||
# 3. Single embedding terms (broadcasted)
|
# 3. Single embedding terms
|
||||||
# Expand C_L to match valid_indices batch dimension
|
if self._sl_cached is not None:
|
||||||
if C_L.shape[0] != B:
|
# Fast path: MLP already run at tokenisation — just index into the result.
|
||||||
C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
|
# sl_cached [L, c_atompair]: query atom l always maps to row l.
|
||||||
# Gather key features for each query
|
single_l = self._sl_cached.unsqueeze(0).unsqueeze(2).expand(B, -1, k, -1)
|
||||||
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
|
# sm_cached [L, c_atompair]: key atoms are given by valid_indices [B, L, k].
|
||||||
C_L_keys = torch.gather(
|
single_m = self._sm_cached[valid_indices] # [B, L, k, c_atompair]
|
||||||
C_L_queries,
|
else:
|
||||||
1,
|
# Slow path (no cache): run the MLPs over the raw atom features.
|
||||||
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
|
if C_L.shape[0] != B:
|
||||||
) # [B, L, k, c_token]
|
C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
|
||||||
|
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
|
||||||
# Add single embeddings - match standard implementation structure
|
C_L_keys = torch.gather(
|
||||||
# Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
|
C_L_queries,
|
||||||
# We need to broadcast from [B, L, k, c_atompair] to match this
|
1,
|
||||||
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
|
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
|
||||||
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
|
) # [B, L, k, c_token]
|
||||||
|
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
|
||||||
|
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
|
||||||
P_LL_sparse += single_l + single_m
|
P_LL_sparse += single_l + single_m
|
||||||
|
|
||||||
# 4. Token pair features Z_init_II
|
# 4. Token pair features Z_init_II
|
||||||
# Map atoms to tokens and gather token pair features
|
# Map atoms to tokens and gather token pair features.
|
||||||
# Handle tok_idx dimensions properly
|
|
||||||
if tok_idx.dim() == 1: # [L] - add batch dimension for consistency
|
if tok_idx.dim() == 1: # [L] - add batch dimension for consistency
|
||||||
tok_idx_expanded = tok_idx.unsqueeze(0) # [1, L]
|
tok_idx_expanded = tok_idx.unsqueeze(0) # [1, L]
|
||||||
else:
|
else:
|
||||||
tok_idx_expanded = tok_idx
|
tok_idx_expanded = tok_idx
|
||||||
|
|
||||||
# Expand tok_idx_expanded to match valid_indices batch dimension
|
|
||||||
if tok_idx_expanded.shape[0] != B:
|
if tok_idx_expanded.shape[0] != B:
|
||||||
tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
|
tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
|
||||||
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
|
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
|
||||||
# Use valid_indices for token mapping as well
|
|
||||||
tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
|
tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
|
||||||
|
|
||||||
# Gather Z_init_II[tok_queries, tok_keys] with safe indexing
|
if self._Z_proc_cached is not None:
|
||||||
# Z_init_II shape is [I, I, c_z] (3D), not 4D
|
# Fast path: process_z already run at tokenisation.
|
||||||
# tok_queries shape: [B, L, k] - each value is a token index
|
Z_processed = self._Z_proc_cached # [I, I, c_atompair]
|
||||||
# We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
|
else:
|
||||||
|
# Slow path: run the MLP over the token-pair matrix.
|
||||||
I_z, I_z2, c_z = Z_init_II.shape
|
Z_processed = self.process_z(Z_init_II) # [I, I, c_atompair]
|
||||||
|
|
||||||
# CRITICAL: Match standard implementation exactly!
|
|
||||||
# Standard does: self.process_z(Z_init_II)[..., tok_idx, :, :][..., tok_idx, :]
|
|
||||||
# This means: 1) Process Z_init_II first, 2) Then do double token indexing
|
|
||||||
|
|
||||||
# Step 1: Process Z_init_II to get processed token pair features
|
|
||||||
Z_processed = self.process_z(Z_init_II) # [I, I, c_atompair]
|
|
||||||
|
|
||||||
# Step 2: Do the double indexing like the standard implementation
|
|
||||||
# Standard: Z_processed[..., tok_idx, :, :][..., tok_idx, :]
|
|
||||||
# This creates Z_processed[tok_idx, :][:, tok_idx] which is [L, L, c_atompair]
|
|
||||||
# Then we need to gather the sparse version
|
|
||||||
|
|
||||||
|
I_z, I_z2 = Z_processed.shape[:2]
|
||||||
Z_pairs_processed = torch.zeros(
|
Z_pairs_processed = torch.zeros(
|
||||||
B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
|
B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
for b in range(B):
|
for b in range(B):
|
||||||
# For this batch, get the token queries and keys
|
tq = torch.clamp(tok_queries[b], 0, I_z - 1) # [L, k]
|
||||||
tq = tok_queries[b] # [L, k]
|
tk = torch.clamp(tok_keys[b], 0, I_z2 - 1) # [L, k]
|
||||||
tk = tok_keys[b] # [L, k]
|
|
||||||
|
|
||||||
# Ensure indices are within bounds
|
|
||||||
tq = torch.clamp(tq, 0, I_z - 1)
|
|
||||||
tk = torch.clamp(tk, 0, I_z2 - 1)
|
|
||||||
|
|
||||||
# Apply the double token indexing like standard implementation
|
|
||||||
Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
|
Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
|
||||||
|
|
||||||
P_LL_sparse += Z_pairs_processed
|
P_LL_sparse += Z_pairs_processed
|
||||||
@@ -369,8 +380,12 @@ def create_chunked_embedders(
|
|||||||
"""
|
"""
|
||||||
Factory function to create chunked pairwise embedder with standard components.
|
Factory function to create chunked pairwise embedder with standard components.
|
||||||
"""
|
"""
|
||||||
motif_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
|
motif_pos_embedder = ChunkedSinusoidalDistEmbed(
|
||||||
ref_pos_embedder = ChunkedPositionPairDistEmbedder(c_atompair, embed_frame)
|
embedder_instance=SinusoidalDistEmbed(c_atompair, embed_frame)
|
||||||
|
)
|
||||||
|
ref_pos_embedder = ChunkedPositionPairDistEmbedder(
|
||||||
|
embedder_instance=PositionPairDistEmbedder(c_atompair, embed_frame)
|
||||||
|
)
|
||||||
|
|
||||||
return ChunkedPairwiseEmbedder(
|
return ChunkedPairwiseEmbedder(
|
||||||
c_atompair=c_atompair,
|
c_atompair=c_atompair,
|
||||||
|
|||||||
@@ -128,17 +128,21 @@ class TokenInitializer(nn.Module):
|
|||||||
|
|
||||||
# Atom pair feature processing
|
# Atom pair feature processing
|
||||||
if self.use_chunked_pll:
|
if self.use_chunked_pll:
|
||||||
# Initialize chunked embedders and share the trained MLPs!
|
# Share trained components so chunked inference uses same weights as full training.
|
||||||
|
motif_pos_embedder = ChunkedSinusoidalDistEmbed(
|
||||||
|
embedder_instance=self.motif_pos_embedder
|
||||||
|
)
|
||||||
|
ref_pos_embedder = ChunkedPositionPairDistEmbedder(
|
||||||
|
embedder_instance=self.ref_pos_embedder
|
||||||
|
)
|
||||||
self.chunked_pairwise_embedder = ChunkedPairwiseEmbedder(
|
self.chunked_pairwise_embedder = ChunkedPairwiseEmbedder(
|
||||||
c_atompair=c_atompair,
|
c_atompair=c_atompair,
|
||||||
motif_pos_embedder=ChunkedSinusoidalDistEmbed(c_atompair=c_atompair),
|
motif_pos_embedder=motif_pos_embedder,
|
||||||
ref_pos_embedder=ChunkedPositionPairDistEmbedder(
|
ref_pos_embedder=ref_pos_embedder,
|
||||||
c_atompair, embed_frame=False
|
process_single_l=self.process_single_l,
|
||||||
),
|
process_single_m=self.process_single_m,
|
||||||
process_single_l=self.process_single_l, # Share trained parameters!
|
process_z=self.process_z,
|
||||||
process_single_m=self.process_single_m, # Share trained parameters!
|
pair_mlp=self.pair_mlp,
|
||||||
process_z=self.process_z, # Share trained parameters!
|
|
||||||
pair_mlp=self.pair_mlp, # Share trained parameters!
|
|
||||||
)
|
)
|
||||||
self.process_pll = linearNoBias(c_atompair, c_atompair)
|
self.process_pll = linearNoBias(c_atompair, c_atompair)
|
||||||
self.project_pll = linearNoBias(c_atompair, c_z)
|
self.project_pll = linearNoBias(c_atompair, c_z)
|
||||||
@@ -223,7 +227,9 @@ class TokenInitializer(nn.Module):
|
|||||||
C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :]
|
C_L = Q_L_init + self.process_s_trunk(S_init_I)[..., tok_idx, :]
|
||||||
|
|
||||||
if self.use_chunked_pll:
|
if self.use_chunked_pll:
|
||||||
# Chunked mode: return embedder for later sparse computation
|
# Precompute static MLP projections once so forward_chunked can
|
||||||
|
# skip those MLP calls at every subsequent diffusion step.
|
||||||
|
self.chunked_pairwise_embedder.cache_static_projections(C_L, Z_init_II)
|
||||||
return {
|
return {
|
||||||
"Q_L_init": Q_L_init,
|
"Q_L_init": Q_L_init,
|
||||||
"C_L": C_L,
|
"C_L": C_L,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
|
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/foundry_exec.sh" "$0" "$@"'
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from rfd3.trainer.trainer_utils import (
|
|||||||
_build_atom_array_stack,
|
_build_atom_array_stack,
|
||||||
_cleanup_virtual_atoms_and_assign_atom_name_elements,
|
_cleanup_virtual_atoms_and_assign_atom_name_elements,
|
||||||
_reassign_unindexed_token_chains,
|
_reassign_unindexed_token_chains,
|
||||||
|
_remap_outputs,
|
||||||
_reorder_dict,
|
_reorder_dict,
|
||||||
process_unindexed_outputs,
|
process_unindexed_outputs,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user