Merge branch 'main' into ipsae

This commit is contained in:
Tim O'Donnell
2026-02-17 13:51:36 -05:00
committed by GitHub
40 changed files with 14501 additions and 140 deletions

View File

@@ -12,11 +12,14 @@ ENV DEBIAN_FRONTEND=noninteractive \
RUN apt-get update && \
apt-get install -y --no-install-recommends \
python3.10 \
python3.10-dev \
python3-pip \
python3-venv \
python3-wheel \
software-properties-common \
curl \
&& add-apt-repository -y ppa:deadsnakes/ppa \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
python3.11 \
python3.11-dev \
python3.11-venv \
build-essential \
git \
cmake \
@@ -30,9 +33,10 @@ RUN apt-get update && \
libboost-all-dev \
&& rm -rf /var/lib/apt/lists/*
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 && \
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 && \
python -m pip install --upgrade pip setuptools setuptools_scm wheel
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1 && \
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 && \
curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11 && \
python3.11 -m pip install --upgrade pip setuptools setuptools_scm wheel
WORKDIR /app
@@ -56,7 +60,4 @@ RUN groupadd --gid ${USER_GID} ${USERNAME} && \
RUN mkdir -p "${HF_HOME}" && chown -R ${USER_UID}:${USER_GID} "${HF_HOME}"
USER ${USERNAME}
WORKDIR /workspace
ENTRYPOINT ["boltzgen"]
CMD ["--help"]
WORKDIR /workspace

View File

@@ -4,7 +4,8 @@
<img src="assets/boltzgen.png" alt="BoltzGen logo" width="60%">
[Paper](https://hannes-stark.com/assets/boltzgen.pdf) |
[Slack](https://boltz.bio/join-slack) <br> <br>
[Slack](https://boltz.bio/join-slack) |
[Video](https://www.youtube.com/watch?v=9d_QWUUI1Qo) <br> <br>
![alt text](assets/cover.png)
</div>
@@ -73,10 +74,10 @@ docker build -t boltzgen .
# Run an example
mkdir -p workdir # output
mkdir -p cache # where models will be downloaded to
docker run --rm --gpus all -v "$(realpath workdir)":/workdir -v "$(realpath cache)":/cache -v "$(realpath example)":/example \
boltzgen run /example/vanilla_protein/1g13prot.yaml --output /workdir/test \
--protocol protein-anything \
--num_designs 2
docker run --rm --gpus all -v "$(realpath workdir)":/workdir -v "$(realpath cache)":/cache -v "$(realpath example)":/example boltzgen \
boltzgen run /example/vanilla_protein/1g13prot.yaml --output /workdir/test \
--protocol protein-anything \
--num_designs 2
```
In the example above, the model weights are downloaded the first time the image is run. To bake the weights into the image at build time, run:
@@ -144,7 +145,7 @@ When the pipeline completes your output directory will have:
- `/final_designs_metrics_<budget>.csv` — metrics for the selected final set.
- `/results_overview.pdf` — plots
# Protocols
# Protocols
| Protocol (design-target) | Appropriate for | Major config differences |
|--------------------------|---------------------------------------------------------------------------|------------------------|
@@ -153,6 +154,7 @@ When the pipeline completes your output directory will have:
| protein-small_molecule | Design proteins to bind small molecules | Includes binding affinity prediction. Includes `design folding` step. |
| antibody-anything | Design antibody CDRs | No Cys are generated in inverse folding. No `design folding` step. Don't compute largest hydrophobic patch. |
| nanobody-anything | Design nanobody CDRs | Same settings as antibody-anything |
| protein-redesign | Redesign or optimize existing proteins | No `design folding` step. Uses `design_mask` for target/template definition. |
All configuration parameters can be overridden using the `--config` option; see `boltzgen run --help` or the `Advanced Users` section below for details.
@@ -182,6 +184,7 @@ We provide many example `.yaml` files in the `example/` directory, including:
- `example/fab_targets/pdl1.yaml`
- `example/denovo_zinc_finger_against_dna/zinc_finger.yaml`
- `example/protein_binding_small_molecule/chorismite.yaml`
- `example/small_molecule_from_file_and_smiles/4g37.yaml`
Small example of a protein design against a target protein without binding site specified:
```yaml
@@ -311,6 +314,33 @@ constraints:
```
## Symmetric complex design (inverse-folding only)
For symmetric complexes (e.g., homo-dimers), the you can tie sequence generation during inverse folding by specifying the `symmetric_group` for each symmetric chain. The `protein-redesign` protocol allows for scoring complexes without separate binders/targets.
```yaml
entities:
- file:
path: symmetric_dimer.cif
include:
- chain:
id: A
res_index: 100..300
symmetric_group: 1 # Link chains A and B for symmetric sampling
- chain:
id: B
res_index: 100..300
symmetric_group: 1 # Same group = same sampled insertion length
# Mark residues to redesign on both chains
design:
- chain:
id: A
res_index: 200..210
- chain:
id: B
res_index: 200..210
```
# Running only specific pipeline steps
@@ -325,6 +355,16 @@ boltzgen run example/cyclotide/3ivq.yaml \
--num_designs 2
```
If you want to run only the inverse folding and subsequent design evaluation steps (but not the backbone design step), you can also run:
**Run only inverse_folding step:**
```bash
boltzgen run example/inverse_folding/1brs.yaml \
--output workbench/if-only \
--only_inverse_fold \
--inverse_fold_num_sequences 2
```
**Available steps:**
- `design` - Generate num_design candidates using the diffusion model based on your design specification
- `inverse_folding` - Redesign sequences from the previous step using our inverse folding model

View File

@@ -22,6 +22,7 @@ We provide many example `.yaml` files in the `example/` directory, including:
- [fab_targets/pdl1.yaml](fab_targets/pdl1.yaml)
- [denovo_zinc_finger_against_dna/zinc_finger.yaml](denovo_zinc_finger_against_dna/zinc_finger.yaml)
- [protein_binding_small_molecule/chorismite.yaml](protein_binding_small_molecule/chorismite.yaml)
- [small_molecule_from_file_and_smiles/4g37.yaml](small_molecule_from_file_and_smiles/4g37.yaml)
Small example of a protein design against a target protein without binding site specified:
```yaml
@@ -343,6 +344,9 @@ constraints:
atom2: [Q, 1, CK] # Connect sulfur of Cys-4 in chain R to atom CK in ligand Q
```
We now support constraints specifications of small molecules from the input file and from smiles. Check `examples/small_molecule_from_file_and_smiles/4g37.yaml`. Below is brief guidelines:
* Small molecules from the file: check `atom_name` from the CCD and specify it.
* Small molecules from the smiles: count index of target element from the smiles and specify its element type with index (e.g. C6, for 6th carbon from the smiles).
Here is a comprehensive list of all the keys from your YAML file with explanations for each.

View File

@@ -1,7 +1,14 @@
entities:
- protein:
- protein:
id: G
sequence: 15..20AAAAAAVTTTT18PPP # range between 15 and 20 inclusive on both sides
residue_constraints:
- position: 1
allowed: A # Only Alanine at position 1
- position: 3..5
disallowed: CM # No Cysteine or Methionine at positions 3-5
- position: 8
allowed: AGS # Only Ala, Gly, or Ser at position 8
- protein:
id: R
sequence: 3..5C6C3 # Random number of design residues between 3 and 5, then a Cystein, then 6 design residues, then ...

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
entities:
- file:
path: 1brs.cif
# Include target and binder
include:
- chain:
id: A
- chain:
id: D
# Redesign specified binder residues
design:
- chain:
id: D
res_index: 33..46

View File

@@ -0,0 +1,33 @@
# Test file for per-residue amino acid constraints
# This tests the new residue_constraints feature
#
# Usage:
# boltzgen check example/residue_constraints_test.yaml
# boltzgen run example/residue_constraints_test.yaml --output test_output/ --steps design inverse_folding --num_designs 50
#
# Note: Use --num_designs 50 (not 5) for statistically meaningful verification.
# With only 5 designs, blacklist constraints have a ~21% false-pass probability.
#
# Note: Supports both string format ("AGS") and list format ([A, G, S])
# String format is preferred for consistency with sequence/binding_types
entities:
- protein:
id: A
sequence: 10
residue_constraints:
# Position 1: Force Alanine only
- position: 1
allowed: A
# Positions 3-5: Exclude Cysteine and Methionine
- position: 3..5
disallowed: CM
# Position 8: Allow only small amino acids
- position: 8
allowed: AGS
# Position 10: Force Proline
- position: 10
allowed: P

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
entities:
- protein:
id: B
sequence: 1..3C1..2C1..3 ###1to3 residues before first CYS, 1to2 residues, second CYS, followed by 1to3 residues
- file:
path: 4g37.pdb
include:
- chain:
id: A
- chain:
id: C # XLX
- ligand:
id: D
smiles: "CC(=O)NCCNC(C)=O"
constraints:
## constraints from the file.
# covalent bond between C3 in XLX (from CCD) and SG in chain A of 4g37.pdb
- bond:
atom1: [A, 105, SG]
atom2: [C, 1, C3]
# additional covalent bond between C11 in XLX (from CCD) and SG in chain B (design)
- bond:
atom1: [B, 2, SG]
atom2: [C, 1, C11]
## constraints from the smiles.
- bond:
# covalent bond between C1 in D (from smiles) and OG in chain A of 4g37.pdb
atom1: [A, 339, OG]
atom2: [D, 1, C1]
# additional covalent bond between C6 in D (from smiles) and SG in chain B (design)
- bond:
atom1: [B, 4, SG]
atom2: [D, 1, C6]
#boltzgen run example/small_molecule_from_file_and_smiles/4g37.yaml --output workbench --protocol protein-anything --num_designs 2 --devices 1 --steps design

View File

@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "boltzgen"
requires-python = ">=3.11"
version = "0.1.8"
version = "0.3.0"
readme = { file = "PYPI_DESCRIPTION.md", content-type = "text/markdown" }
description = "Protein design"
dependencies = [

View File

@@ -34,7 +34,7 @@ if [[ "$MODE" == "submit" ]]; then
elif [[ "$MODE" == "process" ]]; then
boltzgen merge "$OUT"/task-* --output "$MERGED_OUT"
boltzgen run "$DESIGN_SPEC" --steps filtering --output "$MERGED_OUT"
boltzgen run "$DESIGN_SPEC" --steps filtering --protocol protein-anything --output "$MERGED_OUT"
else
echo "Usage: $0 {submit|process}"

View File

@@ -100,6 +100,16 @@ protocol_configs = {
"analysis": ["largest_hydrophobic=false", "largest_hydrophobic_refolded=false"],
"filtering": ["filter_cysteine=true"],
},
"protein-redesign": {
# For redesigning/optimizing existing proteins (e.g., symmetric dimers)
# where all chains may have designed residues. Skips design_folding and
# uses design_mask (not chain_design_mask) for target/template definition.
"folding": ["data.design_mask_templates=true"],
"analysis": ["use_design_mask_for_target=true"],
"filtering": [
"metrics_override={design_to_target_iptm: null, neg_min_design_to_target_pae: null, design_ptm: null, plip_hbonds_refolded: null, plip_saltbridge_refolded: null, delta_sasa_refolded: null, plip_hbonds: null, plip_saltbridge: null, delta_sasa_original: null, design_residue_iptm: 1, iptm: 2, ptm: 3, neg_filter_rmsd_design: 4}",
],
},
}
assert all(
step_name in step_names for cfg in protocol_configs.values() for step_name in cfg
@@ -981,6 +991,30 @@ class BinderDesignPipeline:
)
if args.only_inverse_fold:
exclude_residues = []
inverse_fold_avoid = (
args.inverse_fold_avoid
if args.inverse_fold_avoid is not None
else (
"C"
if protocol
in [
"peptide-anything",
"nanobody-anything",
"antibody-anything",
]
else ""
)
)
for one_letter_code in inverse_fold_avoid:
exclude_residues.append(const.prot_letter_to_token[one_letter_code])
if len(exclude_residues) > 0:
print(
f"Inverse fold will avoid the following residues: {exclude_residues}"
)
print(f"Inverse-folded designs will be saved to: {output_dir}")
# Designs from inverse folding
self.steps.append(
PipelineStep(
@@ -991,9 +1025,12 @@ class BinderDesignPipeline:
f"data.cfg.yaml_path=[{', '.join(str(s) for s in args.design_spec)}]",
f"trainer.devices={devices}",
f"data.cfg.multiplicity={getattr(args, 'inverse_fold_num_sequences', 10)}",
f"data.cfg.skip_existing={args.reuse}",
f"data.cfg.output_dir={output_dir}",
f"override.use_kernels={use_kernels}",
f"checkpoint={get_artifact_path(args, args.inverse_fold_checkpoint)}",
f"data.cfg.moldir={moldir}",
f"override.inverse_fold_args.inverse_fold_restriction=[{', '.join(exclude_residues)}]",
]
+ config_args_by_step.get("inverse_folding", []),
)

View File

@@ -3469,11 +3469,17 @@ eval_keys_confidence = [
"design_iptm",
"design_iiptm",
"design_to_target_iptm",
"design_residue_iptm",
"target_ptm",
"design_ptm",
"design_ipsae_min",
"design_to_target_ipsae",
"target_to_design_ipsae",
"ligand_iptm",
"complex_plddt",
"complex_iplddt",
"complex_pde",
"complex_ipde",
]
eval_keys_affinity = [
@@ -3522,6 +3528,7 @@ token_features = [
"res_type_clone",
"is_standard",
"design_mask",
"aa_constraint_mask",
"binding_type",
"structure_group",
"token_bonds",

View File

@@ -1,4 +1,5 @@
import json
import warnings
from dataclasses import asdict, dataclass
from pathlib import Path
import re
@@ -291,6 +292,7 @@ Chain = [
("res_idx", np.dtype("i4")),
("res_num", np.dtype("i4")),
("cyclic_period", np.dtype("i4")),
("symmetric_group", np.dtype("i4")),
]
Interface = [
@@ -907,7 +909,7 @@ class Structure(NumpySerializable):
old_to_new = {old.item(): new for new, old in enumerate(atom_indices)}
old_to_new_res = {old.item(): new for new, old in enumerate(res_indices)}
old_to_new_res_chain = {}
res_chain_map = {}
for i in range(len(residues)):
original_atom_range = np.arange(
@@ -966,8 +968,7 @@ class Structure(NumpySerializable):
chain_start = orig_chain["res_idx"]
chain_end = orig_chain["res_idx"] + orig_chain["res_num"]
chain_res_indices = [r for r in res_indices if chain_start <= r < chain_end]
chain_res_indices -= orig_chain["res_idx"]
old_to_new_res_chain[i] = {
res_chain_map[i] = {
old.item(): new for new, old in enumerate(chain_res_indices)
}
@@ -996,8 +997,8 @@ class Structure(NumpySerializable):
chain_atom_start = chain["atom_idx"]
chain_atom_end = chain["atom_idx"] + chain["atom_num"]
if chain_atom_start <= res["atom_idx"] < chain_atom_end:
res_idx_item = residues[i]["res_idx"].item()
residues[i]["res_idx"] = old_to_new_res_chain[chain_idx].get(
res_idx_item = res_indices[i]
residues[i]["res_idx"] = res_chain_map[chain_idx].get(
res_idx_item
)
@@ -1216,7 +1217,8 @@ class Structure(NumpySerializable):
len(atom_data),
0,
len(res_data),
0,
0, # cyclic_period
0, # symmetric_group
)
]
@@ -1523,7 +1525,8 @@ class Structure(NumpySerializable):
num_atoms,
total_res,
len(chain_res_selector),
0,
0, # cyclic_period
0, # symmetric_group
)
)
total_res += len(chain_res_selector)
@@ -1759,6 +1762,18 @@ def biotite_array_from_feat(feat):
atom_pad_mask & atom_resolved_mask
].bool()
# add chain design mask
chain_design_mask = feat["chain_design_mask"].bool()
atom_chain_design_mask = (
(feat["atom_to_token"].float() @ chain_design_mask.unsqueeze(-1).float())
.bool()
.squeeze()
)
atom_array.add_annotation("is_chain_design", bool)
atom_array.is_chain_design = atom_chain_design_mask[
atom_pad_mask & atom_resolved_mask
].bool()
return atom_array
@@ -1916,6 +1931,7 @@ class DesignInfo(NumpySerializable):
res_structure_groups: npt.NDArray[np.int_]
res_ss_types: npt.NDArray[np.int_]
res_binding_type: npt.NDArray[np.int_]
res_aa_constraint_mask: npt.NDArray[np.float32] # Shape: (num_residues, 20), 0=allowed, 1=disallowed
@classmethod
def is_valid(self, info: "DesignInfo") -> bool:
@@ -1925,6 +1941,7 @@ class DesignInfo(NumpySerializable):
len(info.res_design_mask) == len(info.res_structure_groups)
and len(info.res_structure_groups) == len(info.res_ss_types)
and len(info.res_ss_types) == len(info.res_binding_type)
and len(info.res_aa_constraint_mask) == len(info.res_design_mask)
), (
"There must be a bug in the code. All residue level design info objects should have the same length."
)
@@ -1941,6 +1958,22 @@ class DesignInfo(NumpySerializable):
msg = "Misspecified design info. There were residues that have a secondary structure type specified but are not set to be designed."
raise ValueError(msg)
# Validate residue constraints
has_constraints = info.res_aa_constraint_mask.any(axis=1)
if any(has_constraints & ~info.res_design_mask.astype(bool)):
warnings.warn(
"Residue constraints specified for non-designed residues "
"will be ignored during inverse folding.",
UserWarning,
stacklevel=2,
)
# Check if any designed position has ALL amino acids blocked
all_blocked = info.res_aa_constraint_mask.all(axis=1)
if any(all_blocked & info.res_design_mask.astype(bool)):
msg = "Invalid residue constraints: some designed positions have all amino acids disallowed."
raise ValueError(msg)
return True
@@ -2035,11 +2068,13 @@ Token = [
("design_mask", np.dtype("?")),
("binding_type", np.dtype("i4")),
("structure_group", np.dtype("i4")),
("aa_constraint_mask", np.dtype("20f4")), # Per-residue AA constraints: 20 floats (one per canonical AA)
("ccd", np.dtype("5i4")),
("target_msa_mask", np.dtype("?")),
("design_ss_mask", np.dtype("?")),
("feature_asym_id", np.dtype("i4")),
("feature_res_idx", np.dtype("i4")),
("symmetric_group", np.dtype("i4")),
]
TokenBond = [

View File

@@ -701,6 +701,12 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
res_type = from_numpy(token_data["res_type"]).long()
is_standard = from_numpy(token_data["is_standard"])
design = from_numpy(token_data["design_mask"]).long()
# Per-residue amino acid constraint mask (shape: num_tokens x 20)
if "aa_constraint_mask" in token_data.dtype.names:
aa_constraint_mask = from_numpy(token_data["aa_constraint_mask"]).float()
else:
# Default: no constraints (all zeros = all AAs allowed)
aa_constraint_mask = torch.zeros(len(token_data), len(const.canonical_tokens))
res_type = one_hot(res_type, num_classes=const.num_tokens)
modified = from_numpy(token_data["modified"]).long()
ccd = from_numpy(token_data["ccd"]).long()
@@ -711,6 +717,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
design_ss_mask = from_numpy(token_data["design_ss_mask"])
feature_residue_index = from_numpy(token_data["feature_res_idx"]).long()
feature_asym_id = from_numpy(token_data["feature_asym_id"]).long()
symmetric_group = from_numpy(token_data["symmetric_group"]).long()
token_to_res = from_numpy(data.token_to_res).long()
method = (
@@ -863,6 +870,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
res_type = pad_dim(res_type, 0, pad_len)
is_standard = pad_dim(is_standard, 0, pad_len)
design = pad_dim(design, 0, pad_len)
aa_constraint_mask = pad_dim(aa_constraint_mask, 0, pad_len)
binding_type = pad_dim(binding_type, 0, pad_len)
structure_group = pad_dim(structure_group, 0, pad_len)
pad_mask = pad_dim(pad_mask, 0, pad_len)
@@ -887,6 +895,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
design_ss_mask = pad_dim(design_ss_mask, 0, pad_len)
feature_residue_index = pad_dim(feature_residue_index, 0, pad_len)
feature_asym_id = pad_dim(feature_asym_id, 0, pad_len)
symmetric_group = pad_dim(symmetric_group, 0, pad_len)
token_to_res = pad_dim(token_to_res, 0, pad_len)
token_features = {
"token_index": token_index,
@@ -899,6 +908,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
"res_type_clone": res_type.clone(),
"is_standard": is_standard,
"design_mask": design,
"aa_constraint_mask": aa_constraint_mask,
"binding_type": binding_type,
"structure_group": structure_group,
"token_bonds": bonds,
@@ -921,6 +931,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
"design_ss_mask": design_ss_mask,
"feature_residue_index": feature_residue_index,
"feature_asym_id": feature_asym_id,
"symmetric_group": symmetric_group,
"ligand_affinity_mask": ligand_affinity_mask,
"token_to_res": token_to_res,
}

View File

@@ -1,8 +1,9 @@
import itertools
import os
import pickle
from pathlib import Path
import random
from typing import Dict
from typing import Dict, List
import zipfile
import numpy as np
@@ -14,10 +15,44 @@ from boltzgen.data import const
from boltzgen.data.pad import pad_dim
from boltzgen.model.loss.confidence import lddt_dist
MOLDIR_ZIP_CACHE = {} # moldir -> open zip file
MOLDIR_ZIP_CACHE: Dict[
tuple[int, Path], zipfile.ZipFile
] = {} # moldir -> open zip file
def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
def _get_zipfile(moldir: Path) -> zipfile.ZipFile:
"""
Retrieve a cached ZipFile object for the given molecule directory zip file.
If the ZipFile for the provided path and current process does not exist in the cache,
it is created and stored. Otherwise, the cached instance is returned.
Parameters
----------
moldir : Path
Path to the .zip file containing molecule data.
Returns
-------
zipfile.ZipFile
An open ZipFile object for reading molecule data.
Notes
-----
The cache is keyed by both process ID and zip file path to prevent issues with forked processes.
"""
pid = os.getpid()
key = (pid, moldir)
zf = MOLDIR_ZIP_CACHE.get(key)
if zf is None:
zf = zipfile.ZipFile(moldir, "r")
MOLDIR_ZIP_CACHE[key] = zf
return zf
def load_molecules(moldir: str, molecules: List[str]) -> Dict[str, Mol]:
"""Load the given input data.
Parameters
@@ -33,12 +68,10 @@ def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
The loaded molecules.
"""
moldir = Path(moldir)
loaded_mols = {}
loaded_mols: Dict[str, Mol] = {}
if moldir.is_file() and moldir.suffix == ".zip":
if moldir not in MOLDIR_ZIP_CACHE:
MOLDIR_ZIP_CACHE[moldir] = zipfile.ZipFile(moldir, "r")
zip_file = MOLDIR_ZIP_CACHE[moldir]
zip_file = _get_zipfile(moldir)
for molecule in molecules:
pkl_filename = f"{molecule}.pkl"
with zip_file.open(pkl_filename, "r") as f:
@@ -621,12 +654,16 @@ def get_chain_symmetries(cropped, backbone_only, atom14, atom37, max_n_symmetrie
crop_to_all_atom_map.shape[0]
)
connections_edge_index = []
crop_atom_set = set(crop_to_all_atom_map.astype(np.int64))
for connection in structure.bonds:
if (connection["chain_1"] == connection["chain_2"]) and (
connection["res_1"] == connection["res_2"]
):
continue
connections_edge_index.append([connection["atom_1"], connection["atom_2"]])
atom_1, atom_2 = connection["atom_1"], connection["atom_2"]
# Only include bonds where BOTH atoms are in the crop
if atom_1 in crop_atom_set and atom_2 in crop_atom_set:
connections_edge_index.append([atom_1, atom_2])
if len(connections_edge_index) > 0:
connections_edge_index = np.array(connections_edge_index, dtype=np.int64).T
connections_edge_index = all_atom_to_crop_map[connections_edge_index]

View File

@@ -679,6 +679,8 @@ def parse_polymer( # noqa: C901, PLR0915, PLR0912
If the alignment fails.
"""
assert entity_poly_seq is not None
# Since the polymer object already contains the global idx, we don't need to perform the alignment
sequence = [_entity[1] for _entity in entity_poly_seq]
@@ -911,8 +913,8 @@ def parse_mmcif( # noqa: C901, PLR0915, PLR0912
Path to the MMCIF file.
use_original_res_idx : bool
Uses the res_idx for the res_idx in the Residues in the returned structure
that was in the mmcif file for each residue instead of using the index in the
Uses the res_idx for the res_idx in the Residues in the returned structure
that was in the mmcif file for each residue instead of using the index in the
seqres that is obtained after aligning the seqres to the sequence of amino acids from the present residues.
Returns
@@ -1152,6 +1154,7 @@ def mmcif_from_block( # noqa: C901, PLR0915, PLR0912
mols=mols,
moldir=moldir,
use_original_res_idx=use_original_res_idx,
entity_poly_seq=entity_poly_seq[entity.name],
)
if parsed_polymer is not None:
ensemble_chains[ref_chain_map[subchain_id]] = parsed_polymer
@@ -1255,6 +1258,7 @@ def mmcif_from_block( # noqa: C901, PLR0915, PLR0912
res_idx,
res_num,
0, # cyclic period
0, # symmetric_group (default, can be overridden via YAML)
)
)
chain_to_idx[chain.name] = asym_id

View File

@@ -117,6 +117,7 @@ def patch_chain_names(structure: ParsedStructure) -> ParsedStructure:
chain["res_idx"],
chain["res_num"],
chain["cyclic_period"],
chain["symmetric_group"],
)
)
chains_new = np.array(chains_new, dtype=Chain)

View File

@@ -144,6 +144,7 @@ class ParsedChain:
cyclic_period: int
sequence: Optional[str] = None
sampleidx_to_specidx: Optional[np.ndarray] = None
symmetric_group: int = 0
@dataclass(frozen=True)
@@ -335,6 +336,12 @@ yaml_keys = [
"leaving_atoms",
"atom",
"use_assembly",
"symmetric_group",
# Per-residue amino acid constraints
"residue_constraints",
"position",
"allowed",
"disallowed",
]
@@ -483,6 +490,7 @@ def parse_polymer(
components: dict[str, Mol],
cyclic: bool,
mol_dir: Path,
symmetric_group: int = 0,
) -> Optional[ParsedChain]:
"""Process a sequence into a chain object.
@@ -630,6 +638,7 @@ def parse_polymer(
cyclic_period=cyclic_period,
sequence=raw_sequence,
sampleidx_to_specidx=sampleidx_to_specidx,
symmetric_group=symmetric_group,
)
@@ -675,6 +684,182 @@ def parse_range(ranges, c_start=0, c_end=None):
return indices
def _normalize_aa_spec(aa_spec) -> list[str]:
"""Normalize amino acid specification to a list of individual codes.
Supports both BoltzGen conventions:
- String format: "AGS" (concatenated 1-letter codes, consistent with sequence/binding_types)
- List format: [A, G, S] or [ALA, GLY, SER]
Parameters
----------
aa_spec : str or list
Amino acid specification in string or list format
Returns
-------
list[str]
List of individual amino acid codes
"""
if isinstance(aa_spec, str):
# String format: "AGS" -> ["A", "G", "S"]
# Handle both "AGS" and "ALA" (single 3-letter code)
aa_spec = aa_spec.strip().upper()
if len(aa_spec) <= 3 and aa_spec.isalpha():
# Could be single 3-letter code like "ALA" or 1-3 single letters like "A", "AG", "AGS"
# Check if it's a valid 3-letter code
if len(aa_spec) == 3 and aa_spec in ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"]:
return [aa_spec]
# Otherwise treat as concatenated 1-letter codes
return list(aa_spec)
else:
# Longer string: treat as concatenated 1-letter codes
return list(aa_spec)
elif isinstance(aa_spec, list):
# List format: [A, G, S] or [ALA, GLY, SER]
return [str(x).strip().upper() for x in aa_spec]
else:
raise ValueError(f"Invalid amino acid specification: {aa_spec}")
def _convert_aa_names_to_indices(
aa_names: list,
canonical_tokens: list[str],
prot_letter_to_token: dict[str, str],
) -> list[int]:
"""Convert amino acid names (1-letter or 3-letter) to canonical token indices.
Parameters
----------
aa_names : list
List of amino acid names (1-letter like 'A' or 3-letter like 'ALA')
canonical_tokens : list[str]
List of canonical 3-letter amino acid codes
prot_letter_to_token : dict[str, str]
Mapping from 1-letter to 3-letter codes
Returns
-------
list[int]
List of indices into canonical_tokens
"""
indices = []
for name in aa_names:
name = str(name).strip().upper()
# Convert 1-letter to 3-letter if needed
if len(name) == 1:
if name not in prot_letter_to_token:
raise ValueError(f"Unknown amino acid code: {name}")
name = prot_letter_to_token[name]
# Find index in canonical_tokens
if name not in canonical_tokens:
raise ValueError(f"Unknown amino acid: {name}")
indices.append(canonical_tokens.index(name))
return indices
def parse_residue_constraints(
constraints_spec: list,
chain_length: int,
canonical_tokens: list[str],
prot_letter_to_token: dict[str, str],
) -> np.ndarray:
"""Parse residue_constraints into a per-residue constraint mask.
Parameters
----------
constraints_spec : list
List of constraint specifications from YAML
chain_length : int
Length of the chain (number of residues)
canonical_tokens : list[str]
List of canonical 3-letter amino acid codes (20 AAs)
prot_letter_to_token : dict[str, str]
Mapping from 1-letter to 3-letter codes
Returns
-------
np.ndarray
Shape (chain_length, 20) where:
- 0.0 means allowed
- 1.0 means disallowed (will be converted to -inf logit bias in model)
Notes
-----
Overlapping constraints use **intersection** semantics: if multiple
constraints cover the same position, only amino acids allowed by ALL
of them survive. For example, ``allowed: AG`` at pos 1..10 followed
by ``allowed: GS`` at pos 5..15 results in only G being allowed at
positions 5-10 (the intersection of {A,G} and {G,S}).
"""
num_aa = len(canonical_tokens) # Should be 20
constraint_mask = np.zeros((chain_length, num_aa), dtype=np.float32)
for constraint in constraints_spec:
# Parse position(s)
position_spec = constraint.get("position")
if position_spec is None:
raise ValueError("residue_constraints: 'position' is required")
# Use parse_range to handle single positions and ranges (1-indexed)
positions = parse_range(str(position_spec), c_start=0, c_end=chain_length)
# Validate positions are within bounds
for pos in positions:
if pos < 0 or pos >= chain_length:
raise ValueError(
f"Position {pos + 1} is out of bounds for chain of length {chain_length}"
)
# Parse amino acid specification
allowed = constraint.get("allowed", None)
disallowed = constraint.get("disallowed", None)
# Validate: cannot have both allowed and disallowed
if allowed is not None and disallowed is not None:
raise ValueError(
f"Position {position_spec}: cannot specify both 'allowed' and 'disallowed'"
)
if allowed is None and disallowed is None:
raise ValueError(
f"Position {position_spec}: must specify either 'allowed' or 'disallowed'"
)
if allowed is not None:
# Whitelist mode: block all except specified AAs
# Uses np.maximum to accumulate with existing constraints (intersection semantics):
# if a position already has constraints, only AAs allowed by BOTH survive.
aa_list = _normalize_aa_spec(allowed)
if len(aa_list) == 0:
raise ValueError(
f"Position {position_spec}: 'allowed' cannot be empty"
)
aa_indices = _convert_aa_names_to_indices(
aa_list, canonical_tokens, prot_letter_to_token
)
new_block = np.ones(num_aa, dtype=np.float32)
for idx in aa_indices:
new_block[idx] = 0.0
for pos in positions:
constraint_mask[pos, :] = np.maximum(constraint_mask[pos, :], new_block)
elif disallowed is not None:
# Blacklist mode: only block specified
# Normalize input: supports both "CM" (string) and [C, M] (list)
aa_list = _normalize_aa_spec(disallowed)
aa_indices = _convert_aa_names_to_indices(
aa_list, canonical_tokens, prot_letter_to_token
)
for pos in positions:
for idx in aa_indices:
constraint_mask[pos, idx] = 1.0 # Block specified
return constraint_mask
def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
extra_mols: dict[str, Mol] = {}
parsed_chains: dict[str, ParsedChain] = {}
@@ -766,6 +951,9 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
seq[idx] = code
cyclic = item[entity_type].get("cyclic", False)
symmetric_group = item[entity_type].get("symmetric_group", 0)
if symmetric_group is None:
symmetric_group = 0
# Parse a polymer
parsed_chain = parse_polymer(
@@ -776,10 +964,14 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
components=mols,
cyclic=cyclic,
mol_dir=mol_dir,
symmetric_group=symmetric_group,
)
# Parse a non-polymer
elif (entity_type == "ligand") and "ccd" in (item[entity_type]):
symmetric_group = item[entity_type].get("symmetric_group", 0)
if symmetric_group is None:
symmetric_group = 0
seq = item[entity_type]["ccd"]
if isinstance(seq, str):
seq = [seq]
@@ -806,6 +998,7 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
type=const.chain_type_ids["NONPOLYMER"],
cyclic_period=0,
sequence=None,
symmetric_group=symmetric_group,
)
assert not item[entity_type].get("cyclic", False), (
@@ -813,6 +1006,9 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
)
elif (entity_type == "ligand") and ("smiles" in item[entity_type]):
symmetric_group = item[entity_type].get("symmetric_group", 0)
if symmetric_group is None:
symmetric_group = 0
seq = item[entity_type]["smiles"]
mol = AllChem.MolFromSmiles(seq)
mol = AllChem.AddHs(mol)
@@ -848,6 +1044,7 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
type=const.chain_type_ids["NONPOLYMER"],
cyclic_period=0,
sequence=None,
symmetric_group=symmetric_group,
)
assert not item[entity_type].get("cyclic", False), (
@@ -944,6 +1141,31 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
else:
ss_type.extend([const.ss_type_ids["UNSPECIFIED"]] * num)
# Parse residue_constraints for per-residue amino acid restrictions
entry = item[entity_type]
constraints_spec = entry.get("residue_constraints", None)
ids = item[entity_type]["id"]
num_chains = 1 if isinstance(ids, str) else len(ids)
res_aa_constraint_list = []
for _ in range(num_chains):
if constraints_spec is not None and entity_type == "protein":
res_aa_constraints = parse_residue_constraints(
constraints_spec,
chain_length=num,
canonical_tokens=const.canonical_tokens,
prot_letter_to_token=const.prot_letter_to_token,
)
else:
# No constraints: all 20 amino acids allowed (zeros)
res_aa_constraints = np.zeros((num, len(const.canonical_tokens)), dtype=np.float32)
res_aa_constraint_list.append(res_aa_constraints)
# Concatenate constraint masks for all chain copies
if res_aa_constraint_list:
res_aa_constraint_mask = np.concatenate(res_aa_constraint_list, axis=0)
else:
res_aa_constraint_mask = np.zeros((0, len(const.canonical_tokens)), dtype=np.float32)
# Add as many parsed_chains as provided ids
if entity_type in {"protein", "dna", "rna", "ligand"}:
ids = item[entity_type]["id"]
@@ -971,6 +1193,7 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
chain_to_msa,
fuse_info,
ligand_id,
res_aa_constraint_mask,
)
@@ -1202,6 +1425,7 @@ class YamlDesignParser:
res_design_mask = np.array([], dtype=bool)
res_bind_type = np.array([], dtype=np.int32)
ss_type = np.array([], dtype=np.int32)
res_aa_constraint_mask = np.zeros((0, len(const.canonical_tokens)), dtype=np.float32)
chain_to_msa = {}
global_asym_id = 0
@@ -1225,6 +1449,7 @@ class YamlDesignParser:
entity_chain_to_msa,
fuse_info,
ligand_id,
new_res_aa_constraint_mask,
) = parse_entity(
item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto
)
@@ -1233,6 +1458,7 @@ class YamlDesignParser:
extra_mols.update(new_extra_mols)
res_bind_type = np.concatenate([res_bind_type, new_res_bind_type])
ss_type = np.concatenate([ss_type, new_ss_type])
res_aa_constraint_mask = np.concatenate([res_aa_constraint_mask, new_res_aa_constraint_mask], axis=0)
for asym_id, (chain_name, chain) in enumerate(
parsed_chains.items()
):
@@ -1260,6 +1486,7 @@ class YamlDesignParser:
res_idx,
res_num,
chain.cyclic_period,
chain.symmetric_group,
)
)
chain_to_idx[chain_name] = asym_id
@@ -1405,11 +1632,16 @@ class YamlDesignParser:
fbind_types,
fss_type,
file_chain_to_msa,
file_chain_symmetric_group,
fuse_info,
new_extra_mols,
file_msa_flag,
ligand_id,
) = self.parse_file(item, mols, mol_dir, ligand_id, base_file_path)
# Apply symmetric_group to chains from file
for chain_id, sym_group in file_chain_symmetric_group.items():
chain_mask = new_data.chains["name"] == chain_id
new_data.chains["symmetric_group"][chain_mask] = sym_group
if fuse_info["fuse"]:
if fuse_info["target_id"] in total_renaming.keys():
fuse_info["target_id"] = total_renaming[
@@ -1430,6 +1662,9 @@ class YamlDesignParser:
res_design_mask = np.concatenate([res_design_mask, new_design_mask])
res_bind_type = np.concatenate([res_bind_type, fbind_types])
ss_type = np.concatenate([ss_type, fss_type])
# File entities have no residue constraints — pad with zeros (all AAs allowed)
file_constraint_mask = np.zeros((len(new_design_mask), len(const.canonical_tokens)), dtype=np.float32)
res_aa_constraint_mask = np.concatenate([res_aa_constraint_mask, file_constraint_mask], axis=0)
extra_mols.update(new_extra_mols)
if len(renaming) > 0:
msg = f"\nChain ids conflict with existing chain ids. Renaming with {renaming}. This is for the structure from '{path}'."
@@ -1484,18 +1719,65 @@ class YamlDesignParser:
ValueError(msg)
# Map index
if all_parsed_chains[c1].sampleidx_to_specidx is not None:
if (
c1 in all_parsed_chains.keys()
and all_parsed_chains[c1].sampleidx_to_specidx is not None
):
r1 = np.where(all_parsed_chains[c1].sampleidx_to_specidx == r1)[0][
0
].item()
if all_parsed_chains[c2].sampleidx_to_specidx is not None:
c1, r1, a1 = atom_idx_map[(c1, r1, a1)]
else:
# we have a chain coming from a file where we just use the residue index
chain = data.chains[data.chains["name"] == c1]
c1 = chain["asym_id"].item()
res_start = chain["res_idx"].item()
res_end = chain["res_idx"].item() + chain["res_num"].item()
residues = data.residues[res_start:res_end]
residue = residues[residues["res_idx"] == r1]
r1 = res_start + residue["res_idx"].item()
atom_start = residue["atom_idx"].item()
atom_end = residue["atom_idx"].item() + residue["atom_num"].item()
atoms = data.atoms[atom_start:atom_end]
assert a1 in atoms["name"], (
f"Atom {a1} not found in residue {r1} of chain {c1}"
)
a1 = np.where(atoms["name"] == a1)[0].item()
a1 = (
residue["atom_idx"].item() + a1
) # THIS STILL NEEDS TO BE CORRECTED
if (
c2 in all_parsed_chains.keys()
and all_parsed_chains[c2].sampleidx_to_specidx is not None
):
r2 = np.where(all_parsed_chains[c2].sampleidx_to_specidx == r2)[0][
0
].item()
c2, r2, a2 = atom_idx_map[(c2, r2, a2)]
else:
# we have a chain coming from a file where we just use the residue index
chain = data.chains[data.chains["name"] == c2]
c2 = chain["asym_id"].item()
c1, r1, a1 = atom_idx_map[(c1, r1, a1)]
c2, r2, a2 = atom_idx_map[(c2, r2, a2)]
res_start = chain["res_idx"].item()
res_end = chain["res_idx"].item() + chain["res_num"].item()
residues = data.residues[res_start:res_end]
residue = residues[residues["res_idx"] == r2]
r2 = res_start + residue["res_idx"].item()
atom_start = residue["atom_idx"].item()
atom_end = residue["atom_idx"].item() + residue["atom_num"].item()
atoms = data.atoms[atom_start:atom_end]
assert a2 in atoms["name"], (
f"Atom {a2} not found in residue {r2} of chain {c2}"
)
a2 = np.where(atoms["name"] == a2)[0].item()
a2 = (
residue["atom_idx"].item() + a2
) # THIS STILL NEEDS TO BE CORRECTED
covalents.append((c1, c2, r1, r2, a1, a2))
elif "total_len" in constraints:
continue
@@ -1552,6 +1834,7 @@ class YamlDesignParser:
res_structure_groups=structure_groups,
res_binding_type=res_bind_type,
res_ss_types=ss_type,
res_aa_constraint_mask=res_aa_constraint_mask,
)
DesignInfo.is_valid(design_info)
@@ -1651,6 +1934,7 @@ class YamlDesignParser:
# Construct include mask from include entries
file_chain_to_msa = {}
file_chain_symmetric_group = {}
if isinstance(include, str):
if include == "all":
include_mask = np.ones(num_res)
@@ -1665,12 +1949,15 @@ class YamlDesignParser:
msg = f"Misspecified chain in include with missing 'id' for file with path {path}."
raise ValueError(msg)
chain_id = chain["id"]
if chain_id not in structure.chains["name"]:
msg = f"Specified chain id {chain_id} not in file {path}."
raise ValueError(msg)
if "msa" in chain:
file_chain_to_msa[chain_id] = chain["msa"]
if "symmetric_group" in chain:
file_chain_symmetric_group[chain_id] = chain["symmetric_group"]
data_chain = structure.chains[structure.chains["name"] == chain_id]
c_start = data_chain["res_idx"].item()
@@ -1970,8 +2257,12 @@ class YamlDesignParser:
fss_type[indices] = const.ss_type_ids["SHEET"]
# Parse and apply design insertions
# First pass: collect insertions and coordinate lengths for symmetric chains
if design_insertions is not None:
num_inserted = defaultdict(int)
# Group insertions by (symmetric_group, res_index) to coordinate variable lengths
symmetric_length_cache = {} # (sym_group, res_index) -> sampled_length
for list_element in design_insertions:
insertion = list_element["insertion"]
if "id" not in insertion:
@@ -1985,10 +2276,24 @@ class YamlDesignParser:
res_index += num_inserted[chain_id]
ss_insert_type = insertion.get("secondary_structure", "UNSPECIFIED")
num_residues_spec = insertion["num_residues"]
num_residues_range = parse_range(num_residues_spec)
# Check if this chain has a symmetric_group
chain_sym_group = file_chain_symmetric_group.get(chain_id, 0)
# If chain has symmetric_group > 0, coordinate length with other symmetric chains
if chain_sym_group > 0:
cache_key = (chain_sym_group, res_index, str(num_residues_spec))
if cache_key in symmetric_length_cache:
num_residues = symmetric_length_cache[cache_key]
else:
num_residues = np.random.choice(num_residues_range).item()
symmetric_length_cache[cache_key] = num_residues
else:
num_residues = np.random.choice(num_residues_range).item()
# We add +1 because the parse_range function is usually used for indexing where we then convert the 1 based inputs to 0 indexing
num_residues = insertion["num_residues"]
num_residues = parse_range(num_residues)
num_residues = np.random.choice(num_residues).item()
num_residues += 1
num_inserted[chain_id] += num_residues
@@ -2122,6 +2427,7 @@ class YamlDesignParser:
fbind_types,
fss_type,
file_chain_to_msa,
file_chain_symmetric_group,
fuse_info,
extra_mols,
file_msa_flag,

View File

@@ -50,11 +50,13 @@ class TokenData:
design: bool
binding_type: int
structure_group: int
aa_constraint_mask: np.ndarray # Shape: (20,) - per-residue AA constraints
ccd: np.ndarray
target_msa_mask: bool
design_ss_mask: bool
feature_asym_id: int
feature_res_idx: int
symmetric_group: int
def compute_frame(
@@ -265,11 +267,13 @@ class Tokenizer:
design=False,
binding_type=const.binding_type_ids["UNSPECIFIED"],
structure_group=0,
aa_constraint_mask=np.zeros(20, dtype=np.float32),
ccd=convert_ccd(res["name"]),
target_msa_mask=0,
design_ss_mask=0,
feature_asym_id=chain["asym_id"],
feature_res_idx=res["res_idx"],
symmetric_group=chain["symmetric_group"],
)
token_data.append(tokendata_to_tuple(token))
@@ -329,11 +333,13 @@ class Tokenizer:
design=False,
binding_type=const.binding_type_ids["UNSPECIFIED"],
structure_group=0,
aa_constraint_mask=np.zeros(20, dtype=np.float32),
ccd=convert_ccd(res["name"]),
target_msa_mask=0,
design_ss_mask=0,
feature_asym_id=chain["asym_id"],
feature_res_idx=res["res_idx"],
symmetric_group=chain["symmetric_group"],
)
token_data.append(tokendata_to_tuple(token))
@@ -387,11 +393,13 @@ class Tokenizer:
design=False,
binding_type=const.binding_type_ids["UNSPECIFIED"],
structure_group=0,
aa_constraint_mask=np.zeros(20, dtype=np.float32),
ccd=convert_ccd(res["name"]),
target_msa_mask=0,
design_ss_mask=0,
feature_asym_id=chain["asym_id"],
feature_res_idx=res["res_idx"],
symmetric_group=chain["symmetric_group"],
)
token_data.append(tokendata_to_tuple(token))

View File

@@ -279,6 +279,22 @@ def compute_ptms(logits, x_preds, feats, multiplicity):
dim=1,
).values
# iPTM between designed residues and any token from a different chain
design_residue_iptm_mask = (
maski[:, :, None]
* mask_pad[:, None, :]
* mask_pad[:, :, None]
* (asym_id[:, None, :] != asym_id[:, :, None])
* (
is_design_token[:, :, None] + is_design_token[:, None, :]
).clamp(max=1)
)
design_residue_iptm = torch.max(
torch.sum(tm_expected_value * design_residue_iptm_mask, dim=-1)
/ (torch.sum(design_residue_iptm_mask, dim=-1) + 1e-5),
dim=1,
).values
design_ptm_mask = (
maski[:, :, None]
* mask_pad[:, None, :]
@@ -394,6 +410,7 @@ def compute_ptms(logits, x_preds, feats, multiplicity):
protein_iptm,
chain_pair_iptm,
design_to_target_iptm,
design_residue_iptm,
design_iptm,
design_iiptm,
target_ptm,

View File

@@ -84,6 +84,7 @@ class PairformerModule(nn.Module):
pair_mask,
chunk_size_tri_attn,
use_kernels=use_kernels,
use_reentrant=False,
)
else:
s, z = layer(

View File

@@ -177,7 +177,7 @@ def compute_bond_loss(pred_atom_coords, true_coords, feats):
pred_bond_coords = pred_atom_coords[
:, feats["connections_edge_index"][index_batch]
]
true_bond_coords = pred_atom_coords[
true_bond_coords = true_coords[
:, feats["connections_edge_index"][index_batch]
]
pred_bond_lengths = torch.linalg.norm(
@@ -190,4 +190,4 @@ def compute_bond_loss(pred_atom_coords, true_coords, feats):
num_bonds += pred_bond_lengths.shape[1]
if num_bonds > 0:
bond_loss /= num_bonds
return bond_loss, num_bonds
return bond_loss.mean(), num_bonds

View File

@@ -557,7 +557,7 @@ class ConfidenceHeads(nn.Module):
interaction_pae=interaction_pae,
min_design_to_target_pae=min_design_to_target_pae
if feats["design_mask"].sum() > 0
else torch.nan,
else torch.tensor([torch.nan]),
min_interaction_pae=min_interaction_pae,
)
@@ -581,6 +581,7 @@ class ConfidenceHeads(nn.Module):
protein_iptm,
pair_chains_iptm,
design_to_target_iptm,
design_residue_iptm,
design_iptm,
design_iiptm,
target_ptm,
@@ -596,6 +597,7 @@ class ConfidenceHeads(nn.Module):
out_dict["protein_iptm"] = protein_iptm
out_dict["pair_chains_iptm"] = pair_chains_iptm
out_dict["design_to_target_iptm"] = design_to_target_iptm
out_dict["design_residue_iptm"] = design_residue_iptm
out_dict["design_iptm"] = design_iptm
out_dict["design_iiptm"] = design_iiptm
out_dict["target_ptm"] = target_ptm
@@ -612,6 +614,7 @@ class ConfidenceHeads(nn.Module):
out_dict["protein_iptm"] = torch.zeros_like(complex_plddt)
out_dict["pair_chains_iptm"] = torch.zeros_like(complex_plddt)
out_dict["design_to_target_iptm"] = torch.zeros_like(complex_plddt)
out_dict["design_residue_iptm"] = torch.zeros_like(complex_plddt)
out_dict["design_iptm"] = torch.zeros_like(complex_plddt)
out_dict["design_iiptm"] = torch.zeros_like(complex_plddt)
out_dict["target_ptm"] = torch.zeros_like(complex_plddt)

View File

@@ -829,7 +829,7 @@ class AtomDiffusion(Module):
if add_bond_loss:
bond_loss, num_bonds = compute_bond_loss(
pred_atom_coords=out_dict["sample_atom_coords"].float(),
pred_atom_coords=out_dict["denoised_atom_coords"].float(),
true_coords=atom_coords_aligned_ground_truth,
feats=feats,
)

View File

@@ -1,4 +1,5 @@
from typing import Dict, Tuple, List
from typing import Dict, Tuple, List, Optional
import warnings
import torch
from torch import Tensor, nn
@@ -38,6 +39,75 @@ def softmax_dropout(
)
def build_constraint_logit_mask(
num_nodes: int,
aa_constraint_mask: Optional[Tensor],
inverse_fold_restriction: list[str],
canonical_tokens: list[str],
inf: float,
device: torch.device,
) -> Tensor:
"""Build per-position inverse-folding logit mask.
The mask uses additive logit bias semantics:
0.0 = allowed, -inf = disallowed.
"""
num_aa = len(canonical_tokens)
has_per_residue_constraints = False
if aa_constraint_mask is None:
per_residue_blocked = torch.zeros(
num_nodes, num_aa, dtype=torch.bool, device=device
)
else:
expected_shape = (num_nodes, num_aa)
if aa_constraint_mask.shape != expected_shape:
warnings.warn(
f"aa_constraint_mask shape mismatch: "
f"got {aa_constraint_mask.shape}, expected {expected_shape}. "
f"Ignoring per-residue constraints.",
RuntimeWarning,
stacklevel=2,
)
per_residue_blocked = torch.zeros(
num_nodes, num_aa, dtype=torch.bool, device=device
)
else:
has_per_residue_constraints = True
per_residue_blocked = aa_constraint_mask.to(device=device) > 0
global_blocked = torch.zeros(num_aa, dtype=torch.bool, device=device)
for res_type in inverse_fold_restriction:
global_blocked[canonical_tokens.index(res_type)] = True
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
all_blocked = combined_blocked.all(dim=1)
if all_blocked.any() and has_per_residue_constraints:
blocked_positions = torch.where(all_blocked)[0].tolist()
warnings.warn(
f"Positions {blocked_positions} have all amino acids blocked by the "
f"combination of per-residue constraints and '--inverse_fold_avoid'. "
f"Relaxing per-residue constraints for these positions.",
RuntimeWarning,
stacklevel=2,
)
per_residue_blocked = per_residue_blocked.clone()
per_residue_blocked[all_blocked] = False
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
still_all_blocked = combined_blocked.all(dim=1)
if still_all_blocked.any():
blocked_positions = torch.where(still_all_blocked)[0].tolist()
raise ValueError(
f"Inverse folding has no valid amino acids at token positions "
f"{blocked_positions} after applying '--inverse_fold_avoid'. "
f"Reduce global restrictions to keep at least one amino acid."
)
return combined_blocked.to(dtype=torch.float32) * (-inf)
class MLPAttnGNN(nn.Module):
def __init__(
self,
@@ -464,6 +534,7 @@ class InverseFoldingDecoder(nn.Module):
num_decoder_layers: int = 3,
inverse_fold_restriction: List[str] = [],
sampling_temperature: float = 0.1,
tie_symmetric_sequences: bool = True,
**kwargs, # old checkpoint compatibility
):
super().__init__()
@@ -485,6 +556,7 @@ class InverseFoldingDecoder(nn.Module):
self.num_decoder_layers = num_decoder_layers
self.inverse_fold_restriction = inverse_fold_restriction
self.sampling_temperature = sampling_temperature
self.tie_symmetric_sequences = tie_symmetric_sequences
self.decoder_layers = nn.ModuleList()
self.inf = 10**6
@@ -507,6 +579,43 @@ class InverseFoldingDecoder(nn.Module):
# init the output of the predictor to be zero
self.predictor.weight.zero_()
def _build_symmetric_groups(
self,
feats: Dict[str, Tensor],
valid_mask: Tensor,
design_mask: Tensor,
) -> Tuple[Dict[int, List[int]], Dict[int, int]]:
"""Build mapping from positions to symmetric groups for homomer tying."""
from collections import defaultdict
symmetric_group = feats["symmetric_group"][valid_mask]
res_idx = feats["feature_residue_index"][valid_mask]
# Group by (symmetric_group, res_idx) - positions that should share sequence
key_to_positions = defaultdict(list)
num_nodes = symmetric_group.shape[0]
for i in range(num_nodes):
if design_mask[i]:
group = symmetric_group[i].item()
if group > 0: # 0 = no group
key = (group, res_idx[i].item())
key_to_positions[key].append(i)
# Build symmetric groups (only groups with >1 member)
sym_groups = {}
position_to_group = {}
group_id = 0
for positions in key_to_positions.values():
if len(positions) > 1:
sym_groups[group_id] = positions
for pos in positions:
position_to_group[pos] = group_id
group_id += 1
return sym_groups, position_to_group
def forward(self, s, z, edge_idx, valid_mask, feats):
with torch.no_grad():
src_idx, dst_idx = edge_idx[0], edge_idx[1]
@@ -549,14 +658,17 @@ class InverseFoldingDecoder(nn.Module):
f"num_design: {num_design}, num_not_design: {num_not_design}"
)
# Create restriction mask that sets the probability of excluded residues to 0
if len(self.inverse_fold_restriction) > 0:
restriction_mask = torch.zeros(len(const.canonical_tokens), device=s.device)
for res_type in self.inverse_fold_restriction:
restriction_mask[const.canonical_tokens.index(res_type)] = -self.inf
restriction_mask = restriction_mask.unsqueeze(0)
else:
restriction_mask = torch.zeros(len(const.canonical_tokens), device=s.device)
constraint_mask = None
if "aa_constraint_mask" in feats:
constraint_mask = feats["aa_constraint_mask"][valid_mask]
per_residue_mask = build_constraint_logit_mask(
num_nodes=num_nodes,
aa_constraint_mask=constraint_mask,
inverse_fold_restriction=self.inverse_fold_restriction,
canonical_tokens=const.canonical_tokens,
inf=self.inf,
device=s.device,
)
order = torch.randperm(num_nodes, device=s.device).cpu().numpy().tolist()
# Non-design residues are not sampled and used as the condition. So the order should filter them out.
@@ -572,33 +684,64 @@ class InverseFoldingDecoder(nn.Module):
else:
decoded_seq = torch.zeros(num_nodes, const.num_tokens, device=s.device)
logits = torch.zeros(num_nodes, const.num_tokens, device=s.device)
# Build symmetric groups for homomer tying
if self.tie_symmetric_sequences and "symmetric_group" in feats:
sym_groups, position_to_group = self._build_symmetric_groups(
feats, valid_mask, design_mask
)
sampled = set(torch.where(~design_mask)[0].cpu().numpy().tolist()) if num_not_design > 0 else set()
else:
sym_groups, position_to_group = {}, {}
sampled = set()
src_idx, dst_idx = edge_idx[0], edge_idx[1]
# decoding in order
for i in order:
s_i = s[i : i + 1]
edge_mask_i = dst_idx == i
z_i = z[edge_mask_i]
src_idx_i = src_idx[edge_mask_i]
res_type = decoded_seq[src_idx_i]
res_rep = self.seq_to_s(res_type)
neighbors_rep_i = torch.concat([z_i, s[src_idx_i] + res_rep], dim=-1)
# Skip if already sampled (symmetric position was processed earlier)
if self.tie_symmetric_sequences and i in sampled:
continue
for layer in self.decoder_layers:
s_i = layer.sample(s_i, neighbors_rep_i)
# Get symmetric positions (or just [i] if no symmetry)
if self.tie_symmetric_sequences and i in position_to_group:
positions = sym_groups[position_to_group[i]]
else:
positions = [i]
logits_i = self.predictor(s_i)
logits[i] = logits_i
# Aggregate logits from all symmetric positions
aggregated_logits = None
for pos in positions:
s_pos = s[pos : pos + 1]
edge_mask_pos = dst_idx == pos
z_pos = z[edge_mask_pos]
src_idx_pos = src_idx[edge_mask_pos]
res_type_pos = decoded_seq[src_idx_pos]
res_rep = self.seq_to_s(res_type_pos)
neighbors_rep_pos = torch.concat([z_pos, s[src_idx_pos] + res_rep], dim=-1)
s_temp = s_pos
for layer in self.decoder_layers:
s_temp = layer.sample(s_temp, neighbors_rep_pos)
logits_pos = self.predictor(s_temp)
if aggregated_logits is None:
aggregated_logits = logits_pos
else:
aggregated_logits = aggregated_logits + logits_pos
# Average logits across symmetric positions
aggregated_logits = aggregated_logits / len(positions)
# Sample from aggregated logits
pred_canonical = (
logits_i[
aggregated_logits[
:,
const.canonicals_offset : len(const.canonical_tokens)
+ const.canonicals_offset,
]
+ restriction_mask
+ per_residue_mask[i : i + 1] # Position-specific mask
)
ids_canonical = torch.argmax(pred_canonical, dim=-1)
if self.sampling_temperature is None:
ids_canonical = torch.argmax(pred_canonical, dim=-1)
else:
@@ -609,7 +752,13 @@ class InverseFoldingDecoder(nn.Module):
ids = ids_canonical + const.canonicals_offset
pred_one_hot = F.one_hot(ids, num_classes=const.num_tokens)
decoded_seq[i] = pred_one_hot
# Apply same residue to all symmetric positions
for pos in positions:
decoded_seq[pos] = pred_one_hot
logits[pos] = aggregated_logits
if self.tie_symmetric_sequences:
sampled.add(pos)
n_tokens = valid_mask.shape[1]
res_type = torch.zeros(1, n_tokens, self.num_res_type, device=s.device)

View File

@@ -88,6 +88,9 @@ class BoltzMasker(Module):
new["disto_target"] = clone["disto_target"]
new["token_pair_mask"] = clone["token_pair_mask"]
new["binding_type"] = clone["binding_type"]
# Per-residue amino acid constraints for inverse folding
if "aa_constraint_mask" in clone:
new["aa_constraint_mask"] = clone["aa_constraint_mask"]
new["structure_group"] = clone["structure_group"]
new["cyclic"] = clone["cyclic"]
new["modified"] = clone["modified"]
@@ -101,6 +104,7 @@ class BoltzMasker(Module):
new["res_type_clone"] = clone["res_type_clone"]
new["feature_residue_index"] = clone["feature_residue_index"]
new["feature_asym_id"] = clone["feature_asym_id"]
new["symmetric_group"] = clone["symmetric_group"]
new["token_to_res"] = clone["token_to_res"]
new["token_bonds"] = torch.zeros_like(
clone["token_bonds"]

View File

@@ -286,6 +286,7 @@ class TemplateModule(nn.Module):
)
self.u_proj = nn.Linear(template_dim, token_z, bias=False)
if miniformer_blocks:
self.pairformer = MiniformerNoSeqModule(
template_dim,
@@ -330,7 +331,6 @@ class TemplateModule(nn.Module):
"""
# Load relevant features
asym_id = feats["asym_id"]
res_type = feats["template_restype"]
frame_rot = feats["template_frame_rot"]
frame_t = feats["template_frame_t"]
@@ -338,6 +338,7 @@ class TemplateModule(nn.Module):
cb_coords = feats["template_cb"]
ca_coords = feats["template_ca"]
cb_mask = feats["template_mask_cb"]
visibility_ids = feats["visibility_ids"]
template_mask = feats["template_mask"].any(dim=2).float()
num_templates = template_mask.sum(dim=1)
num_templates = num_templates.clamp(min=1)
@@ -349,10 +350,11 @@ class TemplateModule(nn.Module):
b_cb_mask = b_cb_mask[..., None]
b_frame_mask = b_frame_mask[..., None]
# Compute asym mask, template features only attend within the same chain
# Compute visibility mask, template features only attend within the same visibility
B, T = res_type.shape[:2] # noqa: N806
asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float()
asym_mask = asym_mask[:, None].expand(-1, T, -1, -1)
tmlp_pair_mask = (
visibility_ids[:, :, :, None] == visibility_ids[:, :, None, :]
).float()
# Compute template features
with torch.autocast(device_type="cuda", enabled=False):
@@ -375,7 +377,8 @@ class TemplateModule(nn.Module):
# Concatenate input features
a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
a_tij = torch.cat(a_tij, dim=-1)
a_tij = a_tij * asym_mask.unsqueeze(-1)
a_tij = a_tij * tmlp_pair_mask.unsqueeze(-1)
res_type_i = res_type[:, :, :, None]
res_type_j = res_type[:, :, None, :]
res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)
@@ -669,6 +672,7 @@ class MSAModule(nn.Module):
chunk_size_outer_product,
chunk_size_tri_attn,
use_kernels=use_kernels,
use_reentrant=False,
)
else:
z, m = self.layers[i](

View File

@@ -30,7 +30,7 @@ data:
fail_if_no_designs: true
output_dir: null
keys_dict_out: ["min_interaction_pae", "min_design_to_target_pae", "interaction_pae", "ligand_iptm", "protein_iptm", "iptm", "design_iptm", "design_iiptm", "design_to_target_iptm", "design_ptm", "target_ptm", "ptm"]
keys_dict_out: ["min_interaction_pae", "min_design_to_target_pae", "interaction_pae", "ligand_iptm", "protein_iptm", "iptm", "design_iptm", "design_iiptm", "design_to_target_iptm", "design_residue_iptm", "design_ptm", "target_ptm", "ptm"]
writer:
_target_: boltzgen.task.predict.writer.FoldingWriter

View File

@@ -0,0 +1,484 @@
_target_: boltzgen.task.train.train.Training
trainer:
accelerator: gpu
devices: 8
precision: bf16-mixed
gradient_clip_val: 10.0
accumulate_grad_batches: 1
max_epochs: -1
num_sanity_val_steps: 3
log_every_n_steps: 1
wandb:
group: boltzgen
project: boltzgen
entity: yourwandb
name: a_big_run_resume3
slurm: true
output: workdir
strict_loading: false
resume: null
pretrained: null
debug: false
save_every_n_train_steps: 2500
disable_checkpoint: false
matmul_precision: null
save_top_k: -1
data:
datasets:
- _target_: boltzgen.task.train.data.DatasetConfig
target_dir: ./training_data/targets
msa_dir: ./training_data/msa
prob: 1.0
filters:
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
min_chains: 1
max_chains: 300
- _target_: boltzgen.data.filter.dynamic.date.DateFilter
date: "2023-06-01"
ref: released
- _target_: boltzgen.data.filter.dynamic.resolution.ResolutionFilter
resolution: 9.0
sampler:
_target_: boltzgen.data.sample.cluster.ClusterSampler
cropper:
_target_: boltzgen.data.crop.multimer.MultimerCropper
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
split: ./src/boltzgen/resources/splits/validation_ids_boltz2_all.txt
symmetry_correction: false
val_group: "RCSB"
tokenizer:
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
atomize_modified_residues: false
featurizer:
_target_: boltzgen.data.feature.featurizer.Featurizer
moldir: ./training_data/mols
max_tokens: 512
max_atoms: 5120
max_seqs: 4096
pad_to_max_tokens: true
pad_to_max_atoms: true
pad_to_max_seqs: true
samples_per_epoch: 100000
batch_size: 1
num_workers: 2
random_seed: 42
pin_memory: false
overfit: null
return_train_symmetries: false
return_val_symmetries: false
atoms_per_window_queries: 32
min_dist: 2.0
max_dist: 22.0
num_bins: 64
single_sequence_prop_training: 0.1
msa_sampling_training: true
# Design
design: true
backbone_only: false
atom14: true
atom37: false
selector:
_target_: boltzgen.data.select.protein.ProteinSelector
design_neighborhood_sizes: [2, 4, 6,8,10,12,14,16,18]
substructure_neighborhood_sizes: [2,4,6,8,10,12,24]
structure_condition_prob: 0.4
distance_noise_std: 1
run_selection: true
specify_binding_sites: true
ss_condition_prob: 0.1
select_all: false
# Design datasets
monomer_split: ./src/boltzgen/resources/splits/val_monomers_boltzgen_min50_max220.txt
monomer_target_dir: ./training_data/targets
monomer_target_structure_condition: true
monomer_seq_len: 100
ligand_split: ./src/boltzgen/resources/splits/val_ccd_pdb_pairs_boltzgen.txt
ligand_target_dir: ./training_data/targets
ligand_seq_len: 100
model:
_target_: boltzgen.model.models.boltz.Boltz
atom_s: 128
atom_z: 16
token_s: 384
token_z: 128
num_bins: 64
atom_feature_dim: 388
atoms_per_window_queries: 32
atoms_per_window_keys: 128
use_miniformer: false
ema: true
ema_decay: 0.999
exclude_ions_from_lddt: true
num_val_datasets: 1 # New
ignore_ckpt_shape_mismatch: false # New
aggregate_distogram: true # New
bond_type_feature: true
predict_bfactor: true
checkpoint_diffusion_conditioning: true
use_kernels: true
validators:
- _target_: boltzgen.model.validation.design.DesignValidator
val_names: ["RCSB"]
confidence_prediction: ${model.confidence_prediction}
atom14: ${data.atom14}
atom37: ${data.atom37}
masker_args:
mask: true
mask_backbone: false
mask_disto: true
embedder_args:
atom_encoder_depth: 3
atom_encoder_heads: 4
add_mol_type_feat: true
add_method_conditioning: true
add_modified_flag: true
add_cyclic_flag: true
add_design_mask_flag: true
add_binding_specification: true
add_ss_specification: true
freeze_template_weights: true
use_templates: true
template_args:
template_dim: 64
template_blocks: 2
activation_checkpointing: false
use_token_distances: true
token_distance_args:
token_distance_dim: 64
token_distance_blocks: 2
use_token_distance_feats: true
distance_gaussian_dim: 32
activation_checkpointing: true
msa_args:
msa_s: 64
msa_blocks: 4
msa_dropout: 0.15
z_dropout: 0.25
miniformer_blocks: false
pairwise_head_width: 32
pairwise_num_heads: 4
use_paired_feature: true
activation_checkpointing: true
pairformer_args:
num_blocks: 64
num_heads: 16
dropout: 0.25
post_layer_norm: false
activation_checkpointing: true
score_model_args:
sigma_data: 16
dim_fourier: 256
atom_encoder_depth: 3
atom_encoder_heads: 4
# token level args
token_layers: 1
token_transformer_depth: 24
token_transformer_heads: 16
diffusion_pairformer_args:
num_blocks: 0
num_heads: 2
dropout: 0
use_s_to_z: false
atom_decoder_depth: 3
atom_decoder_heads: 4
conditioning_transition_layers: 2
transformer_post_ln: false
activation_checkpointing: true
confidence_prediction: false
structure_prediction_training: true
training_args:
recycling_steps: 3
sampling_steps: 20
diffusion_multiplicity: 32
diffusion_samples: 1
confidence_loss_weight: 1e-4
diffusion_loss_weight: 4.0
distogram_loss_weight: 3e-2
bfactor_loss_weight: 1e-3
adam_beta_1: 0.9
adam_beta_2: 0.95
adam_eps: 0.00000001
lr_scheduler: af3
base_lr: 0.0
max_lr: 0.0005
lr_warmup_no_steps: 1000
lr_start_decay_after_n_steps: 50000
lr_decay_every_n_steps: 50000
lr_decay_factor: 0.95
weight_decay: 0.003
weight_decay_exclude: true
validation_args:
recycling_steps: 3
sampling_steps: 200
diffusion_samples: 1
symmetry_correction: false
diffusion_process_args:
sigma_min: 0.0004 # min noise level
sigma_max: 160.0 # max noise level
sigma_data: 16.0 # standard deviation of data distribution
rho: 7 # controls the sampling schedule
P_mean: -1.2 # mean of log-normal distribution from which noise is drawn for training
P_std: 1.5 # standard deviation of log-normal distribution from which noise is drawn for training
gamma_0: 0.8
gamma_min: 1.0
noise_scale: 1.0
step_scale: 1.0
mse_rotational_alignment: true
coordinate_augmentation: true
alignment_reverse_diff: true
synchronize_sigmas: false
diffusion_loss_args:
add_smooth_lddt_loss: true
add_bond_loss: false
nucleotide_loss_weight: 5.0
ligand_loss_weight: 10.0
refolding_validator:
_target_: boltzgen.model.validation.refolding.RefoldingValidator
val_names: ["RCSB"]
step_scale: 1.5
noise_scale: 0.75
atom14: ${data.atom14}
atom37: ${data.atom37}
val_monomer: ${data.monomer_split}
val_ligand: ${data.ligand_split}
analyze_task:
_target_: boltzgen.task.analyze.analyze.Analyze
name: ${name}
debug: ${debug}
design_dir: null
num_processes: 1
# Common metrics to compute
affinity_metrics: false
allatom_fold_metrics: true
backbone_fold_metrics: true
noncovalents_original: false
noncovalents_refolded: false
delta_sasa_original: false
delta_sasa_refolded: false
largest_hydrophobic: false
largest_hydrophobic_refolded: false
run_clustering: false
# Liability analysis
liability_analysis: false
liability_modality: peptide
liability_peptide_type: linear
# Uncommon metrics
diversity_original: true
diversity_refolded: true
diversity_per_target_original: false
diversity_per_target_refolded: false
novelty_original: false
novelty_refolded: false
novelty_per_target_original: false
novelty_per_target_refolded: false
wandb: null
data:
_target_: boltzgen.task.predict.data_from_generated.FromGeneratedDataModule
cfg:
_target_: boltzgen.task.predict.data_from_generated.DataConfig
tokenizer:
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
atomize_modified_residues: false
featurizer:
_target_: boltzgen.data.feature.featurizer.Featurizer
suffix: .cif
suffix_metadata: .npz
suffix_native: _native.cif
samples_per_target: 1
num_targets: 100000000
moldir: ./training_data/mols
batch_size: 1
num_workers: 4
pin_memory: true
return_native: true
folding_checkpoint: ./training_data/boltz2_fold.ckpt
folding_args:
recycling_steps: 3
sampling_steps: 200
diffusion_samples: 1
folding_model_args:
atom_s: 128
atom_z: 16
token_s: 384
token_z: 128
num_bins: 64
atom_feature_dim: 388
atoms_per_window_queries: 32
atoms_per_window_keys: 128
compile_pairformer: false
compile_templates: false
compile_msa: false
use_miniformer: false
ema: true
ema_decay: 0.999
exclude_ions_from_lddt: true
num_val_datasets: 4
ignore_ckpt_shape_mismatch: false
aggregate_distogram: true
bond_type_feature: true
conditioning_cutoff_min: 4.0
conditioning_cutoff_max: 20.0
use_templates: true
predict_bfactor: true
checkpoint_diffusion_conditioning: false
use_kernels: true
validators: null
embedder_args:
atom_encoder_depth: 3
atom_encoder_heads: 4
add_mol_type_feat: true
add_method_conditioning: true
add_modified_flag: true
add_cyclic_flag: true
msa_args:
msa_s: 64
msa_blocks: 4
msa_dropout: 0.15
z_dropout: 0.25
miniformer_blocks: false
pairwise_head_width: 32
pairwise_num_heads: 4
use_paired_feature: true
activation_checkpointing: false
template_args:
template_dim: 64
template_blocks: 2
activation_checkpointing: false
pairformer_args:
num_blocks: 64
num_heads: 16
dropout: 0.25
post_layer_norm: false
activation_checkpointing: false
score_model_args:
sigma_data: 16
dim_fourier: 256
atom_encoder_depth: 3
atom_encoder_heads: 4
token_transformer_depth: 24
token_transformer_heads: 16
atom_decoder_depth: 3
atom_decoder_heads: 4
conditioning_transition_layers: 2
transformer_post_ln: false
activation_checkpointing: false
confidence_prediction: true
affinity_prediction: false
structure_prediction_training: true
affinity_model_args:
num_dist_bins: 64
max_dist: 22
no_trunk_feats: false
add_s_to_z_prod: false
add_s_input_to_s: false
confidence_args:
num_plddt_bins: 50
num_pde_bins: 64
num_pae_bins: 64
training_args:
recycling_steps: 3
sampling_steps: 20
diffusion_multiplicity: 48
diffusion_samples: 1
affinity_loss_weight: 3e-3
confidence_loss_weight: 1e-4
diffusion_loss_weight: 4.0
distogram_loss_weight: 3e-2
bfactor_loss_weight: 1e-3
adam_beta_1: 0.9
adam_beta_2: 0.95
adam_eps: 0.00000001
lr_scheduler: af3
base_lr: 0.0
max_lr: 0.001
lr_warmup_no_steps: 1000
lr_start_decay_after_n_steps: 50000
lr_decay_every_n_steps: 50000
lr_decay_factor: 0.95
weight_decay: 0.003
weight_decay_exclude: true
validation_args:
recycling_steps: 3
sampling_steps: 200
diffusion_samples: 5
symmetry_correction: false
diffusion_process_args:
sigma_min: 0.0004 # min noise level
sigma_max: 160.0 # max noise level
sigma_data: 16.0 # standard deviation of data distribution
rho: 7 # controls the sampling schedule
P_mean: -1.2 # mean of log-normal distribution from which noise is drawn for training
P_std: 1.5 # standard deviation of log-normal distribution from which noise is drawn for training
gamma_0: 0.8
gamma_min: 1.0
noise_scale: 1.0
step_scale: 1.0
mse_rotational_alignment: true
coordinate_augmentation: true
alignment_reverse_diff: true
synchronize_sigmas: false
diffusion_loss_args:
add_smooth_lddt_loss: true
add_bond_loss: false
nucleotide_loss_weight: 5.0
ligand_loss_weight: 10.0

View File

@@ -0,0 +1,38 @@
{
"design_iiptm": {
"mean": 0.3828427453913554,
"std": 0.10621994686776445
},
"design_ptm": {
"mean": 0.6548983461494845,
"std": 0.041925380367595154
},
"min_design_to_target_pae": {
"mean": 8.665374909044512,
"std": 3.8643855151416453
},
"design_hydrophobicity": {
"mean": 43.95844996189343,
"std": 7.565721364606706
},
"design_largest_hydrophobic_patch_refolded": {
"mean": 608.7238163271427,
"std": 440.3199461402019
},
"delta_sasa_refolded": {
"mean": 1368.936707104108,
"std": 514.8573173243741
},
"plip_saltbridge_refolded": {
"mean": 1.746922289263117,
"std": 1.9177749177139145
},
"plip_hbonds_refolded": {
"mean": 7.099706654392513,
"std": 4.456755351208993
},
"affinity_probability_binary1": {
"mean": 0.4726671722992441,
"std": 0.19092615494037402
}
}

View File

@@ -104,6 +104,7 @@ class Analyze(Task):
foldseek_binary: str = "/data/rbg/users/hstark/foldseek/bin/foldseek",
skip_specific_ids: List[str] = None,
designfolding_metrics: bool = False,
use_design_mask_for_target: bool = False,
) -> None:
"""Initialize the task.
@@ -152,6 +153,7 @@ class Analyze(Task):
self.wandb = wandb
self.slurm = slurm
self.diversity_subset = diversity_subset
self.use_design_mask_for_target = use_design_mask_for_target
# Prevent each worker process from spawning its own multithreaded pools
torch.set_num_threads(1)
@@ -565,12 +567,50 @@ class Analyze(Task):
"designed_chain_sequence": design_chain_seq,
}
# Add per-chain sequences to csv when designing multiple chains
design_token_indices = torch.where(feat["design_mask"].bool() & feat["token_pad_mask"].bool())[0]
designed_chain_ids = feat["asym_id"][design_token_indices].unique().tolist()
if len(designed_chain_ids) > 1:
for chain_id in designed_chain_ids:
chain_mask = feat["asym_id"] == chain_id
# Full chain sequence
chain_res_types = res_type_argmax[chain_mask]
full_chain_seq = "".join(
[
const.prot_token_to_letter.get(const.tokens[t], "X")
for t in chain_res_types
]
)
# Designed residues only from this chain
design_chain_mask = feat["design_mask"].bool() & feat["token_pad_mask"].bool() & chain_mask
design_res_types = res_type_argmax[design_chain_mask]
design_seq = "".join(
[
const.prot_token_to_letter.get(const.tokens[t], "X")
for t in design_res_types
]
)
metrics[f"designed_sequence_{chain_id}"] = design_seq
metrics[f"full_sequence_{chain_id}"] = full_chain_seq
target_id = re.search(rf"{self.data.cfg.target_id_regex}", sample_id).group(1)
# Get masks
design_mask = feat["design_mask"].bool()
chain_design_mask = feat["chain_design_mask"].bool()
design_resolved_mask = design_mask & feat["token_resolved_mask"].bool()
target_resolved_mask = ~design_mask & feat["token_resolved_mask"].bool()
# For symmetric designs where all chains have designed residues, use design_mask
# instead of chain_design_mask so "target" = non-designed residues (not empty)
if self.use_design_mask_for_target:
target_resolved_mask = (~design_mask) & feat["token_resolved_mask"].bool()
else:
target_resolved_mask = (~chain_design_mask) & feat["token_resolved_mask"].bool()
atom_design_resolved_mask = (
(feat["atom_to_token"].float() @ design_resolved_mask.unsqueeze(-1).float())
.bool()
@@ -584,11 +624,10 @@ class Analyze(Task):
atom_resolved_mask = feat["atom_resolved_mask"]
resolved_atoms_design_mask = atom_design_resolved_mask[atom_resolved_mask]
resolved_atoms_target_mask = atom_target_resolved_mask[atom_resolved_mask]
design_mask_for_chain = feat["asym_id"] == design_chain_id
atom_chain_mask = (
(
feat["atom_to_token"].float()
@ design_mask_for_chain.unsqueeze(-1).float()
@ chain_design_mask.unsqueeze(-1).float()
)
.bool()
.squeeze()
@@ -669,7 +708,11 @@ class Analyze(Task):
delta_sasa_orig,
design_sasa_unbound,
design_sasa_bound,
) = get_delta_sasa(path, resolved_atoms_target_mask)
) = get_delta_sasa(
path,
atom_target_mask=resolved_atoms_target_mask,
atom_design_mask=resolved_atoms_design_mask,
)
metrics["delta_sasa_original"] = delta_sasa_orig
metrics["design_sasa_unbound_original"] = design_sasa_unbound
metrics["design_sasa_bound_original"] = design_sasa_bound
@@ -1039,12 +1082,21 @@ class Analyze(Task):
if self.delta_sasa_refolded:
cif_path_refolded = self.refold_cif_dir / f"{feat['id']}.cif"
if not cif_path_refolded.exists():
msg = f"Refolded cif path does not exist. This can happen if a process was interrupted between writing the refold .npz file and the refold .cif file. Missing path: {cif_path_refolded}"
print(msg)
return None
# Compute delta sasa
(
delta_sasa_refolded,
design_sasa_unbound,
design_sasa_bound,
) = get_delta_sasa(cif_path_refolded, resolved_atoms_target_mask)
) = get_delta_sasa(
cif_path_refolded,
atom_target_mask=resolved_atoms_target_mask,
atom_design_mask=resolved_atoms_design_mask,
)
metrics["delta_sasa_refolded"] = delta_sasa_refolded
metrics["design_sasa_unbound_refolded"] = design_sasa_unbound

View File

@@ -97,7 +97,7 @@ def make_histogram(
def get_best_folding_sample(folded):
confidence = 0.8 * folded["design_to_target_iptm"] + 0.2 * folded["design_ptm"]
best_idx = np.argmax(confidence)
# TODO: remove the "if k in folded"
best_sample = {
k: folded[k][best_idx] for k in const.eval_keys_confidence if k in folded
@@ -188,10 +188,19 @@ def count_noncovalents(feat):
biotite_array, _ = hydride.add_hydrogen(biotite_array)
hbond = biotite.structure.hbond(biotite_array)
donor_idxs, acceptor_idxs = hbond[:, 0], hbond[:, 2]
cross_design_hbonds = (
biotite_array.is_design[donor_idxs] != biotite_array.is_design[acceptor_idxs]
donor_design_hbonds = int(
(
biotite_array.is_design[donor_idxs]
& ~biotite_array.is_chain_design[acceptor_idxs]
).sum()
)
metrics["plip_hbonds"] = int(cross_design_hbonds.sum())
acceptor_design_hbonds = int(
(
~biotite_array.is_chain_design[donor_idxs]
& biotite_array.is_design[acceptor_idxs]
).sum()
)
metrics["plip_hbonds"] = donor_design_hbonds + acceptor_design_hbonds
# saltbridges
pos_atoms = biotite_array[biotite_array.charge > 0]
@@ -204,10 +213,13 @@ def count_noncovalents(feat):
(pos_neg_distances > 0.5) & (pos_neg_distances < 5.5)
)
# only keep the ones between design and non design
cross_design_saltbridges = (
pos_atoms.is_design[pos_idxs] != neg_atoms.is_design[neg_idxs]
pos_design_sb = int(
(pos_atoms.is_design[pos_idxs] & ~neg_atoms.is_chain_design[neg_idxs]).sum()
)
metrics["plip_saltbridge"] = int(cross_design_saltbridges.sum())
neg_design_sb = int(
(~pos_atoms.is_chain_design[pos_idxs] & neg_atoms.is_design[neg_idxs]).sum()
)
metrics["plip_saltbridge"] = pos_design_sb + neg_design_sb
else:
metrics["plip_saltbridge"] = 0
return metrics
@@ -629,9 +641,14 @@ def largest_hydrophobic_patch_area(cif_path, distance_cutoff=6.0):
return max_patch_area
def get_delta_sasa(path, atom_design_mask):
def get_delta_sasa(
path,
atom_target_mask,
atom_design_mask,
):
stack = _load_stack(path)
atoms = stack[0]
res = [
r.decode().strip() if isinstance(r, bytes) else str(r).strip()
for r in atoms.res_name
@@ -645,23 +662,41 @@ def get_delta_sasa(path, atom_design_mask):
radii = np.array(
[_radius(rn, an, el) for rn, an, el in zip(res, atm, elem)], dtype=float
)
area = sasa(atoms, probe_radius=1.4, point_number=960, vdw_radii=radii)
design_bound = area[atom_design_mask].sum()
ligand_atoms = atoms[atom_design_mask]
lig_res = [r for r, m in zip(res, atom_design_mask) if m]
lig_atm = [a for a, m in zip(atm, atom_design_mask) if m]
lig_elem = [e for e, m in zip(elem, atom_design_mask) if m]
bound_mask = atom_design_mask | atom_target_mask
atoms_bound = atoms[bound_mask]
radii_bound = radii[bound_mask]
area_bound = sasa(
atoms_bound,
probe_radius=1.4,
point_number=960,
vdw_radii=radii_bound,
)
target_in_bound = atom_target_mask[bound_mask]
target_bound = area_bound[target_in_bound].sum()
target_atoms = atoms[atom_target_mask]
target_res = [r for r, m in zip(res, atom_target_mask) if m]
target_atm = [a for a, m in zip(atm, atom_target_mask) if m]
target_elem = [e for e, m in zip(elem, atom_target_mask) if m]
radii_lig = np.array(
[_radius(rn, an, el) for rn, an, el in zip(lig_res, lig_atm, lig_elem)],
[_radius(rn, an, el) for rn, an, el in zip(target_res, target_atm, target_elem)],
dtype=float,
)
ligand_area = sasa(
ligand_atoms, probe_radius=1.4, point_number=960, vdw_radii=radii_lig
target_area = sasa(
target_atoms,
probe_radius=1.4,
point_number=960,
vdw_radii=radii_lig,
)
delta = ligand_area.sum() - design_bound
return delta, ligand_area.sum(), design_bound
delta = target_area.sum() - target_bound
return delta, target_area.sum(), target_bound
def compute_ss_metrics(dssp_pred, ss_conditioning_metricsed):

View File

@@ -1,3 +1,4 @@
import json
from boltzgen.utils.quiet import quiet_startup
@@ -224,7 +225,7 @@ class Filter(Task):
if not metrics_override is None:
for k in metrics_override:
if metrics_override[k] is None:
del self.metrics[k]
self.metrics.pop(k, None)
else:
self.metrics[k] = metrics_override[k]
@@ -284,12 +285,12 @@ class Filter(Task):
{
"feature": "GLY_fraction",
"lower_is_better": True,
"threshold": 0.2,
"threshold": 0.3,
},
{
"feature": "GLU_fraction",
"lower_is_better": True,
"threshold": 0.2,
"threshold": 0.3,
},
{
"feature": "LEU_fraction",
@@ -299,7 +300,7 @@ class Filter(Task):
{
"feature": "VAL_fraction",
"lower_is_better": True,
"threshold": 0.2,
"threshold": 0.3,
},
]
)
@@ -312,6 +313,7 @@ class Filter(Task):
self.load_dataframe()
self.reset_outdir()
self.filter_df()
self.absolute_metrics()
self.sort_df()
self.optimize_diversity()
self.write_outdir()
@@ -379,6 +381,8 @@ class Filter(Task):
"design_largest_hydrophobic_patch_refolded"
]
df["neg_min_interaction_pae"] = -df["min_interaction_pae"]
df["neg_filter_rmsd"] = -df["filter_rmsd"]
df["neg_filter_rmsd_design"] = -df["filter_rmsd_design"]
df["has_x"] = df["designed_sequence"].str.contains("X")
self.df = df
@@ -418,6 +422,61 @@ class Filter(Task):
)
print("\n")
def absolute_metrics(self):
norm_path = Path("src/boltzgen/resources/metrics_normalization.json")
if not norm_path.exists():
return
with norm_path.open("r") as f:
norm_stats = json.load(f)
for col, stats in norm_stats.items():
mean = stats["mean"]
std = stats["std"]
if col in self.df.columns:
self.df[col + "_z"] = (self.df[col] - mean) / std
importances = {
"affinity_probability_binary1": 1.5,
"design_iiptm": 1.0,
"design_ptm": 0.5,
"min_design_to_target_pae": -1.0, # lower is better
"design_hydrophobicity": -0.125, # lower is better
"design_largest_hydrophobic_patch_refolded": -0.15, # lower is better
"delta_sasa_refolded": 0.25,
"plip_saltbridge_refolded": 0.25,
"plip_hbonds_refolded": 0.25,
}
self.df["absolute_score"] = 0.0
for base_col, weight in importances.items():
if base_col in self.df.columns:
norm_col = base_col + "_z"
self.df["absolute_score"] += weight * self.df[norm_col]
total_importance = sum(abs(w) for w in importances.values())
self.df["absolute_score"] /= total_importance
self.df["structure_confidence"] = 0.0
weight_sum = 0
for col in ["design_iiptm", "design_ptm", "min_design_to_target_pae"]:
weight = importances[col]
norm_col = col + "_z"
weight_sum += abs(weight)
self.df["structure_confidence"] += weight * self.df[norm_col]
self.df["structure_confidence"] /= weight_sum
for flt in self.filters:
feat = flt["feature"]
filter_col = f"pass_{feat}_filter"
if "fraction" in feat:
# If this is a "fraction" feature, meaning a res_type fraction filter, only apply the penalty if num_design > 8
mask_fail = (self.df["num_design"] > 8) & (self.df[filter_col] == False)
else:
mask_fail = self.df[filter_col] == False
self.df.loc[mask_fail, "absolute_score"] *= 0.1
def sort_df(self):
rank_df = pd.DataFrame(index=self.df.index)
@@ -584,6 +643,12 @@ class Filter(Task):
heapq.heappush(heap, (-gain, i))
buckets = np.zeros(len(self.size_buckets) + 1)
first = selected[0]
first_len = len(self.df_m["sequence"][first])
for idx, bucket_size in enumerate(self.size_buckets):
if first_len >= bucket_size["min"] and first_len < bucket_size["max"]:
buckets[idx] += 1
break
for _ in tqdm(
range(k - 1), desc="Performing lazy greedy diversity optimization."
):
@@ -749,6 +814,12 @@ class Filter(Task):
hist_metrics = list(dict.fromkeys(hist_metrics))
extra_pairs = list(dict.fromkeys(extra_pairs))
# Prepend any active ranking metrics not already in the lists
extra_ranking = [m for m in self.metrics if m in self.df.columns and m not in hist_metrics]
hist_metrics = extra_ranking + hist_metrics
summary_metrics = extra_ranking + summary_metrics
extra_pairs = [("num_design", m) for m in extra_ranking] + extra_pairs
if self.use_affinity:
summary_metrics.insert(2, "affinity_probability_binary1")
hist_metrics.insert(2, "affinity_probability_binary1")
@@ -883,7 +954,6 @@ class Filter(Task):
pdf_path = self.outdir / f"results_overview.pdf"
pdf = PdfPages(pdf_path)
def _ensure_width(fig, target_w=8.5):
w, h = fig.get_size_inches()
if abs(w - target_w) > 0.01:
@@ -1231,8 +1301,6 @@ class Filter(Task):
pdf.close()
print(
"A description of metrics and summarizing plots was written to:", pdf_path
)

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass
from pathlib import Path
import random
import re
import warnings
from typing import Dict, List, Optional
from collections import defaultdict
from rdkit.Chem import Mol
@@ -198,6 +199,7 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
return_native: bool = False,
reference_metadata_dir: Optional[Path] = None,
target_templates: bool = False,
design_mask_templates: bool = False,
compute_affinity: bool = False,
design: bool = False,
backbone_only: bool = False,
@@ -232,6 +234,7 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
self.return_native = return_native
self.reference_metadata_dir = reference_metadata_dir
self.target_templates = target_templates
self.design_mask_templates = design_mask_templates
self.compute_affinity = compute_affinity
self.design = design
self.backbone_only = backbone_only
@@ -317,8 +320,28 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
if "binding_type" in metadata:
binding_type = metadata["binding_type"]
# Per-residue amino acid constraints for inverse folding
aa_constraint_mask = None
if "aa_constraint_mask" in metadata:
loaded_mask = metadata["aa_constraint_mask"]
# Validate the loaded mask is a proper array with expected shape
if (
isinstance(loaded_mask, np.ndarray)
and loaded_mask.ndim == 2
and loaded_mask.shape[1] == 20 # 20 canonical amino acids
):
aa_constraint_mask = loaded_mask
else:
warnings.warn(
f"Invalid aa_constraint_mask in NPZ: "
f"type={type(loaded_mask)}, shape={getattr(loaded_mask, 'shape', 'N/A')}. "
f"Expected ndarray with shape (N, 20). Ignoring constraints.",
RuntimeWarning,
stacklevel=2,
)
# Get features
feat = self.get_feat(generated_path, design_mask, ss_type, binding_type)
feat = self.get_feat(generated_path, design_mask, ss_type, binding_type, aa_constraint_mask)
# Get native features
if self.return_native:
@@ -332,7 +355,7 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
return feat
def get_feat(self, path, design_mask, ss_type=None, binding_type=None):
def get_feat(self, path, design_mask, ss_type=None, binding_type=None, aa_constraint_mask=None):
# Load design
if self.extra_mol_dir is not None:
mols = {
@@ -460,6 +483,9 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
features["ss_type"] = torch.from_numpy(ss_type).long()
if binding_type is not None:
features["binding_type"] = torch.from_numpy(binding_type).long()
# Per-residue amino acid constraints for inverse folding
if aa_constraint_mask is not None:
features["aa_constraint_mask"] = torch.from_numpy(aa_constraint_mask).float()
# If we do not want the design mask to impact the featurizer (e.g. represent atoms as atom14), we set the design mask only here.
if not self.design:
@@ -490,7 +516,10 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
# Set templates
if self.target_templates:
template_mask = ~features["chain_design_mask"].numpy()
if self.design_mask_templates:
template_mask = ~features["design_mask"].numpy()
else:
template_mask = ~features["chain_design_mask"].numpy()
templates_features = template_from_tokens(tokenized, template_mask)
else:
# Compute template features
@@ -525,6 +554,7 @@ class FromGeneratedDataModule(pl.LightningDataModule):
return_native: bool = False,
compute_affinity: bool = False,
target_templates: bool = False,
design_mask_templates: bool = False,
skip_existing: bool = False,
skip_existing_kind: str = None,
legacy_gen_suffix: str = "_gen.cif",
@@ -552,6 +582,7 @@ class FromGeneratedDataModule(pl.LightningDataModule):
self.legacy_metadata_suffix = legacy_metadata_suffix
self.compute_affinity = compute_affinity
self.target_templates = target_templates
self.design_mask_templates = design_mask_templates
self.extra_features = extra_features
self.disulfide_prob = cfg.disulfide_prob
self.disulfide_on = cfg.disulfide_on
@@ -583,6 +614,7 @@ class FromGeneratedDataModule(pl.LightningDataModule):
return_native=self.return_native,
reference_metadata_dir=self.reference_metadata_dir,
target_templates=self.target_templates,
design_mask_templates=self.design_mask_templates,
compute_affinity=self.compute_affinity,
design=self.cfg.design,
backbone_only=self.cfg.backbone_only,
@@ -790,6 +822,7 @@ class FromGeneratedDataModule(pl.LightningDataModule):
return_native=self.return_native,
reference_metadata_dir=self.reference_metadata_dir,
target_templates=self.target_templates,
design_mask_templates=self.design_mask_templates,
compute_affinity=self.compute_affinity,
design=self.cfg.design,
backbone_only=self.cfg.backbone_only,

View File

@@ -216,6 +216,10 @@ class PredictionDataset(torch.utils.data.Dataset):
tokenized.tokens["structure_group"] = design_info.res_structure_groups[
token_to_res
]
# Transfer per-residue amino acid constraints (shape: num_tokens x 20)
tokenized.tokens["aa_constraint_mask"] = design_info.res_aa_constraint_mask[
token_to_res
]
# Propagate design mask to obtain chain_design_mask (True whenever something is covalently bound to any residue that is in a chain that contains a design residue).
chain_design_mask = tokenized.tokens["design_mask"].astype(bool)

View File

@@ -384,34 +384,31 @@ class DesignWriter(BasePredictionWriter):
# Write metadata
metadata_path = f"{self.outdir}/{file_name}.npz"
token_mask = sample["token_pad_mask"].bool()
np.savez_compressed(
metadata_path,
design_mask=design_mask[sample["token_pad_mask"].bool()]
.cpu()
.numpy(),
inverse_fold_design_mask=sample["inverse_fold_design_mask"][
sample["token_pad_mask"].bool()
]
.cpu()
.numpy()
if "inverse_fold_design_mask" in sample
else None,
mol_type=sample["mol_type"][sample["token_pad_mask"].bool()]
.cpu()
.numpy(),
ss_type=sample["ss_type"][sample["token_pad_mask"].bool()]
.cpu()
.numpy(),
token_resolved_mask=sample["token_resolved_mask"][
sample["token_pad_mask"].bool()
]
.cpu()
.numpy(),
binding_type=binding_type[sample["token_pad_mask"].bool()]
.cpu()
.numpy(),
)
# Build metadata dict with required fields
metadata_dict = {
"design_mask": design_mask[token_mask].cpu().numpy(),
"mol_type": sample["mol_type"][token_mask].cpu().numpy(),
"ss_type": sample["ss_type"][token_mask].cpu().numpy(),
"token_resolved_mask": sample["token_resolved_mask"][token_mask].cpu().numpy(),
"binding_type": binding_type[token_mask].cpu().numpy(),
}
# Add optional fields only if they have valid values (avoid None -> object array)
if "inverse_fold_design_mask" in sample:
metadata_dict["inverse_fold_design_mask"] = (
sample["inverse_fold_design_mask"][token_mask].cpu().numpy()
)
# Per-residue amino acid constraints (for inverse folding step)
# Only save if constraints exist AND have non-zero values
if "aa_constraint_mask" in batch:
aa_mask = batch["aa_constraint_mask"][0]
if aa_mask.any(): # Only save if there are actual constraints
metadata_dict["aa_constraint_mask"] = aa_mask[token_mask].cpu().numpy()
np.savez_compressed(metadata_path, **metadata_dict)
# Write trajectories
if self.save_traj:

79
tests/conftest.py Normal file
View File

@@ -0,0 +1,79 @@
"""Optional test configuration for mocking heavy dependencies.
Enable with:
pytest --mock-heavy-deps tests/test_residue_constraints.py
By default no mocking is performed, so integration tests run against
real dependencies.
"""
import sys
from unittest.mock import MagicMock
def _install_mock(name: str) -> None:
"""Install a mock module (and parent packages) into sys.modules."""
parts = name.split(".")
for i in range(len(parts)):
mod_name = ".".join(parts[: i + 1])
if mod_name not in sys.modules:
sys.modules[mod_name] = MagicMock()
# Heavy dependencies that schema.py imports transitively but are NOT
# needed by the three constraint-parsing functions under test.
_MOCK_MODULES = [
"torch",
"torch.nn",
"torch.nn.functional",
"torch.utils",
"torch.utils.data",
"pytorch_lightning",
"hydra",
"hydra.core",
"hydra.core.config_store",
"einops",
"einx",
"mashumaro",
"biotite",
"biotite.structure",
"biotite.structure.io",
"biotite.structure.io.pdbx",
"pydssp",
"logomaker",
"hydride",
"gemmi",
"pdbeccdutils",
"pdbeccdutils.core",
"pdbeccdutils.core.ccd_reader",
"edit_distance",
"huggingface_hub",
"nvidia_ml_py",
"cuequivariance_ops_cu12",
"cuequivariance_ops_torch_cu12",
"cuequivariance_torch",
"numba",
"sklearn",
"sklearn.cluster",
"sklearn.neighbors",
"pandas",
"matplotlib",
"matplotlib.pyplot",
"tqdm",
"Bio",
"Bio.PDB",
]
def pytest_addoption(parser) -> None:
parser.addoption(
"--mock-heavy-deps",
action="store_true",
default=False,
help="Mock heavy optional dependencies for parser-only unit tests.",
)
def pytest_configure(config) -> None:
if config.getoption("--mock-heavy-deps"):
for mod in _MOCK_MODULES:
_install_mock(mod)

View File

@@ -0,0 +1,90 @@
"""Integration tests for inverse-folding constraint mask composition."""
import pytest
torch = pytest.importorskip("torch")
from boltzgen.data import const
from boltzgen.model.modules.inverse_fold import build_constraint_logit_mask
INF = 10**6
def _allowed_only_mask(allowed_tokens: list[str]) -> torch.Tensor:
"""Build a single-row mask where only `allowed_tokens` are permitted."""
num_aa = len(const.canonical_tokens)
mask = torch.ones((1, num_aa), dtype=torch.float32)
for token in allowed_tokens:
mask[0, const.canonical_tokens.index(token)] = 0.0
return mask
def test_conflict_allowed_and_global_avoid_keeps_global_restriction() -> None:
cys_idx = const.canonical_tokens.index("CYS")
aa_constraint_mask = _allowed_only_mask(["CYS"])
with pytest.warns(RuntimeWarning, match="Relaxing per-residue constraints"):
out = build_constraint_logit_mask(
num_nodes=1,
aa_constraint_mask=aa_constraint_mask,
inverse_fold_restriction=["CYS"],
canonical_tokens=const.canonical_tokens,
inf=INF,
device=torch.device("cpu"),
)
# Global avoid must still block CYS after conflict handling.
assert out[0, cys_idx].item() == -INF
# All other residues remain available.
assert (out[0] == 0).sum().item() == len(const.canonical_tokens) - 1
def test_non_conflicting_constraints_compose_correctly() -> None:
ala_idx = const.canonical_tokens.index("ALA")
cys_idx = const.canonical_tokens.index("CYS")
aa_constraint_mask = _allowed_only_mask(["ALA"])
out = build_constraint_logit_mask(
num_nodes=1,
aa_constraint_mask=aa_constraint_mask,
inverse_fold_restriction=["CYS"],
canonical_tokens=const.canonical_tokens,
inf=INF,
device=torch.device("cpu"),
)
# Only ALA should remain available.
assert out[0, ala_idx].item() == 0.0
assert out[0, cys_idx].item() == -INF
assert (out[0] == 0).sum().item() == 1
def test_global_restrictions_that_block_all_raise() -> None:
with pytest.raises(ValueError, match="no valid amino acids"):
build_constraint_logit_mask(
num_nodes=1,
aa_constraint_mask=None,
inverse_fold_restriction=const.canonical_tokens,
canonical_tokens=const.canonical_tokens,
inf=INF,
device=torch.device("cpu"),
)
def test_shape_mismatch_ignores_per_residue_mask() -> None:
bad_shape = torch.zeros((2, 20), dtype=torch.float32)
with pytest.warns(RuntimeWarning, match="shape mismatch"):
out = build_constraint_logit_mask(
num_nodes=1,
aa_constraint_mask=bad_shape,
inverse_fold_restriction=[],
canonical_tokens=const.canonical_tokens,
inf=INF,
device=torch.device("cpu"),
)
# No restrictions should remain after ignoring mismatched input.
assert out.shape == (1, len(const.canonical_tokens))
assert torch.all(out == 0)

View File

@@ -0,0 +1,381 @@
"""Unit tests for per-residue amino acid constraint parsing.
Tests parse_residue_constraints(), _normalize_aa_spec(), and
_convert_aa_names_to_indices() from boltzgen.data.parse.schema.
"""
import numpy as np
import pytest
from boltzgen.data import const
from boltzgen.data.parse.schema import (
_convert_aa_names_to_indices,
_normalize_aa_spec,
parse_residue_constraints,
)
# Shorthand fixtures
CANONICAL = const.canonical_tokens # 20 three-letter codes
LETTER_MAP = const.prot_letter_to_token # e.g. {"A": "ALA", ...}
# ============================================================================
# _normalize_aa_spec
# ============================================================================
class TestNormalizeAASpec:
"""Tests for _normalize_aa_spec helper."""
def test_single_letter(self):
assert _normalize_aa_spec("A") == ["A"]
def test_multi_letter_string(self):
assert _normalize_aa_spec("AGS") == ["A", "G", "S"]
def test_long_string(self):
assert _normalize_aa_spec("AVILMFYW") == list("AVILMFYW")
def test_three_letter_code(self):
assert _normalize_aa_spec("ALA") == ["ALA"]
def test_three_letter_not_valid(self):
# "AGS" is 3 chars but NOT a valid 3-letter code → split into 1-letter
assert _normalize_aa_spec("AGS") == ["A", "G", "S"]
def test_list_format_single_letters(self):
assert _normalize_aa_spec(["A", "G", "S"]) == ["A", "G", "S"]
def test_list_format_three_letter(self):
assert _normalize_aa_spec(["ALA", "GLY"]) == ["ALA", "GLY"]
def test_lowercase_normalised(self):
assert _normalize_aa_spec("ags") == ["A", "G", "S"]
def test_whitespace_stripped(self):
assert _normalize_aa_spec(" AG ") == ["A", "G"]
def test_invalid_type_raises(self):
with pytest.raises(ValueError):
_normalize_aa_spec(123)
# ============================================================================
# _convert_aa_names_to_indices
# ============================================================================
class TestConvertAANamesToIndices:
"""Tests for _convert_aa_names_to_indices helper."""
def test_single_letter_a(self):
indices = _convert_aa_names_to_indices(["A"], CANONICAL, LETTER_MAP)
assert indices == [CANONICAL.index("ALA")]
def test_single_letter_c(self):
indices = _convert_aa_names_to_indices(["C"], CANONICAL, LETTER_MAP)
assert indices == [CANONICAL.index("CYS")]
def test_three_letter_code(self):
indices = _convert_aa_names_to_indices(["ALA", "GLY"], CANONICAL, LETTER_MAP)
assert indices == [CANONICAL.index("ALA"), CANONICAL.index("GLY")]
def test_mixed_formats(self):
indices = _convert_aa_names_to_indices(["A", "GLY"], CANONICAL, LETTER_MAP)
assert indices == [CANONICAL.index("ALA"), CANONICAL.index("GLY")]
def test_all_20_aas(self):
all_letters = list("ACDEFGHIKLMNPQRSTVWY")
indices = _convert_aa_names_to_indices(all_letters, CANONICAL, LETTER_MAP)
assert len(indices) == 20
assert len(set(indices)) == 20 # all unique
def test_invalid_letter_raises(self):
with pytest.raises(ValueError, match="Unknown amino acid"):
_convert_aa_names_to_indices(["X"], CANONICAL, LETTER_MAP)
def test_invalid_three_letter_raises(self):
with pytest.raises(ValueError, match="Unknown amino acid"):
_convert_aa_names_to_indices(["ZZZ"], CANONICAL, LETTER_MAP)
# ============================================================================
# parse_residue_constraints — valid inputs
# ============================================================================
class TestParseResidueConstraintsValid:
"""Tests for parse_residue_constraints with valid YAML specs."""
def test_empty_list_returns_zeros(self):
mask = parse_residue_constraints([], 10, CANONICAL, LETTER_MAP)
assert mask.shape == (10, 20)
assert mask.sum() == 0.0
def test_single_allowed(self):
spec = [{"position": 1, "allowed": "A"}]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
# Position 0 (1-indexed=1): only ALA allowed (0.0), rest blocked (1.0)
assert mask[0, ala_idx] == 0.0
assert mask[0].sum() == 19.0 # 19 blocked, 1 allowed
# Other positions untouched
assert mask[1:].sum() == 0.0
def test_single_disallowed(self):
spec = [{"position": 3, "disallowed": "CM"}]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
cys_idx = CANONICAL.index("CYS")
met_idx = CANONICAL.index("MET")
# Position 2 (1-indexed=3): CYS and MET blocked
assert mask[2, cys_idx] == 1.0
assert mask[2, met_idx] == 1.0
assert mask[2].sum() == 2.0 # only 2 blocked
def test_range_positions(self):
spec = [{"position": "3..5", "disallowed": "C"}]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
cys_idx = CANONICAL.index("CYS")
# Positions 2,3,4 (1-indexed 3,4,5) should have CYS blocked
for pos in [2, 3, 4]:
assert mask[pos, cys_idx] == 1.0
# Other positions untouched
for pos in [0, 1, 5, 6, 7, 8, 9]:
assert mask[pos, cys_idx] == 0.0
def test_allowed_multiple_aas(self):
spec = [{"position": 8, "allowed": "AGS"}]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
gly_idx = CANONICAL.index("GLY")
ser_idx = CANONICAL.index("SER")
# Position 7 (1-indexed=8): only A,G,S allowed
assert mask[7, ala_idx] == 0.0
assert mask[7, gly_idx] == 0.0
assert mask[7, ser_idx] == 0.0
assert mask[7].sum() == 17.0 # 17 blocked
def test_list_format_allowed(self):
spec = [{"position": 1, "allowed": ["A", "G"]}]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
gly_idx = CANONICAL.index("GLY")
assert mask[0, ala_idx] == 0.0
assert mask[0, gly_idx] == 0.0
assert mask[0].sum() == 18.0
def test_multiple_constraints_no_overlap(self):
spec = [
{"position": 1, "allowed": "A"},
{"position": 5, "allowed": "P"},
]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
pro_idx = CANONICAL.index("PRO")
assert mask[0, ala_idx] == 0.0
assert mask[0].sum() == 19.0
assert mask[4, pro_idx] == 0.0
assert mask[4].sum() == 19.0
# Middle positions untouched
assert mask[1:4].sum() == 0.0
# ------------------------------------------------------------------
# Intersection semantics (overlapping constraints)
# ------------------------------------------------------------------
def test_overlapping_allowed_intersection(self):
"""Two allowed constraints on same position → only common AAs survive."""
spec = [
{"position": 1, "allowed": "AG"},
{"position": 1, "allowed": "GS"},
]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
gly_idx = CANONICAL.index("GLY")
ala_idx = CANONICAL.index("ALA")
ser_idx = CANONICAL.index("SER")
# Only G is in both sets
assert mask[0, gly_idx] == 0.0 # allowed
assert mask[0, ala_idx] == 1.0 # blocked (not in 2nd)
assert mask[0, ser_idx] == 1.0 # blocked (not in 1st)
assert mask[0].sum() == 19.0 # only GLY allowed
def test_overlapping_allowed_range_intersection(self):
"""Overlapping ranges intersect at overlap positions."""
spec = [
{"position": "1..5", "allowed": "AG"},
{"position": "3..7", "allowed": "GS"},
]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
gly_idx = CANONICAL.index("GLY")
ala_idx = CANONICAL.index("ALA")
ser_idx = CANONICAL.index("SER")
# Positions 0,1 (1-indexed 1,2): only AG (first constraint only)
assert mask[0, ala_idx] == 0.0
assert mask[0, gly_idx] == 0.0
assert mask[0].sum() == 18.0
# Positions 2,3,4 (1-indexed 3,4,5): intersection of {A,G} and {G,S} = {G}
for pos in [2, 3, 4]:
assert mask[pos, gly_idx] == 0.0
assert mask[pos, ala_idx] == 1.0
assert mask[pos, ser_idx] == 1.0
assert mask[pos].sum() == 19.0
# Positions 5,6 (1-indexed 6,7): only GS (second constraint only)
assert mask[5, gly_idx] == 0.0
assert mask[5, ser_idx] == 0.0
assert mask[5].sum() == 18.0
def test_allowed_then_disallowed_same_position(self):
"""allowed + disallowed on same position: disallowed narrows the set."""
spec = [
{"position": 5, "allowed": "AGILMV"},
{"position": 5, "disallowed": "CM"},
]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
met_idx = CANONICAL.index("MET")
ala_idx = CANONICAL.index("ALA")
# M was in allowed set but then blocked by disallowed
assert mask[4, met_idx] == 1.0
# A was in allowed set and not disallowed
assert mask[4, ala_idx] == 0.0
def test_disallowed_then_allowed_same_position(self):
"""Order independent: disallowed then allowed gives same result."""
spec_ab = [
{"position": 5, "allowed": "AGILMV"},
{"position": 5, "disallowed": "CM"},
]
spec_ba = [
{"position": 5, "disallowed": "CM"},
{"position": 5, "allowed": "AGILMV"},
]
mask_ab = parse_residue_constraints(spec_ab, 10, CANONICAL, LETTER_MAP)
mask_ba = parse_residue_constraints(spec_ba, 10, CANONICAL, LETTER_MAP)
np.testing.assert_array_equal(mask_ab, mask_ba)
def test_disjoint_allowed_sets_all_blocked(self):
"""Two allowed sets with no overlap → all 20 AAs blocked."""
spec = [
{"position": 1, "allowed": "AG"},
{"position": 1, "allowed": "VILM"},
]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
# All 20 blocked at position 0
assert mask[0].sum() == 20.0
def test_multiple_disallowed_accumulate(self):
"""Multiple disallowed on same position: union of blocked sets."""
spec = [
{"position": 1, "disallowed": "CM"},
{"position": 1, "disallowed": "WK"},
]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
cys_idx = CANONICAL.index("CYS")
met_idx = CANONICAL.index("MET")
trp_idx = CANONICAL.index("TRP")
lys_idx = CANONICAL.index("LYS")
assert mask[0, cys_idx] == 1.0
assert mask[0, met_idx] == 1.0
assert mask[0, trp_idx] == 1.0
assert mask[0, lys_idx] == 1.0
assert mask[0].sum() == 4.0
def test_dtype_and_shape(self):
spec = [{"position": 1, "allowed": "A"}]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
assert mask.dtype == np.float32
assert mask.shape == (10, 20)
# ============================================================================
# parse_residue_constraints — error paths
# ============================================================================
class TestParseResidueConstraintsErrors:
"""Tests for parse_residue_constraints with invalid YAML specs."""
def test_missing_position(self):
spec = [{"allowed": "A"}]
with pytest.raises(ValueError, match="position.*required"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_position_out_of_bounds_high(self):
spec = [{"position": 11, "allowed": "A"}]
with pytest.raises(ValueError, match="out of bounds"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_position_out_of_bounds_zero(self):
# Position 0 is invalid (1-indexed); parse_range catches this
spec = [{"position": 0, "allowed": "A"}]
with pytest.raises(ValueError, match="1 indexed|out of bounds"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_both_allowed_and_disallowed(self):
spec = [{"position": 1, "allowed": "A", "disallowed": "C"}]
with pytest.raises(ValueError, match="cannot specify both"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_neither_allowed_nor_disallowed(self):
spec = [{"position": 1}]
with pytest.raises(ValueError, match="must specify either"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_empty_allowed(self):
spec = [{"position": 1, "allowed": ""}]
with pytest.raises(ValueError, match="cannot be empty"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_invalid_amino_acid_code(self):
spec = [{"position": 1, "allowed": "X"}]
with pytest.raises(ValueError, match="Unknown amino acid"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_invalid_amino_acid_in_disallowed(self):
spec = [{"position": 1, "disallowed": "XZ"}]
with pytest.raises(ValueError, match="Unknown amino acid"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
# ============================================================================
# Regression: original test case (no overlaps)
# ============================================================================
class TestOriginalTestCase:
"""Regression test matching residue_constraints_test.yaml."""
def test_original_yaml_constraints(self):
"""Matches the constraints from example/residue_constraints_test.yaml."""
spec = [
{"position": 1, "allowed": "A"},
{"position": "3..5", "disallowed": "CM"},
{"position": 8, "allowed": "AGS"},
{"position": 10, "allowed": "P"},
]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
cys_idx = CANONICAL.index("CYS")
met_idx = CANONICAL.index("MET")
gly_idx = CANONICAL.index("GLY")
ser_idx = CANONICAL.index("SER")
pro_idx = CANONICAL.index("PRO")
# Position 1: only A
assert mask[0, ala_idx] == 0.0
assert mask[0].sum() == 19.0
# Positions 3-5: C and M blocked
for pos in [2, 3, 4]:
assert mask[pos, cys_idx] == 1.0
assert mask[pos, met_idx] == 1.0
assert mask[pos].sum() == 2.0
# Position 8: only A, G, S
assert mask[7, ala_idx] == 0.0
assert mask[7, gly_idx] == 0.0
assert mask[7, ser_idx] == 0.0
assert mask[7].sum() == 17.0
# Position 10: only P
assert mask[9, pro_idx] == 0.0
assert mask[9].sum() == 19.0
# Unconstrained positions (2, 6, 7, 9 in 0-indexed) are all zeros
for pos in [1, 5, 6, 8]:
assert mask[pos].sum() == 0.0