diff --git a/models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py b/models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py index e3dd713..9058aa0 100644 --- a/models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py +++ b/models/rfd3/src/rfd3/inference/symmetry/symmetry_utils.py @@ -100,14 +100,11 @@ def make_symmetric_atom_array( src_atom_array is not None ), "Source atom array must be provided for symmetric motifs" frames = get_symmetry_frames_from_atom_array(src_atom_array, frames) - elif sym_conf.is_symmetric_motif is None: + else: + # At this point, asym case would have been caught by the check_symmetry_config function. ranked_logger.info( "No motifs found in atom array. Generating unconditional symmetric proteins." ) - else: - raise ValueError( - "Asymmetric motif inputs are not supported yet. Please provide a symmetric motif or no motifs." - ) # Add symmetry annotations to the asu atom array asu_atom_array = add_sym_annotations(asu_atom_array, sym_conf) diff --git a/models/rfd3/src/rfd3/model/inference_sampler.py b/models/rfd3/src/rfd3/model/inference_sampler.py index 01c8f57..6ed957e 100644 --- a/models/rfd3/src/rfd3/model/inference_sampler.py +++ b/models/rfd3/src/rfd3/model/inference_sampler.py @@ -1,4 +1,5 @@ import inspect +import time from dataclasses import dataclass from typing import Any, Literal @@ -8,6 +9,7 @@ from rfd3.inference.symmetry.symmetry_utils import apply_symmetry_to_xyz_atomwis from rfd3.model.cfg_utils import strip_X from foundry.common import exists +from foundry.utils.alignment import weighted_rigid_align from foundry.utils.ddp import RankedLogger from foundry.utils.rotation_augmentation import ( rot_vec_mul, @@ -112,7 +114,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig): ) # Fallback to smallest available step noise_schedule_original = self._construct_inference_noise_schedule( - device=coord_atom_lvl_to_be_noised.device + device=device ) noise_schedule = noise_schedule_original[-1:] # Just use the final step ranked_logger.info( @@ -223,6 +225,7 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig): # Handle chunked mode vs standard mode if "chunked_pairwise_embedder" in initializer_outputs: # Chunked mode: explicitly provide P_LL=None + tic = time.time() chunked_embedder = initializer_outputs[ "chunked_pairwise_embedder" ] # Don't pop, just get @@ -240,6 +243,8 @@ class SampleDiffusionWithMotif(SampleDiffusionConfig): initializer_outputs=other_outputs, **other_outputs, ) + toc = time.time() + ranked_logger.info(f"Chunked mode time: {toc - tic} seconds") else: # Standard mode: P_LL is included in initializer_outputs outs = diffusion_module( @@ -447,6 +452,7 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif): # Handle chunked mode vs standard mode (same as default sampler) if "chunked_pairwise_embedder" in initializer_outputs: # Chunked mode: explicitly provide P_LL=None + tic = time.time() chunked_embedder = initializer_outputs[ "chunked_pairwise_embedder" ] # Don't pop, just get @@ -464,6 +470,8 @@ class SampleDiffusionWithSymmetry(SampleDiffusionWithMotif): initializer_outputs=other_outputs, **other_outputs, ) + toc = time.time() + ranked_logger.info(f"Chunked mode time: {toc - tic} seconds") else: # Standard mode: P_LL is included in initializer_outputs outs = diffusion_module( diff --git a/models/rfd3/src/rfd3/model/layers/block_utils.py b/models/rfd3/src/rfd3/model/layers/block_utils.py index e9d0b69..aeac08c 100644 --- a/models/rfd3/src/rfd3/model/layers/block_utils.py +++ b/models/rfd3/src/rfd3/model/layers/block_utils.py @@ -118,14 +118,14 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices): Parameters ---------- - P_LK_indices : (D, L, k) LongTensor + P_LK_indices : (B, L, k) LongTensor Key indices | P_LK_indices[d, i, k] = global atom index for which atom i attends to. - P_LK : (D, L, k, c) FloatTensor + P_LK : (B, L, k, c) FloatTensor Key features to scatter add into - P_LA_indices : (D, L, a) LongTensor + P_LA_indices : (B, L, a) LongTensor Additional feature indices to scatter into P_LK. - P_LA : (D, L, a, c) FloatTensor + P_LA : (B, L, a, c) FloatTensor Features corresponding to P_LA. Both index tensors contain indices representing D batch dim, @@ -135,42 +135,42 @@ def scatter_add_pair_features(P_LK_tgt, P_LK_indices, P_LA_src, P_LA_indices): """ # Handle case when indices and P_LA don't have batch dimensions - D, L, k = P_LK_indices.shape + B, L, k = P_LK_indices.shape if P_LA_indices.ndim == 2: - P_LA_indices = P_LA_indices.unsqueeze(0).expand(D, -1, -1) + P_LA_indices = P_LA_indices.unsqueeze(0).expand(B, -1, -1) if P_LA_src.ndim == 3: - P_LA_src = P_LA_src.unsqueeze(0).expand(D, -1, -1) + P_LA_src = P_LA_src.unsqueeze(0).expand(B, -1, -1) assert ( P_LA_src.shape[-1] == P_LK_tgt.shape[-1] ), "Channel dims do not match, got: {} vs {}".format( P_LA_src.shape[-1], P_LK_tgt.shape[-1] ) - matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (D, L, a, k) + matches = P_LA_indices.unsqueeze(-1) == P_LK_indices.unsqueeze(-2) # (B, L, a, k) if not torch.all(matches.sum(dim=(-1, -2)) >= 1): raise ValueError("Found multiple scatter indices for some atoms") elif not torch.all(matches.sum(dim=-1) <= 1): raise ValueError("Did not find a scatter index for every atom") - k_indices = matches.long().argmax(dim=-1) # (D, L, a) + k_indices = matches.long().argmax(dim=-1) # (B, L, a) scatter_indices = k_indices.unsqueeze(-1).expand( -1, -1, -1, P_LK_tgt.shape[-1] - ) # (D, L, a, c) + ) # (B, L, a, c) P_LK_tgt = P_LK_tgt.scatter_add(dim=2, index=scatter_indices, src=P_LA_src) return P_LK_tgt def _batched_gather(values: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """ - values : (D, L, C) - idx : (D, L, k) - returns: (D, L, k, C) + values : (B, L, C) + idx : (B, L, k) + returns: (B, L, k, C) """ - D, L, C = values.shape + B, L, C = values.shape k = idx.shape[-1] - # (D, L, 1, C) → stride-0 along k → (D, L, k, C) + # (B, L, 1, C) → stride-0 along k → (B, L, k, C) src = values.unsqueeze(2).expand(-1, -1, k, -1) - idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (D, L, k, C) + idx = idx.unsqueeze(-1).expand(-1, -1, -1, C) # (B, L, k, C) return torch.gather(src, 1, idx) # dim=1 is the L-axis @@ -196,7 +196,7 @@ def create_attention_indices( X_L = torch.randn( (1, L, 3), device=device, dtype=torch.float ) # [L, 3] - random - D_LL = torch.cdist(X_L, X_L, p=2) # [D, L, L] - pairwise atom distances + D_LL = torch.cdist(X_L, X_L, p=2) # [B, L, L] - pairwise atom distances # Create attention indices using neighbour distances base_mask = ~f["unindexing_pair_mask"][ @@ -231,7 +231,7 @@ def create_attention_indices( k_max=k_actual, chain_id=chain_ids, base_mask=base_mask, - ) # [D, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query + ) # [B, L, k] | indices[b, i, j] = atom index for atom i to j-th attn query return attn_indices @@ -245,7 +245,7 @@ def get_sparse_attention_indices_with_inter_chain( Args: tok_idx: atom to token mapping - D_LL: pairwise distances [D, L, L] + D_LL: pairwise distances [B, L, L] n_seq_neighbours: number of sequence neighbors k_intra: number of intra-chain attention keys k_inter: number of inter-chain attention keys @@ -253,29 +253,29 @@ def get_sparse_attention_indices_with_inter_chain( base_mask: base mask for valid pairs Returns: - attn_indices: [D, L, k_total] where k_total = k_intra + k_inter + attn_indices: [B, L, k_total] where k_total = k_intra + k_inter """ - D, L, _ = D_LL.shape + B, L, _ = D_LL.shape # Get regular intra-chain indices (limited to k_intra) intra_indices = get_sparse_attention_indices( tok_idx, D_LL, n_seq_neighbours, k_intra, chain_id, base_mask - ) # [D, L, k_intra] + ) # [B, L, k_intra] # Get inter-chain indices for clash avoidance - inter_indices = torch.zeros(D, L, k_inter, dtype=torch.long, device=D_LL.device) - - for d in range(D): - for l in range(L): - query_chain = chain_id[l] + inter_indices = torch.zeros(B, L, k_inter, dtype=torch.long, device=D_LL.device) + unique_chains = torch.unique(chain_id) + for b in range(B): + for c in unique_chains: + query_chain = chain_id[c] # Find atoms from different chains - other_chain_mask = (chain_id != query_chain) & base_mask[l, :] + other_chain_mask = (chain_id != query_chain) & base_mask[c, :] other_chain_atoms = torch.where(other_chain_mask)[0] if len(other_chain_atoms) > 0: # Get distances to other chains - distances_to_other = D_LL[d, l, other_chain_atoms] + distances_to_other = D_LL[b, c, other_chain_atoms] # Select k_inter closest atoms from other chains n_select = min(k_inter, len(other_chain_atoms)) @@ -283,23 +283,23 @@ def get_sparse_attention_indices_with_inter_chain( selected_atoms = other_chain_atoms[closest_idx] # Fill inter-chain indices - inter_indices[d, l, :n_select] = selected_atoms + inter_indices[b, c, :n_select] = selected_atoms # Pad with random atoms if needed if n_select < k_inter: padding = torch.randint( 0, L, (k_inter - n_select,), device=D_LL.device ) - inter_indices[d, l, n_select:] = padding + inter_indices[b, c, n_select:] = padding else: # No other chains found, fill with random indices - inter_indices[d, l, :] = torch.randint( + inter_indices[b, c, :] = torch.randint( 0, L, (k_inter,), device=D_LL.device ) # Combine intra and inter chain indices combined_indices = torch.cat( [intra_indices, inter_indices], dim=-1 - ) # [D, L, k_total] + ) # [B, L, k_total] return combined_indices diff --git a/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py b/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py index 9f14d44..e34aa4c 100644 --- a/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py +++ b/models/rfd3/src/rfd3/model/layers/chunked_pairwise.py @@ -30,32 +30,32 @@ class ChunkedPositionPairDistEmbedder(nn.Module): def compute_pairs_chunked( self, - query_pos: torch.Tensor, # [D, 3] - key_pos: torch.Tensor, # [D, k, 3] - valid_mask: torch.Tensor, # [D, k, 1] + query_pos: torch.Tensor, # [B, 3] + key_pos: torch.Tensor, # [B, k, 3] + valid_mask: torch.Tensor, # [B, k, 1] ) -> torch.Tensor: """ Compute pairwise embeddings for specific query-key pairs. Args: - query_pos: Query positions [D, 3] - key_pos: Key positions [D, k, 3] - valid_mask: Valid pair mask [D, k, 1] + query_pos: Query positions [B, 3] + key_pos: Key positions [B, k, 3] + valid_mask: Valid pair mask [B, k, 1] Returns: - P_sparse: Pairwise embeddings [D, k, c_atompair] + P_sparse: Pairwise embeddings [B, k, c_atompair] """ - D, k = key_pos.shape[:2] + B, k = key_pos.shape[:2] - # Compute pairwise distances: [D, k, 3] - D_pairs = query_pos.unsqueeze(1) - key_pos # [D, 1, 3] - [D, k, 3] = [D, k, 3] + # Compute pairwise distances: [B, k, 3] + D_pairs = query_pos.unsqueeze(1) - key_pos # [B, 1, 3] - [B, k, 3] = [B, k, 3] if self.embed_frame: # Embed pairwise distances - P_pairs = self.process_d(D_pairs) * valid_mask # [D, k, c_atompair] + P_pairs = self.process_d(D_pairs) * valid_mask # [B, k, c_atompair] # Add inverse distance embedding - norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [D, k, 1] + norm_sq = torch.linalg.norm(D_pairs, dim=-1, keepdim=True) ** 2 # [B, k, 1] inv_dist = 1 / (1 + norm_sq) P_pairs = P_pairs + self.process_inverse_dist(inv_dist) * valid_mask @@ -95,19 +95,19 @@ class ChunkedSinusoidalDistEmbed(nn.Module): def compute_pairs_chunked( self, - query_pos: torch.Tensor, # [D, 3] - key_pos: torch.Tensor, # [D, k, 3] - valid_mask: torch.Tensor, # [D, k, 1] + query_pos: torch.Tensor, # [B, 3] + key_pos: torch.Tensor, # [B, k, 3] + valid_mask: torch.Tensor, # [B, k, 1] ) -> torch.Tensor: """ Compute sinusoidal distance embeddings for specific query-key pairs. """ - D, k = key_pos.shape[:2] + B, k = key_pos.shape[:2] device = query_pos.device # Compute pairwise distances - D_pairs = query_pos.unsqueeze(1) - key_pos # [D, k, 3] - dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [D, k] + D_pairs = query_pos.unsqueeze(1) - key_pos # [B, k, 3] + dist_matrix = torch.linalg.norm(D_pairs, dim=-1) # [B, k] # Sinusoidal embedding half_dim = self.n_freqs @@ -117,13 +117,13 @@ class ChunkedSinusoidalDistEmbed(nn.Module): / half_dim ) # [n_freqs] - angles = dist_matrix.unsqueeze(-1) * freq # [D, k, n_freqs] + angles = dist_matrix.unsqueeze(-1) * freq # [B, k, n_freqs] sin_embed = torch.sin(angles) cos_embed = torch.cos(angles) - sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [D, k, 2*n_freqs] + sincos_embed = torch.cat([sin_embed, cos_embed], dim=-1) # [B, k, 2*n_freqs] # Linear projection - P_pairs = self.output_proj(sincos_embed) # [D, k, c_atompair] + P_pairs = self.output_proj(sincos_embed) # [B, k, c_atompair] P_pairs = P_pairs * valid_mask # Add linear embedding of valid mask @@ -191,8 +191,8 @@ class ChunkedPairwiseEmbedder(nn.Module): def forward_chunked( self, f: dict, - indices: torch.Tensor, # [D, L, k] - sparse attention indices - C_L: torch.Tensor, # [D, L, c_token] - atom features + indices: torch.Tensor, # [B, L, k] - sparse attention indices + C_L: torch.Tensor, # [B, L, c_token] - atom features Z_init_II: torch.Tensor, # [I, I, c_z] - token pair features tok_idx: torch.Tensor, # [L] - atom to token mapping ) -> torch.Tensor: @@ -208,20 +208,20 @@ class ChunkedPairwiseEmbedder(nn.Module): Args: f: Feature dictionary - indices: Sparse attention indices [D, L, k] - C_L: Atom-level features [D, L, c_token] + indices: Sparse attention indices [B, L, k] + C_L: Atom-level features [B, L, c_token] Z_init_II: Token-level pair features [I, I, c_z] tok_idx: Atom to token mapping [L] Returns: - P_LL_sparse: Sparse pairwise features [D, L, k, c_atompair] + P_LL_sparse: Sparse pairwise features [B, L, k, c_atompair] """ - D, L, k = indices.shape + B, L, k = indices.shape device = indices.device # Initialize sparse P_LL P_LL_sparse = torch.zeros( - D, L, k, self.c_atompair, device=device, dtype=C_L.dtype + B, L, k, self.c_atompair, device=device, dtype=C_L.dtype ) # Handle both batched and non-batched C_L @@ -237,71 +237,72 @@ class ChunkedPairwiseEmbedder(nn.Module): if valid_indices.dim() == 2: # [L, k] - add batch dimension valid_indices = valid_indices.unsqueeze(0).expand( C_L.shape[0], -1, -1 - ) # [D, L, k] + ) # [B, L, k] # 1. Motif position embedding (if exists) if self.motif_pos_embedder is not None and "motif_pos" in f: motif_pos = f["motif_pos"] # [L, 3] is_motif = f["is_motif_atom_with_fixed_coord"] # [L] - + is_motif_idx = torch.where(is_motif)[0] # For each query position - for l in range(L): - if is_motif[l]: # Only compute if query is motif - key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices - key_pos = motif_pos[key_indices] # [D, k, 3] - query_pos = motif_pos[l].unsqueeze(0).expand(D, -1) # [D, 3] + for l in is_motif_idx: + key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices + key_pos = motif_pos[key_indices] # [B, k, 3] + query_pos = motif_pos[l].unsqueeze(0).expand(B, -1) # [B, 3] - # Valid mask: both query and keys must be motif - key_is_motif = is_motif[key_indices] # [D, k] - valid_mask = key_is_motif.unsqueeze(-1).float() # [D, k, 1] + # Valid mask: both query and keys must be motif + key_is_motif = is_motif[key_indices] # [B, k] + valid_mask = key_is_motif.unsqueeze(-1).float() # [B, k, 1] - if valid_mask.sum() > 0: - motif_pairs = self.motif_pos_embedder.compute_pairs_chunked( - query_pos, key_pos, valid_mask - ) - P_LL_sparse[:, l, :, :] += motif_pairs + if valid_mask.sum() > 0: + motif_pairs = self.motif_pos_embedder.compute_pairs_chunked( + query_pos, key_pos, valid_mask + ) + P_LL_sparse[:, l, :, :] += motif_pairs # 2. Reference position embedding (if exists) if self.ref_pos_embedder is not None and "ref_pos" in f: ref_pos = f["ref_pos"] # [L, 3] ref_space_uid = f["ref_space_uid"] # [L] is_motif_seq = f["is_motif_atom_with_fixed_seq"] # [L] + is_motif_seq_idx = torch.where(is_motif_seq)[0] + for l in is_motif_seq_idx: + key_indices = valid_indices[:, l, :] # [B, k] - use clamped indices + key_pos = ref_pos[key_indices] # [B, k, 3] + query_pos = ref_pos[l].unsqueeze(0).expand(B, -1) # [B, 3] - for l in range(L): - if is_motif_seq[l]: # Only compute if query has sequence - key_indices = valid_indices[:, l, :] # [D, k] - use clamped indices - key_pos = ref_pos[key_indices] # [D, k, 3] - query_pos = ref_pos[l].unsqueeze(0).expand(D, -1) # [D, 3] + # Valid mask: same token and both have sequence + key_space_uid = ref_space_uid[key_indices] # [B, k] + key_is_motif_seq = is_motif_seq[key_indices] # [B, k] - # Valid mask: same token and both have sequence - key_space_uid = ref_space_uid[key_indices] # [D, k] - key_is_motif_seq = is_motif_seq[key_indices] # [D, k] + same_token = key_space_uid == ref_space_uid[l] # [B, k] + valid_mask = ( + (same_token & key_is_motif_seq).unsqueeze(-1).float() + ) # [B, k, 1] - same_token = key_space_uid == ref_space_uid[l] # [D, k] - valid_mask = ( - (same_token & key_is_motif_seq).unsqueeze(-1).float() - ) # [D, k, 1] - - if valid_mask.sum() > 0: - ref_pairs = self.ref_pos_embedder.compute_pairs_chunked( - query_pos, key_pos, valid_mask - ) - P_LL_sparse[:, l, :, :] += ref_pairs + if valid_mask.sum() > 0: + ref_pairs = self.ref_pos_embedder.compute_pairs_chunked( + query_pos, key_pos, valid_mask + ) + P_LL_sparse[:, l, :, :] += ref_pairs # 3. Single embedding terms (broadcasted) + # Expand C_L to match valid_indices batch dimension + if C_L.shape[0] != B: + C_L = C_L.expand(B, -1, -1) # [B, L, c_token] # Gather key features for each query + C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [B, L, k, c_token] C_L_keys = torch.gather( - C_L.unsqueeze(2).expand(-1, -1, k, -1), + C_L_queries, 1, valid_indices.unsqueeze(-1).expand(-1, -1, -1, C_L.shape[-1]), - ) # [D, L, k, c_token] - C_L_queries = C_L.unsqueeze(2).expand(-1, -1, k, -1) # [D, L, k, c_token] + ) # [B, L, k, c_token] # Add single embeddings - match standard implementation structure # Standard does: self.process_single_l(C_L).unsqueeze(-2) + self.process_single_m(C_L).unsqueeze(-3) - # We need to broadcast from [D, L, k, c_atompair] to match this - single_l = self.process_single_l(C_L_queries) # [D, L, k, c_atompair] - single_m = self.process_single_m(C_L_keys) # [D, L, k, c_atompair] + # We need to broadcast from [B, L, k, c_atompair] to match this + single_l = self.process_single_l(C_L_queries) # [B, L, k, c_atompair] + single_m = self.process_single_m(C_L_keys) # [B, L, k, c_atompair] P_LL_sparse += single_l + single_m # 4. Token pair features Z_init_II @@ -312,15 +313,16 @@ class ChunkedPairwiseEmbedder(nn.Module): else: tok_idx_expanded = tok_idx - tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [D, L, k] + # Expand tok_idx_expanded to match valid_indices batch dimension + if tok_idx_expanded.shape[0] != B: + tok_idx_expanded = tok_idx_expanded.expand(B, -1) # [B, L] + tok_queries = tok_idx_expanded.unsqueeze(2).expand(-1, -1, k) # [B, L, k] # Use valid_indices for token mapping as well - tok_keys = torch.gather( - tok_idx_expanded.unsqueeze(2).expand(-1, -1, k), 1, valid_indices - ) # [D, L, k] + tok_keys = torch.gather(tok_queries, 1, valid_indices) # [B, L, k] # Gather Z_init_II[tok_queries, tok_keys] with safe indexing # Z_init_II shape is [I, I, c_z] (3D), not 4D - # tok_queries shape: [D, L, k] - each value is a token index + # tok_queries shape: [B, L, k] - each value is a token index # We want: Z_init_II[tok_queries[d,l,k], tok_keys[d,l,k], :] for all d,l,k I_z, I_z2, c_z = Z_init_II.shape @@ -338,20 +340,20 @@ class ChunkedPairwiseEmbedder(nn.Module): # Then we need to gather the sparse version Z_pairs_processed = torch.zeros( - D, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype + B, L, k, self.c_atompair, device=device, dtype=Z_processed.dtype ) - for d in range(D): + for b in range(B): # For this batch, get the token queries and keys - tq = tok_queries[d] # [L, k] - tk = tok_keys[d] # [L, k] + tq = tok_queries[b] # [L, k] + tk = tok_keys[b] # [L, k] # Ensure indices are within bounds tq = torch.clamp(tq, 0, I_z - 1) tk = torch.clamp(tk, 0, I_z2 - 1) # Apply the double token indexing like standard implementation - Z_pairs_processed[d] = Z_processed[tq, tk] # [L, k, c_atompair] + Z_pairs_processed[b] = Z_processed[tq, tk] # [L, k, c_atompair] P_LL_sparse += Z_pairs_processed diff --git a/pyproject.toml b/pyproject.toml index 9bcfbc5..20b669b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "rc-foundry" dynamic = ["version"] description = "Shared utilities and training infrastructure for biomolecular structure prediction models." readme = "README.md" -requires-python = "==3.12" +requires-python = ">=3.12" authors = [ { name = "Institute for Protein Design", email = "contact@ipd.uw.edu" }, ]