This commit is contained in:
Ishaan Mathur
2025-10-07 17:50:39 +00:00
parent 79c1208e96
commit 67861b681c
3 changed files with 273 additions and 241 deletions

View File

@@ -31,3 +31,7 @@ repos:
language: system
types: [python]
pass_filenames: true # For speed, we only check the files that are changed
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.2
hooks:
- id: gitleaks

View File

@@ -29,14 +29,6 @@ from esm.sdk.retry import retry_decorator
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
from esm.utils.msa import MSA
from esm.utils.structure.input_builder import (
StructurePredictionInput,
serialize_structure_prediction_input,
)
from esm.utils.structure.molecular_complex import (
MolecularComplex,
MolecularComplexResult,
)
from esm.utils.types import FunctionAnnotation
@@ -217,70 +209,6 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
return self._process_fold_response(data, sequence)
@retry_decorator
async def async_fold_all_atom(
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
Args:
all_atom_input: StructurePredictionInput containing sequences for different molecule types
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input, model_name if model_name is not None else self.model
)
try:
data = await self._async_post("fold_all_atom", request)
except ESMProteinError as e:
return e
return self._process_fold_all_atom_response(data)
@retry_decorator
def fold_all_atom(
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
Args:
all_atom_input: StructurePredictionInput containing sequences for different molecule types
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input, model_name if model_name is not None else self.model
)
try:
data = self._post("fold_all_atom", request)
except ESMProteinError as e:
return e
return self._process_fold_all_atom_response(data)
@staticmethod
def _process_fold_all_atom_request(
all_atom_input: StructurePredictionInput, model_name: str | None = None
) -> dict[str, Any]:
request: dict[str, Any] = {
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
"model": model_name,
}
return request
@staticmethod
def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult:
complex_data = data.get("complex")
molecular_complex = MolecularComplex.from_state_dict(complex_data)
return MolecularComplexResult(
complex=molecular_complex,
plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True),
ptm=data.get("ptm", None),
distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True),
)
@retry_decorator
async def async_inverse_fold(
self,

View File

@@ -9,11 +9,13 @@ from subprocess import check_output
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, List
import biotite.structure as bs
import biotite.structure.io.pdbx as pdbx
import brotli
import msgpack
import numpy as np
import torch
from biotite.structure.io.pdbx import CIFFile, set_structure
from esm.utils import residue_constants
from esm.utils.structure.metrics import compute_lddt, compute_rmsd
@@ -52,9 +54,11 @@ class Molecule:
token_idx: int
atom_positions: np.ndarray # [N_atoms, 3]
atom_elements: np.ndarray # [N_atoms] element strings
residue_type: int
molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3
confidence: float
atom_names: np.ndarray | None = None # [N_atoms] atom names (optional)
atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional)
residue_type: int = 0
molecule_type: int = 0 # PROTEIN=0, RNA=1, DNA=2, LIGAND=3
confidence: float = 0.0
@dataclass(frozen=True)
@@ -76,21 +80,40 @@ class MolecularComplex:
# Token-to-atom mapping for efficient access
token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array
# Chain information
chain_id: np.ndarray # [N_tokens] chain identifier for each token
# Confidence data
plddt: np.ndarray # Per-token confidence scores [N_tokens]
# Metadata
metadata: MolecularComplexMetadata
# Optional atom names and hetero flags (preserved from original structures)
atom_names: np.ndarray | None = None # [N_atoms] atom names (optional)
atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional)
def __post_init__(self):
"""Validate array dimensions."""
n_tokens = len(self.sequence)
n_atoms = len(self.atom_positions)
assert (
self.token_to_atoms.shape[0] == n_tokens
), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens"
assert (
self.chain_id.shape[0] == n_tokens
), f"chain_id shape {self.chain_id.shape} != {n_tokens} tokens"
assert (
self.plddt.shape[0] == n_tokens
), f"plddt shape {self.plddt.shape} != {n_tokens} tokens"
if self.atom_names is not None:
assert (
self.atom_names.shape[0] == n_atoms
), f"atom_names shape {self.atom_names.shape} != {n_atoms} atoms"
if self.atom_hetero is not None:
assert (
self.atom_hetero.shape[0] == n_atoms
), f"atom_hetero shape {self.atom_hetero.shape} != {n_atoms} atoms"
def __len__(self) -> int:
"""Return number of tokens."""
@@ -109,6 +132,12 @@ class MolecularComplex:
# Extract atom data for this token
token_atom_positions = self.atom_positions[start_atom:end_atom]
token_atom_elements = self.atom_elements[start_atom:end_atom]
token_atom_names = None
if self.atom_names is not None:
token_atom_names = self.atom_names[start_atom:end_atom]
token_atom_hetero = None
if self.atom_hetero is not None:
token_atom_hetero = self.atom_hetero[start_atom:end_atom]
# Default values for residue/molecule type (would be extended based on actual implementation)
residue_type = 0 # Default to standard residue
@@ -119,6 +148,8 @@ class MolecularComplex:
token_idx=idx,
atom_positions=token_atom_positions,
atom_elements=token_atom_elements,
atom_names=token_atom_names,
atom_hetero=token_atom_hetero,
residue_type=residue_type,
molecule_type=molecule_type,
confidence=self.plddt[idx],
@@ -151,6 +182,8 @@ class MolecularComplex:
# Convert atom37 to flat arrays
flat_positions = []
flat_elements = []
flat_names = []
flat_hetero = []
token_to_atoms = []
atom_idx = 0
@@ -180,6 +213,12 @@ class MolecularComplex:
) # First character is element
flat_elements.append(element)
# Add atom name
flat_names.append(atom_name)
# Add hetero flag (all proteins are non-hetero)
flat_hetero.append(False)
atom_idx += 1
# Record token-to-atom mapping [start_idx, end_idx)
@@ -189,17 +228,20 @@ class MolecularComplex:
# Convert to numpy arrays
atom_positions = np.array(flat_positions, dtype=np.float32)
atom_elements = np.array(flat_elements, dtype=object)
atom_names = np.array(flat_names, dtype=object)
atom_hetero = np.array(flat_hetero, dtype=bool)
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
# Extract confidence scores (skip chain breaks)
# Extract confidence scores and chain_ids (skip chain breaks)
confidence_scores = []
residue_idx = 0
for aa in pc.sequence:
chain_ids = []
for seq_idx, aa in enumerate(pc.sequence):
if aa != "|":
confidence_scores.append(pc.confidence[residue_idx])
residue_idx += 1
confidence_scores.append(pc.confidence[seq_idx])
chain_ids.append(pc.chain_id[seq_idx])
confidence_array = np.array(confidence_scores, dtype=np.float32)
chain_id_array = np.array(chain_ids, dtype=np.int64)
# Create metadata - convert entity IDs to strings for MolecularComplexMetadata
entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()}
@@ -215,8 +257,11 @@ class MolecularComplex:
atom_positions=atom_positions,
atom_elements=atom_elements,
token_to_atoms=token_to_atoms_array,
chain_id=chain_id_array,
plddt=confidence_array,
metadata=metadata,
atom_names=atom_names,
atom_hetero=atom_hetero,
)
def to_protein_complex(self) -> ProteinComplex:
@@ -251,13 +296,27 @@ class MolecularComplex:
atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32)
atom37_mask = np.zeros((n_residues, 37), dtype=bool)
# Convert tokens back to single-letter sequence
single_letter_sequence = "".join(
[residue_constants.restype_3to1[token] for token in protein_tokens]
)
# Extract confidence scores for protein residues only
# Extract confidence scores and chain_ids for protein residues only
protein_confidence = self.plddt[protein_indices]
protein_chain_ids = self.chain_id[protein_indices]
# Convert tokens back to single-letter sequence with chain breaks
single_letter_residues = []
prev_chain_id = None
for i, (token, chain_id_val) in enumerate(
zip(protein_tokens, protein_chain_ids)
):
# Add chain break if we're switching to a new chain
if prev_chain_id is not None and chain_id_val != prev_chain_id:
single_letter_residues.append("|")
single_letter_residues.append(residue_constants.restype_3to1[token])
prev_chain_id = chain_id_val
single_letter_sequence = "".join(single_letter_residues)
# Calculate final sequence length (includes chain breaks)
sequence_length = len(single_letter_sequence)
# Convert flat atoms back to atom37 representation
for res_idx, token_idx in enumerate(protein_indices):
@@ -283,19 +342,69 @@ class MolecularComplex:
atom37_mask[res_idx, atom_type_idx] = True
atom_count += 1
# Create other required arrays for ProteinComplex
# For simplicity, assume all protein residues belong to the same entity/chain
entity_id = np.zeros(n_residues, dtype=np.int64)
chain_id = np.zeros(n_residues, dtype=np.int64)
sym_id = np.zeros(n_residues, dtype=np.int64)
residue_index = np.arange(1, n_residues + 1, dtype=np.int64)
insertion_code = np.array([""] * n_residues, dtype=object)
# Create arrays that match sequence length (including chain breaks)
# Initialize arrays with proper size
chain_id_expanded = np.full(sequence_length, -1, dtype=np.int64)
entity_id_expanded = np.full(sequence_length, -1, dtype=np.int64)
sym_id_expanded = np.zeros(sequence_length, dtype=np.int64)
residue_index_expanded = np.zeros(sequence_length, dtype=np.int64)
insertion_code_expanded = np.array([""] * sequence_length, dtype=object)
confidence_expanded = np.zeros(sequence_length, dtype=np.float32)
atom37_positions_expanded = np.full(
(sequence_length, 37, 3), np.nan, dtype=np.float32
)
atom37_mask_expanded = np.zeros((sequence_length, 37), dtype=bool)
# Map residue data to sequence positions (skipping chain breaks)
residue_idx = 0
residue_counter_per_chain = {}
for seq_pos, char in enumerate(single_letter_sequence):
if char != "|":
# This is a residue position
chain_id_val = protein_chain_ids[residue_idx]
chain_id_expanded[seq_pos] = chain_id_val
entity_id_expanded[seq_pos] = chain_id_val # Simplified mapping
# Track residue numbering per chain
if chain_id_val not in residue_counter_per_chain:
residue_counter_per_chain[chain_id_val] = 1
else:
residue_counter_per_chain[chain_id_val] += 1
residue_index_expanded[seq_pos] = residue_counter_per_chain[
chain_id_val
]
confidence_expanded[seq_pos] = protein_confidence[residue_idx]
atom37_positions_expanded[seq_pos] = atom37_positions[residue_idx]
atom37_mask_expanded[seq_pos] = atom37_mask[residue_idx]
residue_idx += 1
# Chain break positions keep default values (-1, False, etc.)
# Use the expanded arrays
chain_id = chain_id_expanded
entity_id = entity_id_expanded
sym_id = sym_id_expanded
residue_index = residue_index_expanded
insertion_code = insertion_code_expanded
protein_confidence = confidence_expanded
atom37_positions = atom37_positions_expanded
atom37_mask = atom37_mask_expanded
# Create protein complex metadata preserving chain information
# Convert MolecularComplex metadata to ProteinComplex format
unique_chain_ids = np.unique(protein_chain_ids)
entity_lookup = {int(cid): int(cid) for cid in unique_chain_ids}
chain_lookup = {
int(cid): self.metadata.chain_lookup.get(int(cid), chr(65 + int(cid)))
for cid in unique_chain_ids
}
# Create simplified protein complex metadata
# Map the first entity/chain from molecular complex metadata
protein_metadata = ProteinComplexMetadata(
entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata)
chain_lookup={0: "A"}, # Single chain
entity_lookup=entity_lookup,
chain_lookup=chain_lookup,
assembly_composition=self.metadata.assembly_composition,
)
@@ -336,7 +445,9 @@ class MolecularComplex:
# Get structure - handle missing model information gracefully
try:
structure = pdbx.get_structure(mmcif_file, model=1)
structure = pdbx.get_structure(
mmcif_file, model=1, extra_fields=["b_factor"]
)
except (KeyError, ValueError):
# Fallback for mmCIF files without model information
try:
@@ -374,8 +485,11 @@ class MolecularComplex:
sequence_tokens = []
flat_positions = []
flat_elements = []
flat_names = []
flat_hetero = []
token_to_atoms = []
confidence_scores = []
chain_ids = [] # Track chain IDs for each token
atom_idx = 0
@@ -396,9 +510,16 @@ class MolecularComplex:
}
chain_residue_groups[chain_id][res_id]["atoms"].append(atom)
# Create a mapping from chain_id to numeric indices
chain_id_to_numeric = {
chain_id: idx
for idx, chain_id in enumerate(sorted(chain_residue_groups.keys()))
}
# Process each chain and residue
for chain_id in sorted(chain_residue_groups.keys()):
residues = chain_residue_groups[chain_id]
numeric_chain_id = chain_id_to_numeric[chain_id]
for res_id in sorted(residues.keys()):
residue_data = residues[res_id]
@@ -422,6 +543,9 @@ class MolecularComplex:
token_name = res_name
sequence_tokens.append(token_name)
chain_ids.append(
numeric_chain_id
) # Store the numeric chain ID for this token
token_start = atom_idx
# Add all atoms from this residue
@@ -432,6 +556,14 @@ class MolecularComplex:
element = atom.element
flat_elements.append(element)
# Get atom name
atom_name = atom.atom_name
flat_names.append(atom_name)
# Get hetero flag
hetero_flag = atom.hetero
flat_hetero.append(hetero_flag)
atom_idx += 1
# Record token-to-atom mapping
@@ -446,20 +578,36 @@ class MolecularComplex:
# Create minimal arrays if no atoms found
atom_positions = np.zeros((0, 3), dtype=np.float32)
atom_elements = np.zeros(0, dtype=object)
atom_names = np.zeros(0, dtype=object)
atom_hetero = np.zeros(0, dtype=bool)
token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32)
chain_id_array = (
np.array(chain_ids, dtype=np.int64)
if chain_ids
else np.zeros(len(sequence_tokens), dtype=np.int64)
)
else:
atom_positions = np.array(flat_positions, dtype=np.float32)
atom_elements = np.array(flat_elements, dtype=object)
atom_names = np.array(flat_names, dtype=object)
atom_hetero = np.array(flat_hetero, dtype=bool)
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
chain_id_array = np.array(chain_ids, dtype=np.int64)
confidence_array = np.array(confidence_scores, dtype=np.float32)
# Create metadata
# Create metadata using the chain_id_to_numeric mapping
if chain_residue_groups:
chain_lookup = {
numeric_id: chain_id
for chain_id, numeric_id in chain_id_to_numeric.items()
}
else:
chain_lookup = {}
metadata = MolecularComplexMetadata(
entity_lookup=entity_info,
chain_lookup={
i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys())
},
chain_lookup=chain_lookup,
assembly_composition=None,
)
@@ -475,168 +623,107 @@ class MolecularComplex:
atom_positions=atom_positions,
atom_elements=atom_elements,
token_to_atoms=token_to_atoms_array,
chain_id=chain_id_array,
plddt=confidence_array,
metadata=metadata,
atom_names=atom_names,
atom_hetero=atom_hetero,
)
def to_mmcif(self) -> str:
"""Write MolecularComplex to mmcif string.
"""Write MolecularComplex to mmcif string using biotite.
Returns:
String representation of the complex in mmCIF format
"""
# No need for element mapping - already using element characters
# Pre-allocate AtomArray
n_atoms = len(self.atom_positions)
atom_array = bs.AtomArray(length=n_atoms)
lines = []
# Set coordinates directly (already vectorized)
atom_array.coord = self.atom_positions
# Header
lines.append(f"data_{self.id}")
lines.append("#")
lines.append(f"_entry.id {self.id}")
lines.append("#")
# Pre-allocate per-atom arrays
atom_res_ids = np.zeros(n_atoms, dtype=np.int32)
atom_chain_ids = np.empty(n_atoms, dtype=object)
atom_res_names = np.empty(n_atoms, dtype=object)
atom_hetero = np.zeros(n_atoms, dtype=bool)
atom_bfactors = np.zeros(n_atoms, dtype=np.float32)
atom_names = np.empty(n_atoms, dtype=object)
# Structure metadata
lines.append("_struct.entry_id {}".format(self.id))
lines.append("_struct.title 'Protein Structure'")
lines.append("#")
# Track residue IDs per chain
chain_res_counters = {}
# Entity information
entity_id = 1
chain_counter = 0
lines.append("loop_")
lines.append("_entity.id")
lines.append("_entity.type")
lines.append("_entity.pdbx_description")
# Vectorized expansion of token-level to atom-level annotations
for token_idx, (start, end) in enumerate(self.token_to_atoms):
token = self.sequence[token_idx]
chain_id_numeric = self.chain_id[token_idx]
chain_id_str = self.metadata.chain_lookup.get(
int(chain_id_numeric), chr(65 + int(chain_id_numeric))
)
# Determine entities based on sequence
protein_tokens = []
other_tokens = []
# Track residue numbering per chain
if chain_id_numeric not in chain_res_counters:
chain_res_counters[chain_id_numeric] = 1
res_id = chain_res_counters[chain_id_numeric]
chain_res_counters[chain_id_numeric] += 1
for i, token in enumerate(self.sequence):
if token in residue_constants.restype_3to1:
protein_tokens.append((i, token))
else:
other_tokens.append((i, token))
if protein_tokens:
lines.append(f"{entity_id} polymer 'Protein chain'")
entity_id += 1
for token in set(token for _, token in other_tokens):
lines.append(f"{entity_id} non-polymer 'Ligand {token}'")
entity_id += 1
lines.append("#")
# Chain assignments
lines.append("loop_")
lines.append("_struct_asym.id")
lines.append("_struct_asym.entity_id")
chain_id = "A"
if protein_tokens:
lines.append(f"{chain_id} 1")
chain_counter += 1
chain_id = chr(ord(chain_id) + 1)
entity_id = 2
for token in set(token for _, token in other_tokens):
lines.append(f"{chain_id} {entity_id}")
entity_id += 1
chain_counter += 1
if chain_counter < 26:
chain_id = chr(ord(chain_id) + 1)
lines.append("#")
# Atom site information
lines.append("loop_")
lines.append("_atom_site.group_PDB")
lines.append("_atom_site.id")
lines.append("_atom_site.type_symbol")
lines.append("_atom_site.label_atom_id")
lines.append("_atom_site.label_alt_id")
lines.append("_atom_site.label_comp_id")
lines.append("_atom_site.label_asym_id")
lines.append("_atom_site.label_entity_id")
lines.append("_atom_site.label_seq_id")
lines.append("_atom_site.pdbx_PDB_ins_code")
lines.append("_atom_site.Cartn_x")
lines.append("_atom_site.Cartn_y")
lines.append("_atom_site.Cartn_z")
lines.append("_atom_site.occupancy")
lines.append("_atom_site.B_iso_or_equiv")
lines.append("_atom_site.pdbx_PDB_model_num")
lines.append("_atom_site.auth_seq_id")
lines.append("_atom_site.auth_comp_id")
lines.append("_atom_site.auth_asym_id")
lines.append("_atom_site.auth_atom_id")
atom_id = 1
seq_id = 1
chain_id = "A"
entity_id = 1
for token_idx, token in enumerate(self.sequence):
start_atom, end_atom = self.token_to_atoms[token_idx]
# Determine if this is a protein residue or ligand
# Determine if protein
is_protein = token in residue_constants.restype_3to1
group_pdb = "ATOM" if is_protein else "HETATM"
current_entity_id = 1 if is_protein else 2 # Simplified entity assignment
current_chain_id = "A" if is_protein else "B" # Simplified chain assignment
# Create atom names for this token
atom_names = []
if is_protein:
# Use standard protein atom names
res_atoms = residue_constants.residue_atoms.get(
# Get atom names for this residue
if self.atom_names is not None:
# Use stored atom names (preserves original names from mmCIF)
names = list(self.atom_names[start:end])
elif is_protein:
# Fallback: use standard protein atom names
standard_names = residue_constants.residue_atoms.get(
token, ["N", "CA", "C", "O"]
)
atom_names = res_atoms[: end_atom - start_atom]
names = standard_names[: end - start]
# Pad if needed
while len(names) < (end - start):
names.append(f"X{len(names)+1}")
else:
# Generate generic atom names for ligands
for i in range(end_atom - start_atom):
atom_names.append(f"C{i+1}")
# Fallback: generate names for ligands/nucleic acids
names = [f"C{i+1}" for i in range(end - start)]
# Pad atom names if needed
while len(atom_names) < (end_atom - start_atom):
atom_names.append(f"X{len(atom_names)+1}")
# Vectorized assignment for this token's atoms
atom_res_ids[start:end] = res_id
atom_chain_ids[start:end] = chain_id_str
atom_res_names[start:end] = token
# Use stored hetero flags if available, otherwise guess based on protein status
if self.atom_hetero is not None:
atom_hetero[start:end] = self.atom_hetero[start:end]
else:
atom_hetero[start:end] = not is_protein
atom_bfactors[start:end] = self.plddt[token_idx] * 100.0
atom_names[start:end] = names
# Write atoms for this token
for atom_idx_in_token, global_atom_idx in enumerate(
range(start_atom, end_atom)
):
pos = self.atom_positions[global_atom_idx]
element_char = self.atom_elements[global_atom_idx]
element_symbol = element_char if isinstance(element_char, str) else "C"
# Set all AtomArray attributes at once (convert object arrays to proper string arrays)
atom_array.res_id = atom_res_ids
atom_array.chain_id = np.array(atom_chain_ids, dtype="U4")
atom_array.res_name = np.array(atom_res_names, dtype="U4")
atom_array.hetero = atom_hetero
atom_array.b_factor = atom_bfactors
atom_array.atom_name = np.array(atom_names, dtype="U4")
atom_name = (
atom_names[atom_idx_in_token]
if atom_idx_in_token < len(atom_names)
else f"X{atom_idx_in_token+1}"
)
# Use existing elements or infer them from atom names
if self.atom_elements is not None and len(self.atom_elements) == n_atoms:
# Convert object array to proper string array for biotite
atom_array.element = np.array(self.atom_elements, dtype="U4")
else:
# Use biotite's built-in element inference
atom_array.element = bs.infer_elements(atom_array)
# Format atom site line
bfactor = (
self.plddt[token_idx] * 100.0
if len(self.plddt) > token_idx
else 50.0
)
# Create CIF file and set structure
cif_file = CIFFile()
set_structure(cif_file, atom_array, data_block=self.id)
line = (
f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . "
f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? "
f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 "
f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}"
)
lines.append(line)
atom_id += 1
seq_id += 1
lines.append("#")
return "\n".join(lines)
# Convert to string
output = io.StringIO()
cif_file.write(output)
return output.getvalue()
def dockq(self, native: "MolecularComplex") -> Any:
"""Compute DockQ score against native structure.
@@ -909,7 +996,10 @@ class MolecularComplex:
if isinstance(v, list) and k in [
"atom_positions",
"atom_elements",
"atom_names",
"atom_hetero",
"token_to_atoms",
"chain_id",
"plddt",
]:
dct[k] = np.array(v)
@@ -918,10 +1008,20 @@ class MolecularComplex:
if isinstance(v, np.ndarray):
if k in ["atom_positions", "plddt"]:
dct[k] = v.astype(np.float32)
elif k in ["token_to_atoms"]:
dct[k] = v.astype(np.int32)
elif k in ["token_to_atoms", "chain_id"]:
dct[k] = (
v.astype(np.int32)
if k == "token_to_atoms"
else v.astype(np.int64)
)
dct["metadata"] = MolecularComplexMetadata(**dct["metadata"])
# Backward compatibility: if chain_id is missing, create default array
if "chain_id" not in dct:
# Default all tokens to chain 0
dct["chain_id"] = np.zeros(len(dct["sequence"]), dtype=np.int64)
return cls(**dct)
@classmethod