Merge pull request #111 from RosettaCommons/fix/chunkedPLL

fix: the cause of the slow inference using low_memory_mode
This commit is contained in:
Jasper Butcher
2025-12-17 11:49:16 +01:00
committed by GitHub
5 changed files with 124 additions and 117 deletions

View File

@@ -100,14 +100,11 @@ def make_symmetric_atom_array(
src_atom_array is not None
), "Source atom array must be provided for symmetric motifs"
frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
elif sym_conf.is_symmetric_motif is None:
else:
# At this point, asym case would have been caught by the check_symmetry_config function.
ranked_logger.info(
"No motifs found in atom array. Generating unconditional symmetric proteins."
)
else:
raise ValueError(
"Asymmetric motif inputs are not supported yet. Please provide a symmetric motif or no motifs."
)
# Add symmetry annotations to the asu atom array
asu_atom_array = add_sym_annotations(asu_atom_array, sym_conf)

View File

@@ -1,4 +1,5 @@
import inspect
import time
from dataclasses import dataclass
from typing import Any, Literal
@@ -8,6 +9,7 @@ from rfd3.inference.symmetry.symmetry_utils import apply_symmetry_to_xyz_atomwis
from rfd3.model.cfg_utils import strip_X
from foundry.common import exists
from foundry.utils.alignment import weighted_rigid_align
from foundry.utils.ddp import RankedLogger
from foundry.utils.rotation_augmentation import (
rot_vec_mul,
@@ -112,7 +114,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
)
# Fallback to smallest available step
noise_schedule_original = self._construct_inference_noise_schedule(
device=coord_atom_lvl_to_be_noised.device
device=device
)
noise_schedule = noise_schedule_original[-1:] # Just use the final step
ranked_logger.info(
@@ -223,6 +225,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
# Handle chunked mode vs standard mode
if "chunked_pairwise_embedder" in initializer_outputs:
# Chunked mode: explicitly provide P_LL=None
tic = time.time()
chunked_embedder = initializer_outputs[
"chunked_pairwise_embedder"
] # Don't pop, just get
@@ -240,6 +243,8 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
initializer_outputs=other_outputs,
**other_outputs,
)
toc = time.time()
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
else:
# Standard mode: P_LL is included in initializer_outputs
outs = diffusion_module(
@@ -447,6 +452,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
# Handle chunked mode vs standard mode (same as default sampler)
if "chunked_pairwise_embedder" in initializer_outputs:
# Chunked mode: explicitly provide P_LL=None
tic = time.time()
chunked_embedder = initializer_outputs[
"chunked_pairwise_embedder"
] # Don't pop, just get
@@ -464,6 +470,8 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
initializer_outputs=other_outputs,
**other_outputs,
)
toc = time.time()
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
else:
# Standard mode: P_LL is included in initializer_outputs
outs = diffusion_module(

View File

@@ -118,14 +118,14 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
Parameters
----------
P_LK_indices : (D, L, k) LongTensor
P_LK_indices : (B, L, k) LongTensor
Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
P_LK : (D, L, k, c) FloatTensor
P_LK : (B, L, k, c) FloatTensor
Key features to scatter add into
P_LA_indices : (D, L, a) LongTensor
P_LA_indices : (B, L, a) LongTensor
Additional feature indices to scatter into P_LK.
P_LA : (D, L, a, c) FloatTensor
P_LA : (B, L, a, c) FloatTensor
Features corresponding to P_LA.
Both index tensors contain indices representing D batch dim,
@@ -135,42 +135,42 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
"""
# Handle case when indices and P_LA don't have batch dimensions
D, L, k = P_LK_indices.shape
B, L, k = P_LK_indices.shape
if P_LA_indices.ndim == 2:
P_LA_indices = P_LA_indices.unsqueeze(0).expand(D, -1, -1)
P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1)
if P_LA_src.ndim == 3:
P_LA_src = P_LA_src.unsqueeze(0).expand(D, -1, -1)
P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1)
assert (
P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
), "Channel dims do not match, got: {} vs {}".format(
P_LA_src.shape[-1], P_LK_tgt.shape[-1]
)
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (D, L, a, k)
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (B, L, a, k)
if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
raise ValueError("Found multiple scatter indices for some atoms")
elif not torch.all(matches.sum(dim=-1) <= 1):
raise ValueError("Did not find a scatter index for every atom")
k_indices = matches.long().argmax(dim=-1) # (D, L, a)
k_indices = matches.long().argmax(dim=-1) # (B, L, a)
scatter_indices = k_indices.unsqueeze(-1).expand(
-1, -1, -1, P_LK_tgt.shape[-1]
) # (D, L, a, c)
) # (B, L, a, c)
P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
return P_LK_tgt
def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
values : (D, L, C)
idx : (D, L, k)
returns: (D, L, k, C)
values : (B, L, C)
idx : (B, L, k)
returns: (B, L, k, C)
"""
D, L, C = values.shape
B, L, C = values.shape
k = idx.shape[-1]
# (D, L, 1, C) → stride-0 along k → (D, L, k, C)
# (B, L, 1, C) → stride-0 along k → (B, L, k, C)
src = values.unsqueeze(2).expand(-1, -1, k, -1)
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (D, L, k, C)
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C)
return torch.gather(src, 1, idx) # dim=1 is the L-axis
@@ -196,7 +196,7 @@ def create_attention_indices(
X_L = torch.randn(
(1, L, 3), device=device, dtype=torch.float
) # [L, 3] - random
D_LL = torch.cdist(X_L, X_L, p=2) # [D, L, L] - pairwise atom distances
D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances
# Create attention indices using neighbour distances
base_mask = ~f["unindexing_pair_mask"][
@@ -231,7 +231,7 @@ def create_attention_indices(
k_max=k_actual,
chain_id=chain_ids,
base_mask=base_mask,
) # [D, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
) # [B, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
return attn_indices
@@ -245,7 +245,7 @@ def get_sparse_attention_indices_with_inter_chain(
Args:
tok_idx: atom to token mapping
D_LL: pairwise distances [D, L, L]
D_LL: pairwise distances [B, L, L]
n_seq_neighbours: number of sequence neighbors
k_intra: number of intra-chain attention keys
k_inter: number of inter-chain attention keys
@@ -253,29 +253,29 @@ def get_sparse_attention_indices_with_inter_chain(
base_mask: base mask for valid pairs
Returns:
attn_indices: [D, L, k_total] where k_total = k_intra + k_inter
attn_indices: [B, L, k_total] where k_total = k_intra + k_inter
"""
D, L, _ = D_LL.shape
B, L, _ = D_LL.shape
# Get regular intra-chain indices (limited to k_intra)
intra_indices = get_sparse_attention_indices(
tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
) # [D, L, k_intra]
) # [B, L, k_intra]
# Get inter-chain indices for clash avoidance
inter_indices = torch.zeros(D, L, k_inter, dtype=torch.long, device=D_LL.device)
for d in range(D):
for l in range(L):
query_chain = chain_id[l]
inter_indices = torch.zeros(B, L, k_inter, dtype=torch.long, device=D_LL.device)
unique_chains = torch.unique(chain_id)
for b in range(B):
for c in unique_chains:
query_chain = chain_id[c]
# Find atoms from different chains
other_chain_mask = (chain_id != query_chain) & base_mask[l, :]
other_chain_mask = (chain_id != query_chain) & base_mask[c, :]
other_chain_atoms = torch.where(other_chain_mask)[0]
if len(other_chain_atoms) > 0:
# Get distances to other chains
distances_to_other = D_LL[d, l, other_chain_atoms]
distances_to_other = D_LL[b, c, other_chain_atoms]
# Select k_inter closest atoms from other chains
n_select = min(k_inter, len(other_chain_atoms))
@@ -283,23 +283,23 @@ def get_sparse_attention_indices_with_inter_chain(
selected_atoms = other_chain_atoms[closest_idx]
# Fill inter-chain indices
inter_indices[d, l, :n_select] = selected_atoms
inter_indices[b, c, :n_select] = selected_atoms
# Pad with random atoms if needed
if n_select < k_inter:
padding = torch.randint(
0, L, (k_inter - n_select,), device=D_LL.device
)
inter_indices[d, l, n_select:] = padding
inter_indices[b, c, n_select:] = padding
else:
# No other chains found, fill with random indices
inter_indices[d, l, :] = torch.randint(
inter_indices[b, c, :] = torch.randint(
0, L, (k_inter,), device=D_LL.device
)
# Combine intra and inter chain indices
combined_indices = torch.cat(
[intra_indices, inter_indices], dim=-1
) # [D, L, k_total]
) # [B, L, k_total]
return combined_indices

View File

@@ -30,32 +30,32 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
def compute_pairs_chunked(
self,
query_pos: torch.Tensor, # [D, 3]
key_pos: torch.Tensor, # [D, k, 3]
valid_mask: torch.Tensor, # [D, k, 1]
query_pos: torch.Tensor, # [B, 3]
key_pos: torch.Tensor, # [B, k, 3]
valid_mask: torch.Tensor, # [B, k, 1]
) -> torch.Tensor:
"""
Compute pairwise embeddings for specific query-key pairs.
Args:
query_pos: Query positions [D, 3]
key_pos: Key positions [D, k, 3]
valid_mask: Valid pair mask [D, k, 1]
query_pos: Query positions [B, 3]
key_pos: Key positions [B, k, 3]
valid_mask: Valid pair mask [B, k, 1]
Returns:
P_sparse: Pairwise embeddings [D, k, c_atompair]
P_sparse: Pairwise embeddings [B, k, c_atompair]
"""
D, k = key_pos.shape[:2]
B, k = key_pos.shape[:2]
# Compute pairwise distances: [D, k, 3]
D_pairs = query_pos.unsqueeze(1) - key_pos # [D, 1, 3] - [D, k, 3] = [D, k, 3]
# Compute pairwise distances: [B, k, 3]
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
if self.embed_frame:
# Embed pairwise distances
P_pairs = self.process_d(D_pairs) * valid_mask # [D, k, c_atompair]
P_pairs = self.process_d(D_pairs) * valid_mask # [B, k, c_atompair]
# Add inverse distance embedding
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [D, k, 1]
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [B, k, 1]
inv_dist = 1 / (1 + norm_sq)
P_pairs = P_pairs + self.process_inverse_dist(inv_dist) * valid_mask
@@ -95,19 +95,19 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
def compute_pairs_chunked(
self,
query_pos: torch.Tensor, # [D, 3]
key_pos: torch.Tensor, # [D, k, 3]
valid_mask: torch.Tensor, # [D, k, 1]
query_pos: torch.Tensor, # [B, 3]
key_pos: torch.Tensor, # [B, k, 3]
valid_mask: torch.Tensor, # [B, k, 1]
) -> torch.Tensor:
"""
Compute sinusoidal distance embeddings for specific query-key pairs.
"""
D, k = key_pos.shape[:2]
B, k = key_pos.shape[:2]
device = query_pos.device
# Compute pairwise distances
D_pairs = query_pos.unsqueeze(1) - key_pos # [D, k, 3]
dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [D, k]
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, k, 3]
dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [B, k]
# Sinusoidal embedding
half_dim = self.n_freqs
@@ -117,13 +117,13 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
/ half_dim
) # [n_freqs]
angles = dist_matrix.unsqueeze(-1) * freq # [D, k, n_freqs]
angles = dist_matrix.unsqueeze(-1) * freq # [B, k, n_freqs]
sin_embed = torch.sin(angles)
cos_embed = torch.cos(angles)
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [D, k, 2*n_freqs]
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [B, k, 2*n_freqs]
# Linear projection
P_pairs = self.output_proj(sincos_embed) # [D, k, c_atompair]
P_pairs = self.output_proj(sincos_embed) # [B, k, c_atompair]
P_pairs = P_pairs * valid_mask
# Add linear embedding of valid mask
@@ -191,8 +191,8 @@ class ChunkedPairwiseEmbedder(nn.Module):
def forward_chunked(
self,
f: dict,
indices: torch.Tensor, # [D, L, k] - sparse attention indices
C_L: torch.Tensor, # [D, L, c_token] - atom features
indices: torch.Tensor, # [B, L, k] - sparse attention indices
C_L: torch.Tensor, # [B, L, c_token] - atom features
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
tok_idx: torch.Tensor, # [L] - atom to token mapping
) -> torch.Tensor:
@@ -208,20 +208,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
Args:
f: Feature dictionary
indices: Sparse attention indices [D, L, k]
C_L: Atom-level features [D, L, c_token]
indices: Sparse attention indices [B, L, k]
C_L: Atom-level features [B, L, c_token]
Z_init_II: Token-level pair features [I, I, c_z]
tok_idx: Atom to token mapping [L]
Returns:
P_LL_sparse: Sparse pairwise features [D, L, k, c_atompair]
P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
"""
D, L, k = indices.shape
B, L, k = indices.shape
device = indices.device
# Initialize sparse P_LL
P_LL_sparse = torch.zeros(
D, L, k, self.c_atompair, device=device, dtype=C_L.dtype
B, L, k, self.c_atompair, device=device, dtype=C_L.dtype
)
# Handle both batched and non-batched C_L
@@ -237,71 +237,72 @@ class ChunkedPairwiseEmbedder(nn.Module):
if valid_indices.dim() == 2: # [L, k] - add batch dimension
valid_indices = valid_indices.unsqueeze(0).expand(
C_L.shape[0], -1, -1
) # [D, L, k]
) # [B, L, k]
# 1. Motif position embedding (if exists)
if self.motif_pos_embedder is not None and "motif_pos" in f:
motif_pos = f["motif_pos"] # [L, 3]
is_motif = f["is_motif_atom_with_fixed_coord"] # [L]
is_motif_idx = torch.where(is_motif)[0]
# For each query position
for l in range(L):
if is_motif[l]: # Only compute if query is motif
key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
key_pos = motif_pos[key_indices] # [D, k, 3]
query_pos = motif_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
for l in is_motif_idx:
key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
key_pos = motif_pos[key_indices] # [B, k, 3]
query_pos = motif_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
# Valid mask: both query and keys must be motif
key_is_motif = is_motif[key_indices] # [D, k]
valid_mask = key_is_motif.unsqueeze(-1).float() # [D, k, 1]
# Valid mask: both query and keys must be motif
key_is_motif = is_motif[key_indices] # [B, k]
valid_mask = key_is_motif.unsqueeze(-1).float() # [B, k, 1]
if valid_mask.sum() > 0:
motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
query_pos, key_pos, valid_mask
)
P_LL_sparse[:, l, :, :] += motif_pairs
if valid_mask.sum() > 0:
motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
query_pos, key_pos, valid_mask
)
P_LL_sparse[:, l, :, :] += motif_pairs
# 2. Reference position embedding (if exists)
if self.ref_pos_embedder is not None and "ref_pos" in f:
ref_pos = f["ref_pos"] # [L, 3]
ref_space_uid = f["ref_space_uid"] # [L]
is_motif_seq = f["is_motif_atom_with_fixed_seq"] # [L]
is_motif_seq_idx = torch.where(is_motif_seq)[0]
for l in is_motif_seq_idx:
key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
key_pos = ref_pos[key_indices] # [B, k, 3]
query_pos = ref_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
for l in range(L):
if is_motif_seq[l]: # Only compute if query has sequence
key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
key_pos = ref_pos[key_indices] # [D, k, 3]
query_pos = ref_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
# Valid mask: same token and both have sequence
key_space_uid = ref_space_uid[key_indices] # [B, k]
key_is_motif_seq = is_motif_seq[key_indices] # [B, k]
# Valid mask: same token and both have sequence
key_space_uid = ref_space_uid[key_indices] # [D, k]
key_is_motif_seq = is_motif_seq[key_indices] # [D, k]
same_token = key_space_uid == ref_space_uid[l] # [B, k]
valid_mask = (
(same_token & key_is_motif_seq).unsqueeze(-1).float()
) # [B, k, 1]
same_token = key_space_uid == ref_space_uid[l] # [D, k]
valid_mask = (
(same_token & key_is_motif_seq).unsqueeze(-1).float()
) # [D, k, 1]
if valid_mask.sum() > 0:
ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
query_pos, key_pos, valid_mask
)
P_LL_sparse[:, l, :, :] += ref_pairs
if valid_mask.sum() > 0:
ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
query_pos, key_pos, valid_mask
)
P_LL_sparse[:, l, :, :] += ref_pairs
# 3. Single embedding terms (broadcasted)
# Expand C_L to match valid_indices batch dimension
if C_L.shape[0] != B:
C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
# Gather key features for each query
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
C_L_keys = torch.gather(
C_L.unsqueeze(2).expand(-1, -1, k, -1),
C_L_queries,
1,
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
) # [D, L, k, c_token]
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [D, L, k, c_token]
) # [B, L, k, c_token]
# Add single embeddings - match standard implementation structure
# Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
# We need to broadcast from [D, L, k, c_atompair] to match this
single_l = self.process_single_l(C_L_queries) # [D, L, k, c_atompair]
single_m = self.process_single_m(C_L_keys) # [D, L, k, c_atompair]
# We need to broadcast from [B, L, k, c_atompair] to match this
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
P_LL_sparse += single_l + single_m
# 4. Token pair features Z_init_II
@@ -312,15 +313,16 @@ class ChunkedPairwiseEmbedder(nn.Module):
else:
tok_idx_expanded = tok_idx
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [D, L, k]
# Expand tok_idx_expanded to match valid_indices batch dimension
if tok_idx_expanded.shape[0] != B:
tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
# Use valid_indices for token mapping as well
tok_keys = torch.gather(
tok_idx_expanded.unsqueeze(2).expand(-1, -1, k), 1, valid_indices
) # [D, L, k]
tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
# Gather Z_init_II[tok_queries, tok_keys] with safe indexing
# Z_init_II shape is [I, I, c_z] (3D), not 4D
# tok_queries shape: [D, L, k] - each value is a token index
# tok_queries shape: [B, L, k] - each value is a token index
# We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
I_z, I_z2, c_z = Z_init_II.shape
@@ -338,20 +340,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
# Then we need to gather the sparse version
Z_pairs_processed = torch.zeros(
D, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
)
for d in range(D):
for b in range(B):
# For this batch, get the token queries and keys
tq = tok_queries[d] # [L, k]
tk = tok_keys[d] # [L, k]
tq = tok_queries[b] # [L, k]
tk = tok_keys[b] # [L, k]
# Ensure indices are within bounds
tq = torch.clamp(tq, 0, I_z - 1)
tk = torch.clamp(tk, 0, I_z2 - 1)
# Apply the double token indexing like standard implementation
Z_pairs_processed[d] = Z_processed[tq, tk] # [L, k, c_atompair]
Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
P_LL_sparse += Z_pairs_processed

View File

@@ -3,7 +3,7 @@ name = "rc-foundry"
dynamic = ["version"]
description = "Shared utilities and training infrastructure for biomolecular structure prediction models."
readme = "README.md"
requires-python = "==3.12"
requires-python = ">=3.12"
authors = [
{ name = "Institute for Protein Design", email = "contact@ipd.uw.edu" },
]