mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fix: the cause of the slow inference using low_memory_mode
- remove the for loop that goes through atoms in chunked PLL - support multiple batches in chunked PLL
This commit is contained in:
@@ -100,14 +100,11 @@ def make_symmetric_atom_array(
|
||||
src_atom_array is not None
|
||||
), "Source atom array must be provided for symmetric motifs"
|
||||
frames = get_symmetry_frames_from_atom_array(src_atom_array, frames)
|
||||
elif sym_conf.is_symmetric_motif is None:
|
||||
else:
|
||||
# At this point, asym case would have been caught by the check_symmetry_config function.
|
||||
ranked_logger.info(
|
||||
"No motifs found in atom array. Generating unconditional symmetric proteins."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Asymmetric motif inputs are not supported yet. Please provide a symmetric motif or no motifs."
|
||||
)
|
||||
|
||||
# Add symmetry annotations to the asu atom array
|
||||
asu_atom_array = add_sym_annotations(asu_atom_array, sym_conf)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
@@ -13,6 +14,7 @@ from foundry.utils.rotation_augmentation import (
|
||||
rot_vec_mul,
|
||||
uniform_random_rotation,
|
||||
)
|
||||
from foundry.utils.alignment import weighted_rigid_align
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
@@ -112,7 +114,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
||||
)
|
||||
# Fallback to smallest available step
|
||||
noise_schedule_original = self._construct_inference_noise_schedule(
|
||||
device=coord_atom_lvl_to_be_noised.device
|
||||
device=device
|
||||
)
|
||||
noise_schedule = noise_schedule_original[-1:] # Just use the final step
|
||||
ranked_logger.info(
|
||||
@@ -223,6 +225,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
||||
# Handle chunked mode vs standard mode
|
||||
if "chunked_pairwise_embedder" in initializer_outputs:
|
||||
# Chunked mode: explicitly provide P_LL=None
|
||||
tic = time.time()
|
||||
chunked_embedder = initializer_outputs[
|
||||
"chunked_pairwise_embedder"
|
||||
] # Don't pop, just get
|
||||
@@ -240,6 +243,8 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
||||
initializer_outputs=other_outputs,
|
||||
**other_outputs,
|
||||
)
|
||||
toc = time.time()
|
||||
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
||||
else:
|
||||
# Standard mode: P_LL is included in initializer_outputs
|
||||
outs = diffusion_module(
|
||||
@@ -447,6 +452,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
||||
# Handle chunked mode vs standard mode (same as default sampler)
|
||||
if "chunked_pairwise_embedder" in initializer_outputs:
|
||||
# Chunked mode: explicitly provide P_LL=None
|
||||
tic = time.time()
|
||||
chunked_embedder = initializer_outputs[
|
||||
"chunked_pairwise_embedder"
|
||||
] # Don't pop, just get
|
||||
@@ -464,6 +470,8 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif):
|
||||
initializer_outputs=other_outputs,
|
||||
**other_outputs,
|
||||
)
|
||||
toc = time.time()
|
||||
ranked_logger.info(f"Chunked mode time: {toc - tic} seconds")
|
||||
else:
|
||||
# Standard mode: P_LL is included in initializer_outputs
|
||||
outs = diffusion_module(
|
||||
|
||||
@@ -118,14 +118,14 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
P_LK_indices : (D, L, k) LongTensor
|
||||
P_LK_indices : (B, L, k) LongTensor
|
||||
Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to.
|
||||
P_LK : (D, L, k, c) FloatTensor
|
||||
P_LK : (B, L, k, c) FloatTensor
|
||||
Key features to scatter add into
|
||||
|
||||
P_LA_indices : (D, L, a) LongTensor
|
||||
P_LA_indices : (B, L, a) LongTensor
|
||||
Additional feature indices to scatter into P_LK.
|
||||
P_LA : (D, L, a, c) FloatTensor
|
||||
P_LA : (B, L, a, c) FloatTensor
|
||||
Features corresponding to P_LA.
|
||||
|
||||
Both index tensors contain indices representing D batch dim,
|
||||
@@ -135,42 +135,42 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices):
|
||||
|
||||
"""
|
||||
# Handle case when indices and P_LA don't have batch dimensions
|
||||
D, L, k = P_LK_indices.shape
|
||||
B, L, k = P_LK_indices.shape
|
||||
if P_LA_indices.ndim == 2:
|
||||
P_LA_indices = P_LA_indices.unsqueeze(0).expand(D, -1, -1)
|
||||
P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1)
|
||||
if P_LA_src.ndim == 3:
|
||||
P_LA_src = P_LA_src.unsqueeze(0).expand(D, -1, -1)
|
||||
P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1)
|
||||
assert (
|
||||
P_LA_src.shape[-1] == P_LK_tgt.shape[-1]
|
||||
), "Channel dims do not match, got: {} vs {}".format(
|
||||
P_LA_src.shape[-1], P_LK_tgt.shape[-1]
|
||||
)
|
||||
|
||||
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (D, L, a, k)
|
||||
matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (B, L, a, k)
|
||||
if not torch.all(matches.sum(dim=(-1, -2)) >= 1):
|
||||
raise ValueError("Found multiple scatter indices for some atoms")
|
||||
elif not torch.all(matches.sum(dim=-1) <= 1):
|
||||
raise ValueError("Did not find a scatter index for every atom")
|
||||
k_indices = matches.long().argmax(dim=-1) # (D, L, a)
|
||||
k_indices = matches.long().argmax(dim=-1) # (B, L, a)
|
||||
scatter_indices = k_indices.unsqueeze(-1).expand(
|
||||
-1, -1, -1, P_LK_tgt.shape[-1]
|
||||
) # (D, L, a, c)
|
||||
) # (B, L, a, c)
|
||||
P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src)
|
||||
return P_LK_tgt
|
||||
|
||||
|
||||
def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
values : (D, L, C)
|
||||
idx : (D, L, k)
|
||||
returns: (D, L, k, C)
|
||||
values : (B, L, C)
|
||||
idx : (B, L, k)
|
||||
returns: (B, L, k, C)
|
||||
"""
|
||||
D, L, C = values.shape
|
||||
B, L, C = values.shape
|
||||
k = idx.shape[-1]
|
||||
|
||||
# (D, L, 1, C) → stride-0 along k → (D, L, k, C)
|
||||
# (B, L, 1, C) → stride-0 along k → (B, L, k, C)
|
||||
src = values.unsqueeze(2).expand(-1, -1, k, -1)
|
||||
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (D, L, k, C)
|
||||
idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C)
|
||||
|
||||
return torch.gather(src, 1, idx) # dim=1 is the L-axis
|
||||
|
||||
@@ -196,7 +196,7 @@ def create_attention_indices(
|
||||
X_L = torch.randn(
|
||||
(1, L, 3), device=device, dtype=torch.float
|
||||
) # [L, 3] - random
|
||||
D_LL = torch.cdist(X_L, X_L, p=2) # [D, L, L] - pairwise atom distances
|
||||
D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances
|
||||
|
||||
# Create attention indices using neighbour distances
|
||||
base_mask = ~f["unindexing_pair_mask"][
|
||||
@@ -231,7 +231,7 @@ def create_attention_indices(
|
||||
k_max=k_actual,
|
||||
chain_id=chain_ids,
|
||||
base_mask=base_mask,
|
||||
) # [D, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
|
||||
) # [B, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query
|
||||
|
||||
return attn_indices
|
||||
|
||||
@@ -245,7 +245,7 @@ def get_sparse_attention_indices_with_inter_chain(
|
||||
|
||||
Args:
|
||||
tok_idx: atom to token mapping
|
||||
D_LL: pairwise distances [D, L, L]
|
||||
D_LL: pairwise distances [B, L, L]
|
||||
n_seq_neighbours: number of sequence neighbors
|
||||
k_intra: number of intra-chain attention keys
|
||||
k_inter: number of inter-chain attention keys
|
||||
@@ -253,29 +253,29 @@ def get_sparse_attention_indices_with_inter_chain(
|
||||
base_mask: base mask for valid pairs
|
||||
|
||||
Returns:
|
||||
attn_indices: [D, L, k_total] where k_total = k_intra + k_inter
|
||||
attn_indices: [B, L, k_total] where k_total = k_intra + k_inter
|
||||
"""
|
||||
D, L, _ = D_LL.shape
|
||||
B, L, _ = D_LL.shape
|
||||
|
||||
# Get regular intra-chain indices (limited to k_intra)
|
||||
intra_indices = get_sparse_attention_indices(
|
||||
tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask
|
||||
) # [D, L, k_intra]
|
||||
) # [B, L, k_intra]
|
||||
|
||||
# Get inter-chain indices for clash avoidance
|
||||
inter_indices = torch.zeros(D, L, k_inter, dtype=torch.long, device=D_LL.device)
|
||||
|
||||
for d in range(D):
|
||||
for l in range(L):
|
||||
query_chain = chain_id[l]
|
||||
inter_indices = torch.zeros(B, L, k_inter, dtype=torch.long, device=D_LL.device)
|
||||
unique_chains = torch.unique(chain_id)
|
||||
for b in range(B):
|
||||
for c in unique_chains:
|
||||
query_chain = chain_id[c]
|
||||
|
||||
# Find atoms from different chains
|
||||
other_chain_mask = (chain_id != query_chain) & base_mask[l, :]
|
||||
other_chain_mask = (chain_id != query_chain) & base_mask[c, :]
|
||||
other_chain_atoms = torch.where(other_chain_mask)[0]
|
||||
|
||||
if len(other_chain_atoms) > 0:
|
||||
# Get distances to other chains
|
||||
distances_to_other = D_LL[d, l, other_chain_atoms]
|
||||
# Get distances to other chains
|
||||
distances_to_other = D_LL[b, c, other_chain_atoms]
|
||||
|
||||
# Select k_inter closest atoms from other chains
|
||||
n_select = min(k_inter, len(other_chain_atoms))
|
||||
@@ -283,23 +283,23 @@ def get_sparse_attention_indices_with_inter_chain(
|
||||
selected_atoms = other_chain_atoms[closest_idx]
|
||||
|
||||
# Fill inter-chain indices
|
||||
inter_indices[d, l, :n_select] = selected_atoms
|
||||
inter_indices[b, c, :n_select] = selected_atoms
|
||||
# Pad with random atoms if needed
|
||||
if n_select < k_inter:
|
||||
padding = torch.randint(
|
||||
0, L, (k_inter - n_select,), device=D_LL.device
|
||||
)
|
||||
inter_indices[d, l, n_select:] = padding
|
||||
inter_indices[b, c, n_select:] = padding
|
||||
else:
|
||||
# No other chains found, fill with random indices
|
||||
inter_indices[d, l, :] = torch.randint(
|
||||
inter_indices[b, c, :] = torch.randint(
|
||||
0, L, (k_inter,), device=D_LL.device
|
||||
)
|
||||
|
||||
# Combine intra and inter chain indices
|
||||
combined_indices = torch.cat(
|
||||
[intra_indices, inter_indices], dim=-1
|
||||
) # [D, L, k_total]
|
||||
) # [B, L, k_total]
|
||||
|
||||
return combined_indices
|
||||
|
||||
|
||||
@@ -30,32 +30,32 @@ class ChunkedPositionPairDistEmbedder(nn.Module):
|
||||
|
||||
def compute_pairs_chunked(
|
||||
self,
|
||||
query_pos: torch.Tensor, # [D, 3]
|
||||
key_pos: torch.Tensor, # [D, k, 3]
|
||||
valid_mask: torch.Tensor, # [D, k, 1]
|
||||
query_pos: torch.Tensor, # [B, 3]
|
||||
key_pos: torch.Tensor, # [B, k, 3]
|
||||
valid_mask: torch.Tensor, # [B, k, 1]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute pairwise embeddings for specific query-key pairs.
|
||||
|
||||
Args:
|
||||
query_pos: Query positions [D, 3]
|
||||
key_pos: Key positions [D, k, 3]
|
||||
valid_mask: Valid pair mask [D, k, 1]
|
||||
query_pos: Query positions [B, 3]
|
||||
key_pos: Key positions [B, k, 3]
|
||||
valid_mask: Valid pair mask [B, k, 1]
|
||||
|
||||
Returns:
|
||||
P_sparse: Pairwise embeddings [D, k, c_atompair]
|
||||
P_sparse: Pairwise embeddings [B, k, c_atompair]
|
||||
"""
|
||||
D, k = key_pos.shape[:2]
|
||||
B, k = key_pos.shape[:2]
|
||||
|
||||
# Compute pairwise distances: [D, k, 3]
|
||||
D_pairs = query_pos.unsqueeze(1) - key_pos # [D, 1, 3] - [D, k, 3] = [D, k, 3]
|
||||
# Compute pairwise distances: [B, k, 3]
|
||||
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3]
|
||||
|
||||
if self.embed_frame:
|
||||
# Embed pairwise distances
|
||||
P_pairs = self.process_d(D_pairs) * valid_mask # [D, k, c_atompair]
|
||||
P_pairs = self.process_d(D_pairs) * valid_mask # [B, k, c_atompair]
|
||||
|
||||
# Add inverse distance embedding
|
||||
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [D, k, 1]
|
||||
norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [B, k, 1]
|
||||
inv_dist = 1 / (1 + norm_sq)
|
||||
P_pairs = P_pairs + self.process_inverse_dist(inv_dist) * valid_mask
|
||||
|
||||
@@ -95,19 +95,19 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
||||
|
||||
def compute_pairs_chunked(
|
||||
self,
|
||||
query_pos: torch.Tensor, # [D, 3]
|
||||
key_pos: torch.Tensor, # [D, k, 3]
|
||||
valid_mask: torch.Tensor, # [D, k, 1]
|
||||
query_pos: torch.Tensor, # [B, 3]
|
||||
key_pos: torch.Tensor, # [B, k, 3]
|
||||
valid_mask: torch.Tensor, # [B, k, 1]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute sinusoidal distance embeddings for specific query-key pairs.
|
||||
"""
|
||||
D, k = key_pos.shape[:2]
|
||||
B, k = key_pos.shape[:2]
|
||||
device = query_pos.device
|
||||
|
||||
# Compute pairwise distances
|
||||
D_pairs = query_pos.unsqueeze(1) - key_pos # [D, k, 3]
|
||||
dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [D, k]
|
||||
D_pairs = query_pos.unsqueeze(1) - key_pos # [B, k, 3]
|
||||
dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [B, k]
|
||||
|
||||
# Sinusoidal embedding
|
||||
half_dim = self.n_freqs
|
||||
@@ -117,13 +117,13 @@ class ChunkedSinusoidalDistEmbed(nn.Module):
|
||||
/ half_dim
|
||||
) # [n_freqs]
|
||||
|
||||
angles = dist_matrix.unsqueeze(-1) * freq # [D, k, n_freqs]
|
||||
angles = dist_matrix.unsqueeze(-1) * freq # [B, k, n_freqs]
|
||||
sin_embed = torch.sin(angles)
|
||||
cos_embed = torch.cos(angles)
|
||||
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [D, k, 2*n_freqs]
|
||||
sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [B, k, 2*n_freqs]
|
||||
|
||||
# Linear projection
|
||||
P_pairs = self.output_proj(sincos_embed) # [D, k, c_atompair]
|
||||
P_pairs = self.output_proj(sincos_embed) # [B, k, c_atompair]
|
||||
P_pairs = P_pairs * valid_mask
|
||||
|
||||
# Add linear embedding of valid mask
|
||||
@@ -191,8 +191,8 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
||||
def forward_chunked(
|
||||
self,
|
||||
f: dict,
|
||||
indices: torch.Tensor, # [D, L, k] - sparse attention indices
|
||||
C_L: torch.Tensor, # [D, L, c_token] - atom features
|
||||
indices: torch.Tensor, # [B, L, k] - sparse attention indices
|
||||
C_L: torch.Tensor, # [B, L, c_token] - atom features
|
||||
Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features
|
||||
tok_idx: torch.Tensor, # [L] - atom to token mapping
|
||||
) -> torch.Tensor:
|
||||
@@ -208,20 +208,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
||||
|
||||
Args:
|
||||
f: Feature dictionary
|
||||
indices: Sparse attention indices [D, L, k]
|
||||
C_L: Atom-level features [D, L, c_token]
|
||||
indices: Sparse attention indices [B, L, k]
|
||||
C_L: Atom-level features [B, L, c_token]
|
||||
Z_init_II: Token-level pair features [I, I, c_z]
|
||||
tok_idx: Atom to token mapping [L]
|
||||
|
||||
Returns:
|
||||
P_LL_sparse: Sparse pairwise features [D, L, k, c_atompair]
|
||||
P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair]
|
||||
"""
|
||||
D, L, k = indices.shape
|
||||
B, L, k = indices.shape
|
||||
device = indices.device
|
||||
|
||||
# Initialize sparse P_LL
|
||||
P_LL_sparse = torch.zeros(
|
||||
D, L, k, self.c_atompair, device=device, dtype=C_L.dtype
|
||||
B, L, k, self.c_atompair, device=device, dtype=C_L.dtype
|
||||
)
|
||||
|
||||
# Handle both batched and non-batched C_L
|
||||
@@ -237,71 +237,72 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
||||
if valid_indices.dim() == 2: # [L, k] - add batch dimension
|
||||
valid_indices = valid_indices.unsqueeze(0).expand(
|
||||
C_L.shape[0], -1, -1
|
||||
) # [D, L, k]
|
||||
) # [B, L, k]
|
||||
|
||||
# 1. Motif position embedding (if exists)
|
||||
if self.motif_pos_embedder is not None and "motif_pos" in f:
|
||||
motif_pos = f["motif_pos"] # [L, 3]
|
||||
is_motif = f["is_motif_atom_with_fixed_coord"] # [L]
|
||||
|
||||
is_motif_idx = torch.where(is_motif)[0]
|
||||
# For each query position
|
||||
for l in range(L):
|
||||
if is_motif[l]: # Only compute if query is motif
|
||||
key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
|
||||
key_pos = motif_pos[key_indices] # [D, k, 3]
|
||||
query_pos = motif_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
|
||||
for l in is_motif_idx:
|
||||
key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
|
||||
key_pos = motif_pos[key_indices] # [B, k, 3]
|
||||
query_pos = motif_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
|
||||
|
||||
# Valid mask: both query and keys must be motif
|
||||
key_is_motif = is_motif[key_indices] # [D, k]
|
||||
valid_mask = key_is_motif.unsqueeze(-1).float() # [D, k, 1]
|
||||
# Valid mask: both query and keys must be motif
|
||||
key_is_motif = is_motif[key_indices] # [B, k]
|
||||
valid_mask = key_is_motif.unsqueeze(-1).float() # [B, k, 1]
|
||||
|
||||
if valid_mask.sum() > 0:
|
||||
motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
|
||||
query_pos, key_pos, valid_mask
|
||||
)
|
||||
P_LL_sparse[:, l, :, :] += motif_pairs
|
||||
if valid_mask.sum() > 0:
|
||||
motif_pairs = self.motif_pos_embedder.compute_pairs_chunked(
|
||||
query_pos, key_pos, valid_mask
|
||||
)
|
||||
P_LL_sparse[:, l, :, :] += motif_pairs
|
||||
|
||||
# 2. Reference position embedding (if exists)
|
||||
if self.ref_pos_embedder is not None and "ref_pos" in f:
|
||||
ref_pos = f["ref_pos"] # [L, 3]
|
||||
ref_space_uid = f["ref_space_uid"] # [L]
|
||||
is_motif_seq = f["is_motif_atom_with_fixed_seq"] # [L]
|
||||
is_motif_seq_idx = torch.where(is_motif_seq)[0]
|
||||
for l in is_motif_seq_idx:
|
||||
key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices
|
||||
key_pos = ref_pos[key_indices] # [B, k, 3]
|
||||
query_pos = ref_pos[l].unsqueeze(0).expand(B, -1) # [B, 3]
|
||||
|
||||
for l in range(L):
|
||||
if is_motif_seq[l]: # Only compute if query has sequence
|
||||
key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices
|
||||
key_pos = ref_pos[key_indices] # [D, k, 3]
|
||||
query_pos = ref_pos[l].unsqueeze(0).expand(D, -1) # [D, 3]
|
||||
# Valid mask: same token and both have sequence
|
||||
key_space_uid = ref_space_uid[key_indices] # [B, k]
|
||||
key_is_motif_seq = is_motif_seq[key_indices] # [B, k]
|
||||
|
||||
# Valid mask: same token and both have sequence
|
||||
key_space_uid = ref_space_uid[key_indices] # [D, k]
|
||||
key_is_motif_seq = is_motif_seq[key_indices] # [D, k]
|
||||
same_token = key_space_uid == ref_space_uid[l] # [B, k]
|
||||
valid_mask = (
|
||||
(same_token & key_is_motif_seq).unsqueeze(-1).float()
|
||||
) # [B, k, 1]
|
||||
|
||||
same_token = key_space_uid == ref_space_uid[l] # [D, k]
|
||||
valid_mask = (
|
||||
(same_token & key_is_motif_seq).unsqueeze(-1).float()
|
||||
) # [D, k, 1]
|
||||
|
||||
if valid_mask.sum() > 0:
|
||||
ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
|
||||
query_pos, key_pos, valid_mask
|
||||
)
|
||||
P_LL_sparse[:, l, :, :] += ref_pairs
|
||||
if valid_mask.sum() > 0:
|
||||
ref_pairs = self.ref_pos_embedder.compute_pairs_chunked(
|
||||
query_pos, key_pos, valid_mask
|
||||
)
|
||||
P_LL_sparse[:, l, :, :] += ref_pairs
|
||||
|
||||
# 3. Single embedding terms (broadcasted)
|
||||
# Expand C_L to match valid_indices batch dimension
|
||||
if C_L.shape[0] != B:
|
||||
C_L = C_L.expand(B, -1, -1) # [B, L, c_token]
|
||||
# Gather key features for each query
|
||||
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token]
|
||||
C_L_keys = torch.gather(
|
||||
C_L.unsqueeze(2).expand(-1, -1, k, -1),
|
||||
C_L_queries,
|
||||
1,
|
||||
valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]),
|
||||
) # [D, L, k, c_token]
|
||||
C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [D, L, k, c_token]
|
||||
) # [B, L, k, c_token]
|
||||
|
||||
# Add single embeddings - match standard implementation structure
|
||||
# Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3)
|
||||
# We need to broadcast from [D, L, k, c_atompair] to match this
|
||||
single_l = self.process_single_l(C_L_queries) # [D, L, k, c_atompair]
|
||||
single_m = self.process_single_m(C_L_keys) # [D, L, k, c_atompair]
|
||||
# We need to broadcast from [B, L, k, c_atompair] to match this
|
||||
single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair]
|
||||
single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair]
|
||||
P_LL_sparse += single_l + single_m
|
||||
|
||||
# 4. Token pair features Z_init_II
|
||||
@@ -312,15 +313,16 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
||||
else:
|
||||
tok_idx_expanded = tok_idx
|
||||
|
||||
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [D, L, k]
|
||||
# Expand tok_idx_expanded to match valid_indices batch dimension
|
||||
if tok_idx_expanded.shape[0] != B:
|
||||
tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L]
|
||||
tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k]
|
||||
# Use valid_indices for token mapping as well
|
||||
tok_keys = torch.gather(
|
||||
tok_idx_expanded.unsqueeze(2).expand(-1, -1, k), 1, valid_indices
|
||||
) # [D, L, k]
|
||||
tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k]
|
||||
|
||||
# Gather Z_init_II[tok_queries, tok_keys] with safe indexing
|
||||
# Z_init_II shape is [I, I, c_z] (3D), not 4D
|
||||
# tok_queries shape: [D, L, k] - each value is a token index
|
||||
# tok_queries shape: [B, L, k] - each value is a token index
|
||||
# We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k
|
||||
|
||||
I_z, I_z2, c_z = Z_init_II.shape
|
||||
@@ -338,20 +340,20 @@ class ChunkedPairwiseEmbedder(nn.Module):
|
||||
# Then we need to gather the sparse version
|
||||
|
||||
Z_pairs_processed = torch.zeros(
|
||||
D, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
|
||||
B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype
|
||||
)
|
||||
|
||||
for d in range(D):
|
||||
for b in range(B):
|
||||
# For this batch, get the token queries and keys
|
||||
tq = tok_queries[d] # [L, k]
|
||||
tk = tok_keys[d] # [L, k]
|
||||
tq = tok_queries[b] # [L, k]
|
||||
tk = tok_keys[b] # [L, k]
|
||||
|
||||
# Ensure indices are within bounds
|
||||
tq = torch.clamp(tq, 0, I_z - 1)
|
||||
tk = torch.clamp(tk, 0, I_z2 - 1)
|
||||
|
||||
# Apply the double token indexing like standard implementation
|
||||
Z_pairs_processed[d] = Z_processed[tq, tk] # [L, k, c_atompair]
|
||||
Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair]
|
||||
|
||||
P_LL_sparse += Z_pairs_processed
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ name = "rc-foundry"
|
||||
dynamic = ["version"]
|
||||
description = "Shared utilities and training infrastructure for biomolecular structure prediction models."
|
||||
readme = "README.md"
|
||||
requires-python = "==3.12"
|
||||
requires-python = ">=3.12"
|
||||
authors = [
|
||||
{ name = "Institute for Protein Design", email = "contact@ipd.uw.edu" },
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user