mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 09:04:23 +08:00
sync
This commit is contained in:
@@ -31,3 +31,7 @@ repos:
|
||||
language: system
|
||||
types: [python]
|
||||
pass_filenames: true # For speed, we only check the files that are changed
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.2
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
@@ -29,14 +29,6 @@ from esm.sdk.retry import retry_decorator
|
||||
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
|
||||
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
|
||||
from esm.utils.msa import MSA
|
||||
from esm.utils.structure.input_builder import (
|
||||
StructurePredictionInput,
|
||||
serialize_structure_prediction_input,
|
||||
)
|
||||
from esm.utils.structure.molecular_complex import (
|
||||
MolecularComplex,
|
||||
MolecularComplexResult,
|
||||
)
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -217,70 +209,6 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
|
||||
return self._process_fold_response(data, sequence)
|
||||
|
||||
@retry_decorator
|
||||
async def async_fold_all_atom(
|
||||
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
|
||||
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
||||
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
|
||||
|
||||
Args:
|
||||
all_atom_input: StructurePredictionInput containing sequences for different molecule types
|
||||
model_name: Override the client level model name if needed
|
||||
"""
|
||||
request = self._process_fold_all_atom_request(
|
||||
all_atom_input, model_name if model_name is not None else self.model
|
||||
)
|
||||
|
||||
try:
|
||||
data = await self._async_post("fold_all_atom", request)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return self._process_fold_all_atom_response(data)
|
||||
|
||||
@retry_decorator
|
||||
def fold_all_atom(
|
||||
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
|
||||
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
||||
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
|
||||
|
||||
Args:
|
||||
all_atom_input: StructurePredictionInput containing sequences for different molecule types
|
||||
model_name: Override the client level model name if needed
|
||||
"""
|
||||
request = self._process_fold_all_atom_request(
|
||||
all_atom_input, model_name if model_name is not None else self.model
|
||||
)
|
||||
|
||||
try:
|
||||
data = self._post("fold_all_atom", request)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return self._process_fold_all_atom_response(data)
|
||||
|
||||
@staticmethod
|
||||
def _process_fold_all_atom_request(
|
||||
all_atom_input: StructurePredictionInput, model_name: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
request: dict[str, Any] = {
|
||||
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
|
||||
"model": model_name,
|
||||
}
|
||||
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult:
|
||||
complex_data = data.get("complex")
|
||||
molecular_complex = MolecularComplex.from_state_dict(complex_data)
|
||||
return MolecularComplexResult(
|
||||
complex=molecular_complex,
|
||||
plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True),
|
||||
ptm=data.get("ptm", None),
|
||||
distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True),
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
async def async_inverse_fold(
|
||||
self,
|
||||
|
||||
@@ -9,11 +9,13 @@ from subprocess import check_output
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
import biotite.structure as bs
|
||||
import biotite.structure.io.pdbx as pdbx
|
||||
import brotli
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import torch
|
||||
from biotite.structure.io.pdbx import CIFFile, set_structure
|
||||
|
||||
from esm.utils import residue_constants
|
||||
from esm.utils.structure.metrics import compute_lddt, compute_rmsd
|
||||
@@ -52,9 +54,11 @@ class Molecule:
|
||||
token_idx: int
|
||||
atom_positions: np.ndarray # [N_atoms, 3]
|
||||
atom_elements: np.ndarray # [N_atoms] element strings
|
||||
residue_type: int
|
||||
molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3
|
||||
confidence: float
|
||||
atom_names: np.ndarray | None = None # [N_atoms] atom names (optional)
|
||||
atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional)
|
||||
residue_type: int = 0
|
||||
molecule_type: int = 0 # PROTEIN=0, RNA=1, DNA=2, LIGAND=3
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -76,21 +80,40 @@ class MolecularComplex:
|
||||
# Token-to-atom mapping for efficient access
|
||||
token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array
|
||||
|
||||
# Chain information
|
||||
chain_id: np.ndarray # [N_tokens] chain identifier for each token
|
||||
|
||||
# Confidence data
|
||||
plddt: np.ndarray # Per-token confidence scores [N_tokens]
|
||||
|
||||
# Metadata
|
||||
metadata: MolecularComplexMetadata
|
||||
|
||||
# Optional atom names and hetero flags (preserved from original structures)
|
||||
atom_names: np.ndarray | None = None # [N_atoms] atom names (optional)
|
||||
atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate array dimensions."""
|
||||
n_tokens = len(self.sequence)
|
||||
n_atoms = len(self.atom_positions)
|
||||
assert (
|
||||
self.token_to_atoms.shape[0] == n_tokens
|
||||
), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens"
|
||||
assert (
|
||||
self.chain_id.shape[0] == n_tokens
|
||||
), f"chain_id shape {self.chain_id.shape} != {n_tokens} tokens"
|
||||
assert (
|
||||
self.plddt.shape[0] == n_tokens
|
||||
), f"plddt shape {self.plddt.shape} != {n_tokens} tokens"
|
||||
if self.atom_names is not None:
|
||||
assert (
|
||||
self.atom_names.shape[0] == n_atoms
|
||||
), f"atom_names shape {self.atom_names.shape} != {n_atoms} atoms"
|
||||
if self.atom_hetero is not None:
|
||||
assert (
|
||||
self.atom_hetero.shape[0] == n_atoms
|
||||
), f"atom_hetero shape {self.atom_hetero.shape} != {n_atoms} atoms"
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of tokens."""
|
||||
@@ -109,6 +132,12 @@ class MolecularComplex:
|
||||
# Extract atom data for this token
|
||||
token_atom_positions = self.atom_positions[start_atom:end_atom]
|
||||
token_atom_elements = self.atom_elements[start_atom:end_atom]
|
||||
token_atom_names = None
|
||||
if self.atom_names is not None:
|
||||
token_atom_names = self.atom_names[start_atom:end_atom]
|
||||
token_atom_hetero = None
|
||||
if self.atom_hetero is not None:
|
||||
token_atom_hetero = self.atom_hetero[start_atom:end_atom]
|
||||
|
||||
# Default values for residue/molecule type (would be extended based on actual implementation)
|
||||
residue_type = 0 # Default to standard residue
|
||||
@@ -119,6 +148,8 @@ class MolecularComplex:
|
||||
token_idx=idx,
|
||||
atom_positions=token_atom_positions,
|
||||
atom_elements=token_atom_elements,
|
||||
atom_names=token_atom_names,
|
||||
atom_hetero=token_atom_hetero,
|
||||
residue_type=residue_type,
|
||||
molecule_type=molecule_type,
|
||||
confidence=self.plddt[idx],
|
||||
@@ -151,6 +182,8 @@ class MolecularComplex:
|
||||
# Convert atom37 to flat arrays
|
||||
flat_positions = []
|
||||
flat_elements = []
|
||||
flat_names = []
|
||||
flat_hetero = []
|
||||
token_to_atoms = []
|
||||
|
||||
atom_idx = 0
|
||||
@@ -180,6 +213,12 @@ class MolecularComplex:
|
||||
) # First character is element
|
||||
flat_elements.append(element)
|
||||
|
||||
# Add atom name
|
||||
flat_names.append(atom_name)
|
||||
|
||||
# Add hetero flag (all proteins are non-hetero)
|
||||
flat_hetero.append(False)
|
||||
|
||||
atom_idx += 1
|
||||
|
||||
# Record token-to-atom mapping [start_idx, end_idx)
|
||||
@@ -189,17 +228,20 @@ class MolecularComplex:
|
||||
# Convert to numpy arrays
|
||||
atom_positions = np.array(flat_positions, dtype=np.float32)
|
||||
atom_elements = np.array(flat_elements, dtype=object)
|
||||
atom_names = np.array(flat_names, dtype=object)
|
||||
atom_hetero = np.array(flat_hetero, dtype=bool)
|
||||
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
|
||||
|
||||
# Extract confidence scores (skip chain breaks)
|
||||
# Extract confidence scores and chain_ids (skip chain breaks)
|
||||
confidence_scores = []
|
||||
residue_idx = 0
|
||||
for aa in pc.sequence:
|
||||
chain_ids = []
|
||||
for seq_idx, aa in enumerate(pc.sequence):
|
||||
if aa != "|":
|
||||
confidence_scores.append(pc.confidence[residue_idx])
|
||||
residue_idx += 1
|
||||
confidence_scores.append(pc.confidence[seq_idx])
|
||||
chain_ids.append(pc.chain_id[seq_idx])
|
||||
|
||||
confidence_array = np.array(confidence_scores, dtype=np.float32)
|
||||
chain_id_array = np.array(chain_ids, dtype=np.int64)
|
||||
|
||||
# Create metadata - convert entity IDs to strings for MolecularComplexMetadata
|
||||
entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()}
|
||||
@@ -215,8 +257,11 @@ class MolecularComplex:
|
||||
atom_positions=atom_positions,
|
||||
atom_elements=atom_elements,
|
||||
token_to_atoms=token_to_atoms_array,
|
||||
chain_id=chain_id_array,
|
||||
plddt=confidence_array,
|
||||
metadata=metadata,
|
||||
atom_names=atom_names,
|
||||
atom_hetero=atom_hetero,
|
||||
)
|
||||
|
||||
def to_protein_complex(self) -> ProteinComplex:
|
||||
@@ -251,13 +296,27 @@ class MolecularComplex:
|
||||
atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32)
|
||||
atom37_mask = np.zeros((n_residues, 37), dtype=bool)
|
||||
|
||||
# Convert tokens back to single-letter sequence
|
||||
single_letter_sequence = "".join(
|
||||
[residue_constants.restype_3to1[token] for token in protein_tokens]
|
||||
)
|
||||
|
||||
# Extract confidence scores for protein residues only
|
||||
# Extract confidence scores and chain_ids for protein residues only
|
||||
protein_confidence = self.plddt[protein_indices]
|
||||
protein_chain_ids = self.chain_id[protein_indices]
|
||||
|
||||
# Convert tokens back to single-letter sequence with chain breaks
|
||||
single_letter_residues = []
|
||||
prev_chain_id = None
|
||||
|
||||
for i, (token, chain_id_val) in enumerate(
|
||||
zip(protein_tokens, protein_chain_ids)
|
||||
):
|
||||
# Add chain break if we're switching to a new chain
|
||||
if prev_chain_id is not None and chain_id_val != prev_chain_id:
|
||||
single_letter_residues.append("|")
|
||||
single_letter_residues.append(residue_constants.restype_3to1[token])
|
||||
prev_chain_id = chain_id_val
|
||||
|
||||
single_letter_sequence = "".join(single_letter_residues)
|
||||
|
||||
# Calculate final sequence length (includes chain breaks)
|
||||
sequence_length = len(single_letter_sequence)
|
||||
|
||||
# Convert flat atoms back to atom37 representation
|
||||
for res_idx, token_idx in enumerate(protein_indices):
|
||||
@@ -283,19 +342,69 @@ class MolecularComplex:
|
||||
atom37_mask[res_idx, atom_type_idx] = True
|
||||
atom_count += 1
|
||||
|
||||
# Create other required arrays for ProteinComplex
|
||||
# For simplicity, assume all protein residues belong to the same entity/chain
|
||||
entity_id = np.zeros(n_residues, dtype=np.int64)
|
||||
chain_id = np.zeros(n_residues, dtype=np.int64)
|
||||
sym_id = np.zeros(n_residues, dtype=np.int64)
|
||||
residue_index = np.arange(1, n_residues + 1, dtype=np.int64)
|
||||
insertion_code = np.array([""] * n_residues, dtype=object)
|
||||
# Create arrays that match sequence length (including chain breaks)
|
||||
# Initialize arrays with proper size
|
||||
chain_id_expanded = np.full(sequence_length, -1, dtype=np.int64)
|
||||
entity_id_expanded = np.full(sequence_length, -1, dtype=np.int64)
|
||||
sym_id_expanded = np.zeros(sequence_length, dtype=np.int64)
|
||||
residue_index_expanded = np.zeros(sequence_length, dtype=np.int64)
|
||||
insertion_code_expanded = np.array([""] * sequence_length, dtype=object)
|
||||
confidence_expanded = np.zeros(sequence_length, dtype=np.float32)
|
||||
atom37_positions_expanded = np.full(
|
||||
(sequence_length, 37, 3), np.nan, dtype=np.float32
|
||||
)
|
||||
atom37_mask_expanded = np.zeros((sequence_length, 37), dtype=bool)
|
||||
|
||||
# Map residue data to sequence positions (skipping chain breaks)
|
||||
residue_idx = 0
|
||||
residue_counter_per_chain = {}
|
||||
|
||||
for seq_pos, char in enumerate(single_letter_sequence):
|
||||
if char != "|":
|
||||
# This is a residue position
|
||||
chain_id_val = protein_chain_ids[residue_idx]
|
||||
|
||||
chain_id_expanded[seq_pos] = chain_id_val
|
||||
entity_id_expanded[seq_pos] = chain_id_val # Simplified mapping
|
||||
|
||||
# Track residue numbering per chain
|
||||
if chain_id_val not in residue_counter_per_chain:
|
||||
residue_counter_per_chain[chain_id_val] = 1
|
||||
else:
|
||||
residue_counter_per_chain[chain_id_val] += 1
|
||||
|
||||
residue_index_expanded[seq_pos] = residue_counter_per_chain[
|
||||
chain_id_val
|
||||
]
|
||||
confidence_expanded[seq_pos] = protein_confidence[residue_idx]
|
||||
atom37_positions_expanded[seq_pos] = atom37_positions[residue_idx]
|
||||
atom37_mask_expanded[seq_pos] = atom37_mask[residue_idx]
|
||||
|
||||
residue_idx += 1
|
||||
# Chain break positions keep default values (-1, False, etc.)
|
||||
|
||||
# Use the expanded arrays
|
||||
chain_id = chain_id_expanded
|
||||
entity_id = entity_id_expanded
|
||||
sym_id = sym_id_expanded
|
||||
residue_index = residue_index_expanded
|
||||
insertion_code = insertion_code_expanded
|
||||
protein_confidence = confidence_expanded
|
||||
atom37_positions = atom37_positions_expanded
|
||||
atom37_mask = atom37_mask_expanded
|
||||
|
||||
# Create protein complex metadata preserving chain information
|
||||
# Convert MolecularComplex metadata to ProteinComplex format
|
||||
unique_chain_ids = np.unique(protein_chain_ids)
|
||||
entity_lookup = {int(cid): int(cid) for cid in unique_chain_ids}
|
||||
chain_lookup = {
|
||||
int(cid): self.metadata.chain_lookup.get(int(cid), chr(65 + int(cid)))
|
||||
for cid in unique_chain_ids
|
||||
}
|
||||
|
||||
# Create simplified protein complex metadata
|
||||
# Map the first entity/chain from molecular complex metadata
|
||||
protein_metadata = ProteinComplexMetadata(
|
||||
entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata)
|
||||
chain_lookup={0: "A"}, # Single chain
|
||||
entity_lookup=entity_lookup,
|
||||
chain_lookup=chain_lookup,
|
||||
assembly_composition=self.metadata.assembly_composition,
|
||||
)
|
||||
|
||||
@@ -336,7 +445,9 @@ class MolecularComplex:
|
||||
|
||||
# Get structure - handle missing model information gracefully
|
||||
try:
|
||||
structure = pdbx.get_structure(mmcif_file, model=1)
|
||||
structure = pdbx.get_structure(
|
||||
mmcif_file, model=1, extra_fields=["b_factor"]
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
# Fallback for mmCIF files without model information
|
||||
try:
|
||||
@@ -374,8 +485,11 @@ class MolecularComplex:
|
||||
sequence_tokens = []
|
||||
flat_positions = []
|
||||
flat_elements = []
|
||||
flat_names = []
|
||||
flat_hetero = []
|
||||
token_to_atoms = []
|
||||
confidence_scores = []
|
||||
chain_ids = [] # Track chain IDs for each token
|
||||
|
||||
atom_idx = 0
|
||||
|
||||
@@ -396,9 +510,16 @@ class MolecularComplex:
|
||||
}
|
||||
chain_residue_groups[chain_id][res_id]["atoms"].append(atom)
|
||||
|
||||
# Create a mapping from chain_id to numeric indices
|
||||
chain_id_to_numeric = {
|
||||
chain_id: idx
|
||||
for idx, chain_id in enumerate(sorted(chain_residue_groups.keys()))
|
||||
}
|
||||
|
||||
# Process each chain and residue
|
||||
for chain_id in sorted(chain_residue_groups.keys()):
|
||||
residues = chain_residue_groups[chain_id]
|
||||
numeric_chain_id = chain_id_to_numeric[chain_id]
|
||||
|
||||
for res_id in sorted(residues.keys()):
|
||||
residue_data = residues[res_id]
|
||||
@@ -422,6 +543,9 @@ class MolecularComplex:
|
||||
token_name = res_name
|
||||
|
||||
sequence_tokens.append(token_name)
|
||||
chain_ids.append(
|
||||
numeric_chain_id
|
||||
) # Store the numeric chain ID for this token
|
||||
token_start = atom_idx
|
||||
|
||||
# Add all atoms from this residue
|
||||
@@ -432,6 +556,14 @@ class MolecularComplex:
|
||||
element = atom.element
|
||||
flat_elements.append(element)
|
||||
|
||||
# Get atom name
|
||||
atom_name = atom.atom_name
|
||||
flat_names.append(atom_name)
|
||||
|
||||
# Get hetero flag
|
||||
hetero_flag = atom.hetero
|
||||
flat_hetero.append(hetero_flag)
|
||||
|
||||
atom_idx += 1
|
||||
|
||||
# Record token-to-atom mapping
|
||||
@@ -446,20 +578,36 @@ class MolecularComplex:
|
||||
# Create minimal arrays if no atoms found
|
||||
atom_positions = np.zeros((0, 3), dtype=np.float32)
|
||||
atom_elements = np.zeros(0, dtype=object)
|
||||
atom_names = np.zeros(0, dtype=object)
|
||||
atom_hetero = np.zeros(0, dtype=bool)
|
||||
token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32)
|
||||
chain_id_array = (
|
||||
np.array(chain_ids, dtype=np.int64)
|
||||
if chain_ids
|
||||
else np.zeros(len(sequence_tokens), dtype=np.int64)
|
||||
)
|
||||
else:
|
||||
atom_positions = np.array(flat_positions, dtype=np.float32)
|
||||
atom_elements = np.array(flat_elements, dtype=object)
|
||||
atom_names = np.array(flat_names, dtype=object)
|
||||
atom_hetero = np.array(flat_hetero, dtype=bool)
|
||||
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
|
||||
chain_id_array = np.array(chain_ids, dtype=np.int64)
|
||||
|
||||
confidence_array = np.array(confidence_scores, dtype=np.float32)
|
||||
|
||||
# Create metadata
|
||||
# Create metadata using the chain_id_to_numeric mapping
|
||||
if chain_residue_groups:
|
||||
chain_lookup = {
|
||||
numeric_id: chain_id
|
||||
for chain_id, numeric_id in chain_id_to_numeric.items()
|
||||
}
|
||||
else:
|
||||
chain_lookup = {}
|
||||
|
||||
metadata = MolecularComplexMetadata(
|
||||
entity_lookup=entity_info,
|
||||
chain_lookup={
|
||||
i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys())
|
||||
},
|
||||
chain_lookup=chain_lookup,
|
||||
assembly_composition=None,
|
||||
)
|
||||
|
||||
@@ -475,168 +623,107 @@ class MolecularComplex:
|
||||
atom_positions=atom_positions,
|
||||
atom_elements=atom_elements,
|
||||
token_to_atoms=token_to_atoms_array,
|
||||
chain_id=chain_id_array,
|
||||
plddt=confidence_array,
|
||||
metadata=metadata,
|
||||
atom_names=atom_names,
|
||||
atom_hetero=atom_hetero,
|
||||
)
|
||||
|
||||
def to_mmcif(self) -> str:
|
||||
"""Write MolecularComplex to mmcif string.
|
||||
"""Write MolecularComplex to mmcif string using biotite.
|
||||
|
||||
Returns:
|
||||
String representation of the complex in mmCIF format
|
||||
"""
|
||||
# No need for element mapping - already using element characters
|
||||
# Pre-allocate AtomArray
|
||||
n_atoms = len(self.atom_positions)
|
||||
atom_array = bs.AtomArray(length=n_atoms)
|
||||
|
||||
lines = []
|
||||
# Set coordinates directly (already vectorized)
|
||||
atom_array.coord = self.atom_positions
|
||||
|
||||
# Header
|
||||
lines.append(f"data_{self.id}")
|
||||
lines.append("#")
|
||||
lines.append(f"_entry.id {self.id}")
|
||||
lines.append("#")
|
||||
# Pre-allocate per-atom arrays
|
||||
atom_res_ids = np.zeros(n_atoms, dtype=np.int32)
|
||||
atom_chain_ids = np.empty(n_atoms, dtype=object)
|
||||
atom_res_names = np.empty(n_atoms, dtype=object)
|
||||
atom_hetero = np.zeros(n_atoms, dtype=bool)
|
||||
atom_bfactors = np.zeros(n_atoms, dtype=np.float32)
|
||||
atom_names = np.empty(n_atoms, dtype=object)
|
||||
|
||||
# Structure metadata
|
||||
lines.append("_struct.entry_id {}".format(self.id))
|
||||
lines.append("_struct.title 'Protein Structure'")
|
||||
lines.append("#")
|
||||
# Track residue IDs per chain
|
||||
chain_res_counters = {}
|
||||
|
||||
# Entity information
|
||||
entity_id = 1
|
||||
chain_counter = 0
|
||||
lines.append("loop_")
|
||||
lines.append("_entity.id")
|
||||
lines.append("_entity.type")
|
||||
lines.append("_entity.pdbx_description")
|
||||
# Vectorized expansion of token-level to atom-level annotations
|
||||
for token_idx, (start, end) in enumerate(self.token_to_atoms):
|
||||
token = self.sequence[token_idx]
|
||||
chain_id_numeric = self.chain_id[token_idx]
|
||||
chain_id_str = self.metadata.chain_lookup.get(
|
||||
int(chain_id_numeric), chr(65 + int(chain_id_numeric))
|
||||
)
|
||||
|
||||
# Determine entities based on sequence
|
||||
protein_tokens = []
|
||||
other_tokens = []
|
||||
# Track residue numbering per chain
|
||||
if chain_id_numeric not in chain_res_counters:
|
||||
chain_res_counters[chain_id_numeric] = 1
|
||||
res_id = chain_res_counters[chain_id_numeric]
|
||||
chain_res_counters[chain_id_numeric] += 1
|
||||
|
||||
for i, token in enumerate(self.sequence):
|
||||
if token in residue_constants.restype_3to1:
|
||||
protein_tokens.append((i, token))
|
||||
else:
|
||||
other_tokens.append((i, token))
|
||||
|
||||
if protein_tokens:
|
||||
lines.append(f"{entity_id} polymer 'Protein chain'")
|
||||
entity_id += 1
|
||||
|
||||
for token in set(token for _, token in other_tokens):
|
||||
lines.append(f"{entity_id} non-polymer 'Ligand {token}'")
|
||||
entity_id += 1
|
||||
|
||||
lines.append("#")
|
||||
|
||||
# Chain assignments
|
||||
lines.append("loop_")
|
||||
lines.append("_struct_asym.id")
|
||||
lines.append("_struct_asym.entity_id")
|
||||
|
||||
chain_id = "A"
|
||||
if protein_tokens:
|
||||
lines.append(f"{chain_id} 1")
|
||||
chain_counter += 1
|
||||
chain_id = chr(ord(chain_id) + 1)
|
||||
|
||||
entity_id = 2
|
||||
for token in set(token for _, token in other_tokens):
|
||||
lines.append(f"{chain_id} {entity_id}")
|
||||
entity_id += 1
|
||||
chain_counter += 1
|
||||
if chain_counter < 26:
|
||||
chain_id = chr(ord(chain_id) + 1)
|
||||
|
||||
lines.append("#")
|
||||
|
||||
# Atom site information
|
||||
lines.append("loop_")
|
||||
lines.append("_atom_site.group_PDB")
|
||||
lines.append("_atom_site.id")
|
||||
lines.append("_atom_site.type_symbol")
|
||||
lines.append("_atom_site.label_atom_id")
|
||||
lines.append("_atom_site.label_alt_id")
|
||||
lines.append("_atom_site.label_comp_id")
|
||||
lines.append("_atom_site.label_asym_id")
|
||||
lines.append("_atom_site.label_entity_id")
|
||||
lines.append("_atom_site.label_seq_id")
|
||||
lines.append("_atom_site.pdbx_PDB_ins_code")
|
||||
lines.append("_atom_site.Cartn_x")
|
||||
lines.append("_atom_site.Cartn_y")
|
||||
lines.append("_atom_site.Cartn_z")
|
||||
lines.append("_atom_site.occupancy")
|
||||
lines.append("_atom_site.B_iso_or_equiv")
|
||||
lines.append("_atom_site.pdbx_PDB_model_num")
|
||||
lines.append("_atom_site.auth_seq_id")
|
||||
lines.append("_atom_site.auth_comp_id")
|
||||
lines.append("_atom_site.auth_asym_id")
|
||||
lines.append("_atom_site.auth_atom_id")
|
||||
|
||||
atom_id = 1
|
||||
seq_id = 1
|
||||
chain_id = "A"
|
||||
entity_id = 1
|
||||
|
||||
for token_idx, token in enumerate(self.sequence):
|
||||
start_atom, end_atom = self.token_to_atoms[token_idx]
|
||||
|
||||
# Determine if this is a protein residue or ligand
|
||||
# Determine if protein
|
||||
is_protein = token in residue_constants.restype_3to1
|
||||
group_pdb = "ATOM" if is_protein else "HETATM"
|
||||
current_entity_id = 1 if is_protein else 2 # Simplified entity assignment
|
||||
current_chain_id = "A" if is_protein else "B" # Simplified chain assignment
|
||||
|
||||
# Create atom names for this token
|
||||
atom_names = []
|
||||
if is_protein:
|
||||
# Use standard protein atom names
|
||||
res_atoms = residue_constants.residue_atoms.get(
|
||||
# Get atom names for this residue
|
||||
if self.atom_names is not None:
|
||||
# Use stored atom names (preserves original names from mmCIF)
|
||||
names = list(self.atom_names[start:end])
|
||||
elif is_protein:
|
||||
# Fallback: use standard protein atom names
|
||||
standard_names = residue_constants.residue_atoms.get(
|
||||
token, ["N", "CA", "C", "O"]
|
||||
)
|
||||
atom_names = res_atoms[: end_atom - start_atom]
|
||||
names = standard_names[: end - start]
|
||||
# Pad if needed
|
||||
while len(names) < (end - start):
|
||||
names.append(f"X{len(names)+1}")
|
||||
else:
|
||||
# Generate generic atom names for ligands
|
||||
for i in range(end_atom - start_atom):
|
||||
atom_names.append(f"C{i+1}")
|
||||
# Fallback: generate names for ligands/nucleic acids
|
||||
names = [f"C{i+1}" for i in range(end - start)]
|
||||
|
||||
# Pad atom names if needed
|
||||
while len(atom_names) < (end_atom - start_atom):
|
||||
atom_names.append(f"X{len(atom_names)+1}")
|
||||
# Vectorized assignment for this token's atoms
|
||||
atom_res_ids[start:end] = res_id
|
||||
atom_chain_ids[start:end] = chain_id_str
|
||||
atom_res_names[start:end] = token
|
||||
# Use stored hetero flags if available, otherwise guess based on protein status
|
||||
if self.atom_hetero is not None:
|
||||
atom_hetero[start:end] = self.atom_hetero[start:end]
|
||||
else:
|
||||
atom_hetero[start:end] = not is_protein
|
||||
atom_bfactors[start:end] = self.plddt[token_idx] * 100.0
|
||||
atom_names[start:end] = names
|
||||
|
||||
# Write atoms for this token
|
||||
for atom_idx_in_token, global_atom_idx in enumerate(
|
||||
range(start_atom, end_atom)
|
||||
):
|
||||
pos = self.atom_positions[global_atom_idx]
|
||||
element_char = self.atom_elements[global_atom_idx]
|
||||
element_symbol = element_char if isinstance(element_char, str) else "C"
|
||||
# Set all AtomArray attributes at once (convert object arrays to proper string arrays)
|
||||
atom_array.res_id = atom_res_ids
|
||||
atom_array.chain_id = np.array(atom_chain_ids, dtype="U4")
|
||||
atom_array.res_name = np.array(atom_res_names, dtype="U4")
|
||||
atom_array.hetero = atom_hetero
|
||||
atom_array.b_factor = atom_bfactors
|
||||
atom_array.atom_name = np.array(atom_names, dtype="U4")
|
||||
|
||||
atom_name = (
|
||||
atom_names[atom_idx_in_token]
|
||||
if atom_idx_in_token < len(atom_names)
|
||||
else f"X{atom_idx_in_token+1}"
|
||||
)
|
||||
# Use existing elements or infer them from atom names
|
||||
if self.atom_elements is not None and len(self.atom_elements) == n_atoms:
|
||||
# Convert object array to proper string array for biotite
|
||||
atom_array.element = np.array(self.atom_elements, dtype="U4")
|
||||
else:
|
||||
# Use biotite's built-in element inference
|
||||
atom_array.element = bs.infer_elements(atom_array)
|
||||
|
||||
# Format atom site line
|
||||
bfactor = (
|
||||
self.plddt[token_idx] * 100.0
|
||||
if len(self.plddt) > token_idx
|
||||
else 50.0
|
||||
)
|
||||
# Create CIF file and set structure
|
||||
cif_file = CIFFile()
|
||||
set_structure(cif_file, atom_array, data_block=self.id)
|
||||
|
||||
line = (
|
||||
f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . "
|
||||
f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? "
|
||||
f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 "
|
||||
f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}"
|
||||
)
|
||||
lines.append(line)
|
||||
atom_id += 1
|
||||
|
||||
seq_id += 1
|
||||
|
||||
lines.append("#")
|
||||
return "\n".join(lines)
|
||||
# Convert to string
|
||||
output = io.StringIO()
|
||||
cif_file.write(output)
|
||||
return output.getvalue()
|
||||
|
||||
def dockq(self, native: "MolecularComplex") -> Any:
|
||||
"""Compute DockQ score against native structure.
|
||||
@@ -909,7 +996,10 @@ class MolecularComplex:
|
||||
if isinstance(v, list) and k in [
|
||||
"atom_positions",
|
||||
"atom_elements",
|
||||
"atom_names",
|
||||
"atom_hetero",
|
||||
"token_to_atoms",
|
||||
"chain_id",
|
||||
"plddt",
|
||||
]:
|
||||
dct[k] = np.array(v)
|
||||
@@ -918,10 +1008,20 @@ class MolecularComplex:
|
||||
if isinstance(v, np.ndarray):
|
||||
if k in ["atom_positions", "plddt"]:
|
||||
dct[k] = v.astype(np.float32)
|
||||
elif k in ["token_to_atoms"]:
|
||||
dct[k] = v.astype(np.int32)
|
||||
elif k in ["token_to_atoms", "chain_id"]:
|
||||
dct[k] = (
|
||||
v.astype(np.int32)
|
||||
if k == "token_to_atoms"
|
||||
else v.astype(np.int64)
|
||||
)
|
||||
|
||||
dct["metadata"] = MolecularComplexMetadata(**dct["metadata"])
|
||||
|
||||
# Backward compatibility: if chain_id is missing, create default array
|
||||
if "chain_id" not in dct:
|
||||
# Default all tokens to chain 0
|
||||
dct["chain_id"] = np.zeros(len(dct["sequence"]), dtype=np.int64)
|
||||
|
||||
return cls(**dct)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user