mirror of
https://github.com/HannesStark/boltzgen.git
synced 2026-06-04 11:54:23 +08:00
Merge branch 'main' into ipsae
This commit is contained in:
25
Dockerfile
25
Dockerfile
@@ -12,11 +12,14 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
python3-wheel \
|
||||
software-properties-common \
|
||||
curl \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python3.11 \
|
||||
python3.11-dev \
|
||||
python3.11-venv \
|
||||
build-essential \
|
||||
git \
|
||||
cmake \
|
||||
@@ -30,9 +33,10 @@ RUN apt-get update && \
|
||||
libboost-all-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 && \
|
||||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 && \
|
||||
python -m pip install --upgrade pip setuptools setuptools_scm wheel
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1 && \
|
||||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 && \
|
||||
curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11 && \
|
||||
python3.11 -m pip install --upgrade pip setuptools setuptools_scm wheel
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -56,7 +60,4 @@ RUN groupadd --gid ${USER_GID} ${USERNAME} && \
|
||||
RUN mkdir -p "${HF_HOME}" && chown -R ${USER_UID}:${USER_GID} "${HF_HOME}"
|
||||
|
||||
USER ${USERNAME}
|
||||
WORKDIR /workspace
|
||||
|
||||
ENTRYPOINT ["boltzgen"]
|
||||
CMD ["--help"]
|
||||
WORKDIR /workspace
|
||||
52
README.md
52
README.md
@@ -4,7 +4,8 @@
|
||||
<img src="assets/boltzgen.png" alt="BoltzGen logo" width="60%">
|
||||
|
||||
[Paper](https://hannes-stark.com/assets/boltzgen.pdf) |
|
||||
[Slack](https://boltz.bio/join-slack) <br> <br>
|
||||
[Slack](https://boltz.bio/join-slack) |
|
||||
[Video](https://www.youtube.com/watch?v=9d_QWUUI1Qo) <br> <br>
|
||||

|
||||
</div>
|
||||
|
||||
@@ -73,10 +74,10 @@ docker build -t boltzgen .
|
||||
# Run an example
|
||||
mkdir -p workdir # output
|
||||
mkdir -p cache # where models will be downloaded to
|
||||
docker run --rm --gpus all -v "$(realpath workdir)":/workdir -v "$(realpath cache)":/cache -v "$(realpath example)":/example \
|
||||
boltzgen run /example/vanilla_protein/1g13prot.yaml --output /workdir/test \
|
||||
--protocol protein-anything \
|
||||
--num_designs 2
|
||||
docker run --rm --gpus all -v "$(realpath workdir)":/workdir -v "$(realpath cache)":/cache -v "$(realpath example)":/example boltzgen \
|
||||
boltzgen run /example/vanilla_protein/1g13prot.yaml --output /workdir/test \
|
||||
--protocol protein-anything \
|
||||
--num_designs 2
|
||||
```
|
||||
|
||||
In the example above, the model weights are downloaded the first time the image is run. To bake the weights into the image at build time, run:
|
||||
@@ -144,7 +145,7 @@ When the pipeline completes your output directory will have:
|
||||
- `/final_designs_metrics_<budget>.csv` — metrics for the selected final set.
|
||||
- `/results_overview.pdf` — plots
|
||||
|
||||
# Protocols
|
||||
# Protocols
|
||||
|
||||
| Protocol (design-target) | Appropriate for | Major config differences |
|
||||
|--------------------------|---------------------------------------------------------------------------|------------------------|
|
||||
@@ -153,6 +154,7 @@ When the pipeline completes your output directory will have:
|
||||
| protein-small_molecule | Design proteins to bind small molecules | Includes binding affinity prediction. Includes `design folding` step. |
|
||||
| antibody-anything | Design antibody CDRs | No Cys are generated in inverse folding. No `design folding` step. Don't compute largest hydrophobic patch. |
|
||||
| nanobody-anything | Design nanobody CDRs | Same settings as antibody-anything |
|
||||
| protein-redesign | Redesign or optimize existing proteins | No `design folding` step. Uses `design_mask` for target/template definition. |
|
||||
|
||||
All configuration parameters can be overridden using the `--config` option; see `boltzgen run --help` or the `Advanced Users` section below for details.
|
||||
|
||||
@@ -182,6 +184,7 @@ We provide many example `.yaml` files in the `example/` directory, including:
|
||||
- `example/fab_targets/pdl1.yaml`
|
||||
- `example/denovo_zinc_finger_against_dna/zinc_finger.yaml`
|
||||
- `example/protein_binding_small_molecule/chorismite.yaml`
|
||||
- `example/small_molecule_from_file_and_smiles/4g37.yaml`
|
||||
|
||||
Small example of a protein design against a target protein without binding site specified:
|
||||
```yaml
|
||||
@@ -311,6 +314,33 @@ constraints:
|
||||
|
||||
```
|
||||
|
||||
## Symmetric complex design (inverse-folding only)
|
||||
|
||||
For symmetric complexes (e.g., homo-dimers), the you can tie sequence generation during inverse folding by specifying the `symmetric_group` for each symmetric chain. The `protein-redesign` protocol allows for scoring complexes without separate binders/targets.
|
||||
|
||||
```yaml
|
||||
entities:
|
||||
- file:
|
||||
path: symmetric_dimer.cif
|
||||
include:
|
||||
- chain:
|
||||
id: A
|
||||
res_index: 100..300
|
||||
symmetric_group: 1 # Link chains A and B for symmetric sampling
|
||||
- chain:
|
||||
id: B
|
||||
res_index: 100..300
|
||||
symmetric_group: 1 # Same group = same sampled insertion length
|
||||
|
||||
# Mark residues to redesign on both chains
|
||||
design:
|
||||
- chain:
|
||||
id: A
|
||||
res_index: 200..210
|
||||
- chain:
|
||||
id: B
|
||||
res_index: 200..210
|
||||
```
|
||||
|
||||
# Running only specific pipeline steps
|
||||
|
||||
@@ -325,6 +355,16 @@ boltzgen run example/cyclotide/3ivq.yaml \
|
||||
--num_designs 2
|
||||
```
|
||||
|
||||
If you want to run only the inverse folding and subsequent design evaluation steps (but not the backbone design step), you can also run:
|
||||
|
||||
**Run only inverse_folding step:**
|
||||
```bash
|
||||
boltzgen run example/inverse_folding/1brs.yaml \
|
||||
--output workbench/if-only \
|
||||
--only_inverse_fold \
|
||||
--inverse_fold_num_sequences 2
|
||||
```
|
||||
|
||||
**Available steps:**
|
||||
- `design` - Generate num_design candidates using the diffusion model based on your design specification
|
||||
- `inverse_folding` - Redesign sequences from the previous step using our inverse folding model
|
||||
|
||||
@@ -22,6 +22,7 @@ We provide many example `.yaml` files in the `example/` directory, including:
|
||||
- [fab_targets/pdl1.yaml](fab_targets/pdl1.yaml)
|
||||
- [denovo_zinc_finger_against_dna/zinc_finger.yaml](denovo_zinc_finger_against_dna/zinc_finger.yaml)
|
||||
- [protein_binding_small_molecule/chorismite.yaml](protein_binding_small_molecule/chorismite.yaml)
|
||||
- [small_molecule_from_file_and_smiles/4g37.yaml](small_molecule_from_file_and_smiles/4g37.yaml)
|
||||
|
||||
Small example of a protein design against a target protein without binding site specified:
|
||||
```yaml
|
||||
@@ -343,6 +344,9 @@ constraints:
|
||||
atom2: [Q, 1, CK] # Connect sulfur of Cys-4 in chain R to atom CK in ligand Q
|
||||
```
|
||||
|
||||
We now support constraints specifications of small molecules from the input file and from smiles. Check `examples/small_molecule_from_file_and_smiles/4g37.yaml`. Below is brief guidelines:
|
||||
* Small molecules from the file: check `atom_name` from the CCD and specify it.
|
||||
* Small molecules from the smiles: count index of target element from the smiles and specify its element type with index (e.g. C6, for 6th carbon from the smiles).
|
||||
|
||||
Here is a comprehensive list of all the keys from your YAML file with explanations for each.
|
||||
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
entities:
|
||||
- protein:
|
||||
- protein:
|
||||
id: G
|
||||
sequence: 15..20AAAAAAVTTTT18PPP # range between 15 and 20 inclusive on both sides
|
||||
residue_constraints:
|
||||
- position: 1
|
||||
allowed: A # Only Alanine at position 1
|
||||
- position: 3..5
|
||||
disallowed: CM # No Cysteine or Methionine at positions 3-5
|
||||
- position: 8
|
||||
allowed: AGS # Only Ala, Gly, or Ser at position 8
|
||||
- protein:
|
||||
id: R
|
||||
sequence: 3..5C6C3 # Random number of design residues between 3 and 5, then a Cystein, then 6 design residues, then ...
|
||||
|
||||
8312
example/inverse_folding/1brs.cif
Normal file
8312
example/inverse_folding/1brs.cif
Normal file
File diff suppressed because it is too large
Load Diff
16
example/inverse_folding/1brs.yaml
Normal file
16
example/inverse_folding/1brs.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
entities:
|
||||
- file:
|
||||
path: 1brs.cif
|
||||
|
||||
# Include target and binder
|
||||
include:
|
||||
- chain:
|
||||
id: A
|
||||
- chain:
|
||||
id: D
|
||||
|
||||
# Redesign specified binder residues
|
||||
design:
|
||||
- chain:
|
||||
id: D
|
||||
res_index: 33..46
|
||||
33
example/residue_constraints_test.yaml
Normal file
33
example/residue_constraints_test.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
# Test file for per-residue amino acid constraints
|
||||
# This tests the new residue_constraints feature
|
||||
#
|
||||
# Usage:
|
||||
# boltzgen check example/residue_constraints_test.yaml
|
||||
# boltzgen run example/residue_constraints_test.yaml --output test_output/ --steps design inverse_folding --num_designs 50
|
||||
#
|
||||
# Note: Use --num_designs 50 (not 5) for statistically meaningful verification.
|
||||
# With only 5 designs, blacklist constraints have a ~21% false-pass probability.
|
||||
#
|
||||
# Note: Supports both string format ("AGS") and list format ([A, G, S])
|
||||
# String format is preferred for consistency with sequence/binding_types
|
||||
|
||||
entities:
|
||||
- protein:
|
||||
id: A
|
||||
sequence: 10
|
||||
residue_constraints:
|
||||
# Position 1: Force Alanine only
|
||||
- position: 1
|
||||
allowed: A
|
||||
|
||||
# Positions 3-5: Exclude Cysteine and Methionine
|
||||
- position: 3..5
|
||||
disallowed: CM
|
||||
|
||||
# Position 8: Allow only small amino acids
|
||||
- position: 8
|
||||
allowed: AGS
|
||||
|
||||
# Position 10: Force Proline
|
||||
- position: 10
|
||||
allowed: P
|
||||
4026
example/small_molecule_from_file_and_smiles/4g37.pdb
Normal file
4026
example/small_molecule_from_file_and_smiles/4g37.pdb
Normal file
File diff suppressed because it is too large
Load Diff
37
example/small_molecule_from_file_and_smiles/4g37.yaml
Normal file
37
example/small_molecule_from_file_and_smiles/4g37.yaml
Normal file
@@ -0,0 +1,37 @@
|
||||
entities:
|
||||
- protein:
|
||||
id: B
|
||||
sequence: 1..3C1..2C1..3 ###1to3 residues before first CYS, 1to2 residues, second CYS, followed by 1to3 residues
|
||||
- file:
|
||||
path: 4g37.pdb
|
||||
|
||||
include:
|
||||
- chain:
|
||||
id: A
|
||||
- chain:
|
||||
id: C # XLX
|
||||
- ligand:
|
||||
id: D
|
||||
smiles: "CC(=O)NCCNC(C)=O"
|
||||
|
||||
constraints:
|
||||
## constraints from the file.
|
||||
# covalent bond between C3 in XLX (from CCD) and SG in chain A of 4g37.pdb
|
||||
- bond:
|
||||
atom1: [A, 105, SG]
|
||||
atom2: [C, 1, C3]
|
||||
# additional covalent bond between C11 in XLX (from CCD) and SG in chain B (design)
|
||||
- bond:
|
||||
atom1: [B, 2, SG]
|
||||
atom2: [C, 1, C11]
|
||||
## constraints from the smiles.
|
||||
- bond:
|
||||
# covalent bond between C1 in D (from smiles) and OG in chain A of 4g37.pdb
|
||||
atom1: [A, 339, OG]
|
||||
atom2: [D, 1, C1]
|
||||
# additional covalent bond between C6 in D (from smiles) and SG in chain B (design)
|
||||
- bond:
|
||||
atom1: [B, 4, SG]
|
||||
atom2: [D, 1, C6]
|
||||
|
||||
#boltzgen run example/small_molecule_from_file_and_smiles/4g37.yaml --output workbench --protocol protein-anything --num_designs 2 --devices 1 --steps design
|
||||
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "boltzgen"
|
||||
requires-python = ">=3.11"
|
||||
version = "0.1.8"
|
||||
version = "0.3.0"
|
||||
readme = { file = "PYPI_DESCRIPTION.md", content-type = "text/markdown" }
|
||||
description = "Protein design"
|
||||
dependencies = [
|
||||
|
||||
@@ -34,7 +34,7 @@ if [[ "$MODE" == "submit" ]]; then
|
||||
elif [[ "$MODE" == "process" ]]; then
|
||||
|
||||
boltzgen merge "$OUT"/task-* --output "$MERGED_OUT"
|
||||
boltzgen run "$DESIGN_SPEC" --steps filtering --output "$MERGED_OUT"
|
||||
boltzgen run "$DESIGN_SPEC" --steps filtering --protocol protein-anything --output "$MERGED_OUT"
|
||||
|
||||
else
|
||||
echo "Usage: $0 {submit|process}"
|
||||
|
||||
@@ -100,6 +100,16 @@ protocol_configs = {
|
||||
"analysis": ["largest_hydrophobic=false", "largest_hydrophobic_refolded=false"],
|
||||
"filtering": ["filter_cysteine=true"],
|
||||
},
|
||||
"protein-redesign": {
|
||||
# For redesigning/optimizing existing proteins (e.g., symmetric dimers)
|
||||
# where all chains may have designed residues. Skips design_folding and
|
||||
# uses design_mask (not chain_design_mask) for target/template definition.
|
||||
"folding": ["data.design_mask_templates=true"],
|
||||
"analysis": ["use_design_mask_for_target=true"],
|
||||
"filtering": [
|
||||
"metrics_override={design_to_target_iptm: null, neg_min_design_to_target_pae: null, design_ptm: null, plip_hbonds_refolded: null, plip_saltbridge_refolded: null, delta_sasa_refolded: null, plip_hbonds: null, plip_saltbridge: null, delta_sasa_original: null, design_residue_iptm: 1, iptm: 2, ptm: 3, neg_filter_rmsd_design: 4}",
|
||||
],
|
||||
},
|
||||
}
|
||||
assert all(
|
||||
step_name in step_names for cfg in protocol_configs.values() for step_name in cfg
|
||||
@@ -981,6 +991,30 @@ class BinderDesignPipeline:
|
||||
)
|
||||
|
||||
if args.only_inverse_fold:
|
||||
exclude_residues = []
|
||||
inverse_fold_avoid = (
|
||||
args.inverse_fold_avoid
|
||||
if args.inverse_fold_avoid is not None
|
||||
else (
|
||||
"C"
|
||||
if protocol
|
||||
in [
|
||||
"peptide-anything",
|
||||
"nanobody-anything",
|
||||
"antibody-anything",
|
||||
]
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
for one_letter_code in inverse_fold_avoid:
|
||||
exclude_residues.append(const.prot_letter_to_token[one_letter_code])
|
||||
|
||||
if len(exclude_residues) > 0:
|
||||
print(
|
||||
f"Inverse fold will avoid the following residues: {exclude_residues}"
|
||||
)
|
||||
print(f"Inverse-folded designs will be saved to: {output_dir}")
|
||||
# Designs from inverse folding
|
||||
self.steps.append(
|
||||
PipelineStep(
|
||||
@@ -991,9 +1025,12 @@ class BinderDesignPipeline:
|
||||
f"data.cfg.yaml_path=[{', '.join(str(s) for s in args.design_spec)}]",
|
||||
f"trainer.devices={devices}",
|
||||
f"data.cfg.multiplicity={getattr(args, 'inverse_fold_num_sequences', 10)}",
|
||||
f"data.cfg.skip_existing={args.reuse}",
|
||||
f"data.cfg.output_dir={output_dir}",
|
||||
f"override.use_kernels={use_kernels}",
|
||||
f"checkpoint={get_artifact_path(args, args.inverse_fold_checkpoint)}",
|
||||
f"data.cfg.moldir={moldir}",
|
||||
f"override.inverse_fold_args.inverse_fold_restriction=[{', '.join(exclude_residues)}]",
|
||||
]
|
||||
+ config_args_by_step.get("inverse_folding", []),
|
||||
)
|
||||
|
||||
@@ -3469,11 +3469,17 @@ eval_keys_confidence = [
|
||||
"design_iptm",
|
||||
"design_iiptm",
|
||||
"design_to_target_iptm",
|
||||
"design_residue_iptm",
|
||||
"target_ptm",
|
||||
"design_ptm",
|
||||
"design_ipsae_min",
|
||||
"design_to_target_ipsae",
|
||||
"target_to_design_ipsae",
|
||||
"ligand_iptm",
|
||||
"complex_plddt",
|
||||
"complex_iplddt",
|
||||
"complex_pde",
|
||||
"complex_ipde",
|
||||
]
|
||||
|
||||
eval_keys_affinity = [
|
||||
@@ -3522,6 +3528,7 @@ token_features = [
|
||||
"res_type_clone",
|
||||
"is_standard",
|
||||
"design_mask",
|
||||
"aa_constraint_mask",
|
||||
"binding_type",
|
||||
"structure_group",
|
||||
"token_bonds",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
import re
|
||||
@@ -291,6 +292,7 @@ Chain = [
|
||||
("res_idx", np.dtype("i4")),
|
||||
("res_num", np.dtype("i4")),
|
||||
("cyclic_period", np.dtype("i4")),
|
||||
("symmetric_group", np.dtype("i4")),
|
||||
]
|
||||
|
||||
Interface = [
|
||||
@@ -907,7 +909,7 @@ class Structure(NumpySerializable):
|
||||
old_to_new = {old.item(): new for new, old in enumerate(atom_indices)}
|
||||
old_to_new_res = {old.item(): new for new, old in enumerate(res_indices)}
|
||||
|
||||
old_to_new_res_chain = {}
|
||||
res_chain_map = {}
|
||||
|
||||
for i in range(len(residues)):
|
||||
original_atom_range = np.arange(
|
||||
@@ -966,8 +968,7 @@ class Structure(NumpySerializable):
|
||||
chain_start = orig_chain["res_idx"]
|
||||
chain_end = orig_chain["res_idx"] + orig_chain["res_num"]
|
||||
chain_res_indices = [r for r in res_indices if chain_start <= r < chain_end]
|
||||
chain_res_indices -= orig_chain["res_idx"]
|
||||
old_to_new_res_chain[i] = {
|
||||
res_chain_map[i] = {
|
||||
old.item(): new for new, old in enumerate(chain_res_indices)
|
||||
}
|
||||
|
||||
@@ -996,8 +997,8 @@ class Structure(NumpySerializable):
|
||||
chain_atom_start = chain["atom_idx"]
|
||||
chain_atom_end = chain["atom_idx"] + chain["atom_num"]
|
||||
if chain_atom_start <= res["atom_idx"] < chain_atom_end:
|
||||
res_idx_item = residues[i]["res_idx"].item()
|
||||
residues[i]["res_idx"] = old_to_new_res_chain[chain_idx].get(
|
||||
res_idx_item = res_indices[i]
|
||||
residues[i]["res_idx"] = res_chain_map[chain_idx].get(
|
||||
res_idx_item
|
||||
)
|
||||
|
||||
@@ -1216,7 +1217,8 @@ class Structure(NumpySerializable):
|
||||
len(atom_data),
|
||||
0,
|
||||
len(res_data),
|
||||
0,
|
||||
0, # cyclic_period
|
||||
0, # symmetric_group
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1523,7 +1525,8 @@ class Structure(NumpySerializable):
|
||||
num_atoms,
|
||||
total_res,
|
||||
len(chain_res_selector),
|
||||
0,
|
||||
0, # cyclic_period
|
||||
0, # symmetric_group
|
||||
)
|
||||
)
|
||||
total_res += len(chain_res_selector)
|
||||
@@ -1759,6 +1762,18 @@ def biotite_array_from_feat(feat):
|
||||
atom_pad_mask & atom_resolved_mask
|
||||
].bool()
|
||||
|
||||
# add chain design mask
|
||||
chain_design_mask = feat["chain_design_mask"].bool()
|
||||
atom_chain_design_mask = (
|
||||
(feat["atom_to_token"].float() @ chain_design_mask.unsqueeze(-1).float())
|
||||
.bool()
|
||||
.squeeze()
|
||||
)
|
||||
atom_array.add_annotation("is_chain_design", bool)
|
||||
atom_array.is_chain_design = atom_chain_design_mask[
|
||||
atom_pad_mask & atom_resolved_mask
|
||||
].bool()
|
||||
|
||||
return atom_array
|
||||
|
||||
|
||||
@@ -1916,6 +1931,7 @@ class DesignInfo(NumpySerializable):
|
||||
res_structure_groups: npt.NDArray[np.int_]
|
||||
res_ss_types: npt.NDArray[np.int_]
|
||||
res_binding_type: npt.NDArray[np.int_]
|
||||
res_aa_constraint_mask: npt.NDArray[np.float32] # Shape: (num_residues, 20), 0=allowed, 1=disallowed
|
||||
|
||||
@classmethod
|
||||
def is_valid(self, info: "DesignInfo") -> bool:
|
||||
@@ -1925,6 +1941,7 @@ class DesignInfo(NumpySerializable):
|
||||
len(info.res_design_mask) == len(info.res_structure_groups)
|
||||
and len(info.res_structure_groups) == len(info.res_ss_types)
|
||||
and len(info.res_ss_types) == len(info.res_binding_type)
|
||||
and len(info.res_aa_constraint_mask) == len(info.res_design_mask)
|
||||
), (
|
||||
"There must be a bug in the code. All residue level design info objects should have the same length."
|
||||
)
|
||||
@@ -1941,6 +1958,22 @@ class DesignInfo(NumpySerializable):
|
||||
msg = "Misspecified design info. There were residues that have a secondary structure type specified but are not set to be designed."
|
||||
raise ValueError(msg)
|
||||
|
||||
# Validate residue constraints
|
||||
has_constraints = info.res_aa_constraint_mask.any(axis=1)
|
||||
if any(has_constraints & ~info.res_design_mask.astype(bool)):
|
||||
warnings.warn(
|
||||
"Residue constraints specified for non-designed residues "
|
||||
"will be ignored during inverse folding.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Check if any designed position has ALL amino acids blocked
|
||||
all_blocked = info.res_aa_constraint_mask.all(axis=1)
|
||||
if any(all_blocked & info.res_design_mask.astype(bool)):
|
||||
msg = "Invalid residue constraints: some designed positions have all amino acids disallowed."
|
||||
raise ValueError(msg)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -2035,11 +2068,13 @@ Token = [
|
||||
("design_mask", np.dtype("?")),
|
||||
("binding_type", np.dtype("i4")),
|
||||
("structure_group", np.dtype("i4")),
|
||||
("aa_constraint_mask", np.dtype("20f4")), # Per-residue AA constraints: 20 floats (one per canonical AA)
|
||||
("ccd", np.dtype("5i4")),
|
||||
("target_msa_mask", np.dtype("?")),
|
||||
("design_ss_mask", np.dtype("?")),
|
||||
("feature_asym_id", np.dtype("i4")),
|
||||
("feature_res_idx", np.dtype("i4")),
|
||||
("symmetric_group", np.dtype("i4")),
|
||||
]
|
||||
|
||||
TokenBond = [
|
||||
|
||||
@@ -701,6 +701,12 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
|
||||
res_type = from_numpy(token_data["res_type"]).long()
|
||||
is_standard = from_numpy(token_data["is_standard"])
|
||||
design = from_numpy(token_data["design_mask"]).long()
|
||||
# Per-residue amino acid constraint mask (shape: num_tokens x 20)
|
||||
if "aa_constraint_mask" in token_data.dtype.names:
|
||||
aa_constraint_mask = from_numpy(token_data["aa_constraint_mask"]).float()
|
||||
else:
|
||||
# Default: no constraints (all zeros = all AAs allowed)
|
||||
aa_constraint_mask = torch.zeros(len(token_data), len(const.canonical_tokens))
|
||||
res_type = one_hot(res_type, num_classes=const.num_tokens)
|
||||
modified = from_numpy(token_data["modified"]).long()
|
||||
ccd = from_numpy(token_data["ccd"]).long()
|
||||
@@ -711,6 +717,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
|
||||
design_ss_mask = from_numpy(token_data["design_ss_mask"])
|
||||
feature_residue_index = from_numpy(token_data["feature_res_idx"]).long()
|
||||
feature_asym_id = from_numpy(token_data["feature_asym_id"]).long()
|
||||
symmetric_group = from_numpy(token_data["symmetric_group"]).long()
|
||||
token_to_res = from_numpy(data.token_to_res).long()
|
||||
|
||||
method = (
|
||||
@@ -863,6 +870,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
|
||||
res_type = pad_dim(res_type, 0, pad_len)
|
||||
is_standard = pad_dim(is_standard, 0, pad_len)
|
||||
design = pad_dim(design, 0, pad_len)
|
||||
aa_constraint_mask = pad_dim(aa_constraint_mask, 0, pad_len)
|
||||
binding_type = pad_dim(binding_type, 0, pad_len)
|
||||
structure_group = pad_dim(structure_group, 0, pad_len)
|
||||
pad_mask = pad_dim(pad_mask, 0, pad_len)
|
||||
@@ -887,6 +895,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
|
||||
design_ss_mask = pad_dim(design_ss_mask, 0, pad_len)
|
||||
feature_residue_index = pad_dim(feature_residue_index, 0, pad_len)
|
||||
feature_asym_id = pad_dim(feature_asym_id, 0, pad_len)
|
||||
symmetric_group = pad_dim(symmetric_group, 0, pad_len)
|
||||
token_to_res = pad_dim(token_to_res, 0, pad_len)
|
||||
token_features = {
|
||||
"token_index": token_index,
|
||||
@@ -899,6 +908,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
|
||||
"res_type_clone": res_type.clone(),
|
||||
"is_standard": is_standard,
|
||||
"design_mask": design,
|
||||
"aa_constraint_mask": aa_constraint_mask,
|
||||
"binding_type": binding_type,
|
||||
"structure_group": structure_group,
|
||||
"token_bonds": bonds,
|
||||
@@ -921,6 +931,7 @@ def process_token_features( # noqa: C901, PLR0915, PLR0912
|
||||
"design_ss_mask": design_ss_mask,
|
||||
"feature_residue_index": feature_residue_index,
|
||||
"feature_asym_id": feature_asym_id,
|
||||
"symmetric_group": symmetric_group,
|
||||
"ligand_affinity_mask": ligand_affinity_mask,
|
||||
"token_to_res": token_to_res,
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import itertools
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import random
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
@@ -14,10 +15,44 @@ from boltzgen.data import const
|
||||
from boltzgen.data.pad import pad_dim
|
||||
from boltzgen.model.loss.confidence import lddt_dist
|
||||
|
||||
MOLDIR_ZIP_CACHE = {} # moldir -> open zip file
|
||||
|
||||
MOLDIR_ZIP_CACHE: Dict[
|
||||
tuple[int, Path], zipfile.ZipFile
|
||||
] = {} # moldir -> open zip file
|
||||
|
||||
|
||||
def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
|
||||
def _get_zipfile(moldir: Path) -> zipfile.ZipFile:
|
||||
"""
|
||||
Retrieve a cached ZipFile object for the given molecule directory zip file.
|
||||
|
||||
If the ZipFile for the provided path and current process does not exist in the cache,
|
||||
it is created and stored. Otherwise, the cached instance is returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
moldir : Path
|
||||
Path to the .zip file containing molecule data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
zipfile.ZipFile
|
||||
An open ZipFile object for reading molecule data.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The cache is keyed by both process ID and zip file path to prevent issues with forked processes.
|
||||
"""
|
||||
pid = os.getpid()
|
||||
key = (pid, moldir)
|
||||
|
||||
zf = MOLDIR_ZIP_CACHE.get(key)
|
||||
if zf is None:
|
||||
zf = zipfile.ZipFile(moldir, "r")
|
||||
MOLDIR_ZIP_CACHE[key] = zf
|
||||
return zf
|
||||
|
||||
|
||||
def load_molecules(moldir: str, molecules: List[str]) -> Dict[str, Mol]:
|
||||
"""Load the given input data.
|
||||
|
||||
Parameters
|
||||
@@ -33,12 +68,10 @@ def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
|
||||
The loaded molecules.
|
||||
"""
|
||||
moldir = Path(moldir)
|
||||
loaded_mols = {}
|
||||
loaded_mols: Dict[str, Mol] = {}
|
||||
|
||||
if moldir.is_file() and moldir.suffix == ".zip":
|
||||
if moldir not in MOLDIR_ZIP_CACHE:
|
||||
MOLDIR_ZIP_CACHE[moldir] = zipfile.ZipFile(moldir, "r")
|
||||
zip_file = MOLDIR_ZIP_CACHE[moldir]
|
||||
zip_file = _get_zipfile(moldir)
|
||||
for molecule in molecules:
|
||||
pkl_filename = f"{molecule}.pkl"
|
||||
with zip_file.open(pkl_filename, "r") as f:
|
||||
@@ -621,12 +654,16 @@ def get_chain_symmetries(cropped, backbone_only, atom14, atom37, max_n_symmetrie
|
||||
crop_to_all_atom_map.shape[0]
|
||||
)
|
||||
connections_edge_index = []
|
||||
crop_atom_set = set(crop_to_all_atom_map.astype(np.int64))
|
||||
for connection in structure.bonds:
|
||||
if (connection["chain_1"] == connection["chain_2"]) and (
|
||||
connection["res_1"] == connection["res_2"]
|
||||
):
|
||||
continue
|
||||
connections_edge_index.append([connection["atom_1"], connection["atom_2"]])
|
||||
atom_1, atom_2 = connection["atom_1"], connection["atom_2"]
|
||||
# Only include bonds where BOTH atoms are in the crop
|
||||
if atom_1 in crop_atom_set and atom_2 in crop_atom_set:
|
||||
connections_edge_index.append([atom_1, atom_2])
|
||||
if len(connections_edge_index) > 0:
|
||||
connections_edge_index = np.array(connections_edge_index, dtype=np.int64).T
|
||||
connections_edge_index = all_atom_to_crop_map[connections_edge_index]
|
||||
|
||||
@@ -679,6 +679,8 @@ def parse_polymer( # noqa: C901, PLR0915, PLR0912
|
||||
If the alignment fails.
|
||||
|
||||
"""
|
||||
assert entity_poly_seq is not None
|
||||
|
||||
# Since the polymer object already contains the global idx, we don't need to perform the alignment
|
||||
sequence = [_entity[1] for _entity in entity_poly_seq]
|
||||
|
||||
@@ -911,8 +913,8 @@ def parse_mmcif( # noqa: C901, PLR0915, PLR0912
|
||||
Path to the MMCIF file.
|
||||
|
||||
use_original_res_idx : bool
|
||||
Uses the res_idx for the res_idx in the Residues in the returned structure
|
||||
that was in the mmcif file for each residue instead of using the index in the
|
||||
Uses the res_idx for the res_idx in the Residues in the returned structure
|
||||
that was in the mmcif file for each residue instead of using the index in the
|
||||
seqres that is obtained after aligning the seqres to the sequence of amino acids from the present residues.
|
||||
|
||||
Returns
|
||||
@@ -1152,6 +1154,7 @@ def mmcif_from_block( # noqa: C901, PLR0915, PLR0912
|
||||
mols=mols,
|
||||
moldir=moldir,
|
||||
use_original_res_idx=use_original_res_idx,
|
||||
entity_poly_seq=entity_poly_seq[entity.name],
|
||||
)
|
||||
if parsed_polymer is not None:
|
||||
ensemble_chains[ref_chain_map[subchain_id]] = parsed_polymer
|
||||
@@ -1255,6 +1258,7 @@ def mmcif_from_block( # noqa: C901, PLR0915, PLR0912
|
||||
res_idx,
|
||||
res_num,
|
||||
0, # cyclic period
|
||||
0, # symmetric_group (default, can be overridden via YAML)
|
||||
)
|
||||
)
|
||||
chain_to_idx[chain.name] = asym_id
|
||||
|
||||
@@ -117,6 +117,7 @@ def patch_chain_names(structure: ParsedStructure) -> ParsedStructure:
|
||||
chain["res_idx"],
|
||||
chain["res_num"],
|
||||
chain["cyclic_period"],
|
||||
chain["symmetric_group"],
|
||||
)
|
||||
)
|
||||
chains_new = np.array(chains_new, dtype=Chain)
|
||||
|
||||
@@ -144,6 +144,7 @@ class ParsedChain:
|
||||
cyclic_period: int
|
||||
sequence: Optional[str] = None
|
||||
sampleidx_to_specidx: Optional[np.ndarray] = None
|
||||
symmetric_group: int = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -335,6 +336,12 @@ yaml_keys = [
|
||||
"leaving_atoms",
|
||||
"atom",
|
||||
"use_assembly",
|
||||
"symmetric_group",
|
||||
# Per-residue amino acid constraints
|
||||
"residue_constraints",
|
||||
"position",
|
||||
"allowed",
|
||||
"disallowed",
|
||||
]
|
||||
|
||||
|
||||
@@ -483,6 +490,7 @@ def parse_polymer(
|
||||
components: dict[str, Mol],
|
||||
cyclic: bool,
|
||||
mol_dir: Path,
|
||||
symmetric_group: int = 0,
|
||||
) -> Optional[ParsedChain]:
|
||||
"""Process a sequence into a chain object.
|
||||
|
||||
@@ -630,6 +638,7 @@ def parse_polymer(
|
||||
cyclic_period=cyclic_period,
|
||||
sequence=raw_sequence,
|
||||
sampleidx_to_specidx=sampleidx_to_specidx,
|
||||
symmetric_group=symmetric_group,
|
||||
)
|
||||
|
||||
|
||||
@@ -675,6 +684,182 @@ def parse_range(ranges, c_start=0, c_end=None):
|
||||
return indices
|
||||
|
||||
|
||||
def _normalize_aa_spec(aa_spec) -> list[str]:
|
||||
"""Normalize amino acid specification to a list of individual codes.
|
||||
|
||||
Supports both BoltzGen conventions:
|
||||
- String format: "AGS" (concatenated 1-letter codes, consistent with sequence/binding_types)
|
||||
- List format: [A, G, S] or [ALA, GLY, SER]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
aa_spec : str or list
|
||||
Amino acid specification in string or list format
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
List of individual amino acid codes
|
||||
"""
|
||||
if isinstance(aa_spec, str):
|
||||
# String format: "AGS" -> ["A", "G", "S"]
|
||||
# Handle both "AGS" and "ALA" (single 3-letter code)
|
||||
aa_spec = aa_spec.strip().upper()
|
||||
if len(aa_spec) <= 3 and aa_spec.isalpha():
|
||||
# Could be single 3-letter code like "ALA" or 1-3 single letters like "A", "AG", "AGS"
|
||||
# Check if it's a valid 3-letter code
|
||||
if len(aa_spec) == 3 and aa_spec in ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"]:
|
||||
return [aa_spec]
|
||||
# Otherwise treat as concatenated 1-letter codes
|
||||
return list(aa_spec)
|
||||
else:
|
||||
# Longer string: treat as concatenated 1-letter codes
|
||||
return list(aa_spec)
|
||||
elif isinstance(aa_spec, list):
|
||||
# List format: [A, G, S] or [ALA, GLY, SER]
|
||||
return [str(x).strip().upper() for x in aa_spec]
|
||||
else:
|
||||
raise ValueError(f"Invalid amino acid specification: {aa_spec}")
|
||||
|
||||
|
||||
def _convert_aa_names_to_indices(
|
||||
aa_names: list,
|
||||
canonical_tokens: list[str],
|
||||
prot_letter_to_token: dict[str, str],
|
||||
) -> list[int]:
|
||||
"""Convert amino acid names (1-letter or 3-letter) to canonical token indices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
aa_names : list
|
||||
List of amino acid names (1-letter like 'A' or 3-letter like 'ALA')
|
||||
canonical_tokens : list[str]
|
||||
List of canonical 3-letter amino acid codes
|
||||
prot_letter_to_token : dict[str, str]
|
||||
Mapping from 1-letter to 3-letter codes
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[int]
|
||||
List of indices into canonical_tokens
|
||||
"""
|
||||
indices = []
|
||||
for name in aa_names:
|
||||
name = str(name).strip().upper()
|
||||
# Convert 1-letter to 3-letter if needed
|
||||
if len(name) == 1:
|
||||
if name not in prot_letter_to_token:
|
||||
raise ValueError(f"Unknown amino acid code: {name}")
|
||||
name = prot_letter_to_token[name]
|
||||
|
||||
# Find index in canonical_tokens
|
||||
if name not in canonical_tokens:
|
||||
raise ValueError(f"Unknown amino acid: {name}")
|
||||
indices.append(canonical_tokens.index(name))
|
||||
|
||||
return indices
|
||||
|
||||
|
||||
def parse_residue_constraints(
|
||||
constraints_spec: list,
|
||||
chain_length: int,
|
||||
canonical_tokens: list[str],
|
||||
prot_letter_to_token: dict[str, str],
|
||||
) -> np.ndarray:
|
||||
"""Parse residue_constraints into a per-residue constraint mask.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
constraints_spec : list
|
||||
List of constraint specifications from YAML
|
||||
chain_length : int
|
||||
Length of the chain (number of residues)
|
||||
canonical_tokens : list[str]
|
||||
List of canonical 3-letter amino acid codes (20 AAs)
|
||||
prot_letter_to_token : dict[str, str]
|
||||
Mapping from 1-letter to 3-letter codes
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Shape (chain_length, 20) where:
|
||||
- 0.0 means allowed
|
||||
- 1.0 means disallowed (will be converted to -inf logit bias in model)
|
||||
|
||||
Notes
|
||||
-----
|
||||
Overlapping constraints use **intersection** semantics: if multiple
|
||||
constraints cover the same position, only amino acids allowed by ALL
|
||||
of them survive. For example, ``allowed: AG`` at pos 1..10 followed
|
||||
by ``allowed: GS`` at pos 5..15 results in only G being allowed at
|
||||
positions 5-10 (the intersection of {A,G} and {G,S}).
|
||||
"""
|
||||
num_aa = len(canonical_tokens) # Should be 20
|
||||
constraint_mask = np.zeros((chain_length, num_aa), dtype=np.float32)
|
||||
|
||||
for constraint in constraints_spec:
|
||||
# Parse position(s)
|
||||
position_spec = constraint.get("position")
|
||||
if position_spec is None:
|
||||
raise ValueError("residue_constraints: 'position' is required")
|
||||
|
||||
# Use parse_range to handle single positions and ranges (1-indexed)
|
||||
positions = parse_range(str(position_spec), c_start=0, c_end=chain_length)
|
||||
|
||||
# Validate positions are within bounds
|
||||
for pos in positions:
|
||||
if pos < 0 or pos >= chain_length:
|
||||
raise ValueError(
|
||||
f"Position {pos + 1} is out of bounds for chain of length {chain_length}"
|
||||
)
|
||||
|
||||
# Parse amino acid specification
|
||||
allowed = constraint.get("allowed", None)
|
||||
disallowed = constraint.get("disallowed", None)
|
||||
|
||||
# Validate: cannot have both allowed and disallowed
|
||||
if allowed is not None and disallowed is not None:
|
||||
raise ValueError(
|
||||
f"Position {position_spec}: cannot specify both 'allowed' and 'disallowed'"
|
||||
)
|
||||
|
||||
if allowed is None and disallowed is None:
|
||||
raise ValueError(
|
||||
f"Position {position_spec}: must specify either 'allowed' or 'disallowed'"
|
||||
)
|
||||
|
||||
if allowed is not None:
|
||||
# Whitelist mode: block all except specified AAs
|
||||
# Uses np.maximum to accumulate with existing constraints (intersection semantics):
|
||||
# if a position already has constraints, only AAs allowed by BOTH survive.
|
||||
aa_list = _normalize_aa_spec(allowed)
|
||||
if len(aa_list) == 0:
|
||||
raise ValueError(
|
||||
f"Position {position_spec}: 'allowed' cannot be empty"
|
||||
)
|
||||
aa_indices = _convert_aa_names_to_indices(
|
||||
aa_list, canonical_tokens, prot_letter_to_token
|
||||
)
|
||||
new_block = np.ones(num_aa, dtype=np.float32)
|
||||
for idx in aa_indices:
|
||||
new_block[idx] = 0.0
|
||||
for pos in positions:
|
||||
constraint_mask[pos, :] = np.maximum(constraint_mask[pos, :], new_block)
|
||||
|
||||
elif disallowed is not None:
|
||||
# Blacklist mode: only block specified
|
||||
# Normalize input: supports both "CM" (string) and [C, M] (list)
|
||||
aa_list = _normalize_aa_spec(disallowed)
|
||||
aa_indices = _convert_aa_names_to_indices(
|
||||
aa_list, canonical_tokens, prot_letter_to_token
|
||||
)
|
||||
for pos in positions:
|
||||
for idx in aa_indices:
|
||||
constraint_mask[pos, idx] = 1.0 # Block specified
|
||||
|
||||
return constraint_mask
|
||||
|
||||
|
||||
def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
|
||||
extra_mols: dict[str, Mol] = {}
|
||||
parsed_chains: dict[str, ParsedChain] = {}
|
||||
@@ -766,6 +951,9 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
|
||||
seq[idx] = code
|
||||
|
||||
cyclic = item[entity_type].get("cyclic", False)
|
||||
symmetric_group = item[entity_type].get("symmetric_group", 0)
|
||||
if symmetric_group is None:
|
||||
symmetric_group = 0
|
||||
|
||||
# Parse a polymer
|
||||
parsed_chain = parse_polymer(
|
||||
@@ -776,10 +964,14 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
|
||||
components=mols,
|
||||
cyclic=cyclic,
|
||||
mol_dir=mol_dir,
|
||||
symmetric_group=symmetric_group,
|
||||
)
|
||||
|
||||
# Parse a non-polymer
|
||||
elif (entity_type == "ligand") and "ccd" in (item[entity_type]):
|
||||
symmetric_group = item[entity_type].get("symmetric_group", 0)
|
||||
if symmetric_group is None:
|
||||
symmetric_group = 0
|
||||
seq = item[entity_type]["ccd"]
|
||||
if isinstance(seq, str):
|
||||
seq = [seq]
|
||||
@@ -806,6 +998,7 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
|
||||
type=const.chain_type_ids["NONPOLYMER"],
|
||||
cyclic_period=0,
|
||||
sequence=None,
|
||||
symmetric_group=symmetric_group,
|
||||
)
|
||||
|
||||
assert not item[entity_type].get("cyclic", False), (
|
||||
@@ -813,6 +1006,9 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
|
||||
)
|
||||
|
||||
elif (entity_type == "ligand") and ("smiles" in item[entity_type]):
|
||||
symmetric_group = item[entity_type].get("symmetric_group", 0)
|
||||
if symmetric_group is None:
|
||||
symmetric_group = 0
|
||||
seq = item[entity_type]["smiles"]
|
||||
mol = AllChem.MolFromSmiles(seq)
|
||||
mol = AllChem.AddHs(mol)
|
||||
@@ -848,6 +1044,7 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
|
||||
type=const.chain_type_ids["NONPOLYMER"],
|
||||
cyclic_period=0,
|
||||
sequence=None,
|
||||
symmetric_group=symmetric_group,
|
||||
)
|
||||
|
||||
assert not item[entity_type].get("cyclic", False), (
|
||||
@@ -944,6 +1141,31 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
|
||||
else:
|
||||
ss_type.extend([const.ss_type_ids["UNSPECIFIED"]] * num)
|
||||
|
||||
# Parse residue_constraints for per-residue amino acid restrictions
|
||||
entry = item[entity_type]
|
||||
constraints_spec = entry.get("residue_constraints", None)
|
||||
ids = item[entity_type]["id"]
|
||||
num_chains = 1 if isinstance(ids, str) else len(ids)
|
||||
res_aa_constraint_list = []
|
||||
for _ in range(num_chains):
|
||||
if constraints_spec is not None and entity_type == "protein":
|
||||
res_aa_constraints = parse_residue_constraints(
|
||||
constraints_spec,
|
||||
chain_length=num,
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
prot_letter_to_token=const.prot_letter_to_token,
|
||||
)
|
||||
else:
|
||||
# No constraints: all 20 amino acids allowed (zeros)
|
||||
res_aa_constraints = np.zeros((num, len(const.canonical_tokens)), dtype=np.float32)
|
||||
res_aa_constraint_list.append(res_aa_constraints)
|
||||
|
||||
# Concatenate constraint masks for all chain copies
|
||||
if res_aa_constraint_list:
|
||||
res_aa_constraint_mask = np.concatenate(res_aa_constraint_list, axis=0)
|
||||
else:
|
||||
res_aa_constraint_mask = np.zeros((0, len(const.canonical_tokens)), dtype=np.float32)
|
||||
|
||||
# Add as many parsed_chains as provided ids
|
||||
if entity_type in {"protein", "dna", "rna", "ligand"}:
|
||||
ids = item[entity_type]["id"]
|
||||
@@ -971,6 +1193,7 @@ def parse_entity(item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto):
|
||||
chain_to_msa,
|
||||
fuse_info,
|
||||
ligand_id,
|
||||
res_aa_constraint_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -1202,6 +1425,7 @@ class YamlDesignParser:
|
||||
res_design_mask = np.array([], dtype=bool)
|
||||
res_bind_type = np.array([], dtype=np.int32)
|
||||
ss_type = np.array([], dtype=np.int32)
|
||||
res_aa_constraint_mask = np.zeros((0, len(const.canonical_tokens)), dtype=np.float32)
|
||||
chain_to_msa = {}
|
||||
|
||||
global_asym_id = 0
|
||||
@@ -1225,6 +1449,7 @@ class YamlDesignParser:
|
||||
entity_chain_to_msa,
|
||||
fuse_info,
|
||||
ligand_id,
|
||||
new_res_aa_constraint_mask,
|
||||
) = parse_entity(
|
||||
item, mols, mol_dir, ligand_id, is_msa_custom, is_msa_auto
|
||||
)
|
||||
@@ -1233,6 +1458,7 @@ class YamlDesignParser:
|
||||
extra_mols.update(new_extra_mols)
|
||||
res_bind_type = np.concatenate([res_bind_type, new_res_bind_type])
|
||||
ss_type = np.concatenate([ss_type, new_ss_type])
|
||||
res_aa_constraint_mask = np.concatenate([res_aa_constraint_mask, new_res_aa_constraint_mask], axis=0)
|
||||
for asym_id, (chain_name, chain) in enumerate(
|
||||
parsed_chains.items()
|
||||
):
|
||||
@@ -1260,6 +1486,7 @@ class YamlDesignParser:
|
||||
res_idx,
|
||||
res_num,
|
||||
chain.cyclic_period,
|
||||
chain.symmetric_group,
|
||||
)
|
||||
)
|
||||
chain_to_idx[chain_name] = asym_id
|
||||
@@ -1405,11 +1632,16 @@ class YamlDesignParser:
|
||||
fbind_types,
|
||||
fss_type,
|
||||
file_chain_to_msa,
|
||||
file_chain_symmetric_group,
|
||||
fuse_info,
|
||||
new_extra_mols,
|
||||
file_msa_flag,
|
||||
ligand_id,
|
||||
) = self.parse_file(item, mols, mol_dir, ligand_id, base_file_path)
|
||||
# Apply symmetric_group to chains from file
|
||||
for chain_id, sym_group in file_chain_symmetric_group.items():
|
||||
chain_mask = new_data.chains["name"] == chain_id
|
||||
new_data.chains["symmetric_group"][chain_mask] = sym_group
|
||||
if fuse_info["fuse"]:
|
||||
if fuse_info["target_id"] in total_renaming.keys():
|
||||
fuse_info["target_id"] = total_renaming[
|
||||
@@ -1430,6 +1662,9 @@ class YamlDesignParser:
|
||||
res_design_mask = np.concatenate([res_design_mask, new_design_mask])
|
||||
res_bind_type = np.concatenate([res_bind_type, fbind_types])
|
||||
ss_type = np.concatenate([ss_type, fss_type])
|
||||
# File entities have no residue constraints — pad with zeros (all AAs allowed)
|
||||
file_constraint_mask = np.zeros((len(new_design_mask), len(const.canonical_tokens)), dtype=np.float32)
|
||||
res_aa_constraint_mask = np.concatenate([res_aa_constraint_mask, file_constraint_mask], axis=0)
|
||||
extra_mols.update(new_extra_mols)
|
||||
if len(renaming) > 0:
|
||||
msg = f"\nChain ids conflict with existing chain ids. Renaming with {renaming}. This is for the structure from '{path}'."
|
||||
@@ -1484,18 +1719,65 @@ class YamlDesignParser:
|
||||
ValueError(msg)
|
||||
|
||||
# Map index
|
||||
if all_parsed_chains[c1].sampleidx_to_specidx is not None:
|
||||
if (
|
||||
c1 in all_parsed_chains.keys()
|
||||
and all_parsed_chains[c1].sampleidx_to_specidx is not None
|
||||
):
|
||||
r1 = np.where(all_parsed_chains[c1].sampleidx_to_specidx == r1)[0][
|
||||
0
|
||||
].item()
|
||||
if all_parsed_chains[c2].sampleidx_to_specidx is not None:
|
||||
c1, r1, a1 = atom_idx_map[(c1, r1, a1)]
|
||||
else:
|
||||
# we have a chain coming from a file where we just use the residue index
|
||||
chain = data.chains[data.chains["name"] == c1]
|
||||
c1 = chain["asym_id"].item()
|
||||
|
||||
res_start = chain["res_idx"].item()
|
||||
res_end = chain["res_idx"].item() + chain["res_num"].item()
|
||||
residues = data.residues[res_start:res_end]
|
||||
residue = residues[residues["res_idx"] == r1]
|
||||
r1 = res_start + residue["res_idx"].item()
|
||||
|
||||
atom_start = residue["atom_idx"].item()
|
||||
atom_end = residue["atom_idx"].item() + residue["atom_num"].item()
|
||||
atoms = data.atoms[atom_start:atom_end]
|
||||
assert a1 in atoms["name"], (
|
||||
f"Atom {a1} not found in residue {r1} of chain {c1}"
|
||||
)
|
||||
a1 = np.where(atoms["name"] == a1)[0].item()
|
||||
a1 = (
|
||||
residue["atom_idx"].item() + a1
|
||||
) # THIS STILL NEEDS TO BE CORRECTED
|
||||
|
||||
if (
|
||||
c2 in all_parsed_chains.keys()
|
||||
and all_parsed_chains[c2].sampleidx_to_specidx is not None
|
||||
):
|
||||
r2 = np.where(all_parsed_chains[c2].sampleidx_to_specidx == r2)[0][
|
||||
0
|
||||
].item()
|
||||
c2, r2, a2 = atom_idx_map[(c2, r2, a2)]
|
||||
else:
|
||||
# we have a chain coming from a file where we just use the residue index
|
||||
chain = data.chains[data.chains["name"] == c2]
|
||||
c2 = chain["asym_id"].item()
|
||||
|
||||
c1, r1, a1 = atom_idx_map[(c1, r1, a1)]
|
||||
c2, r2, a2 = atom_idx_map[(c2, r2, a2)]
|
||||
res_start = chain["res_idx"].item()
|
||||
res_end = chain["res_idx"].item() + chain["res_num"].item()
|
||||
residues = data.residues[res_start:res_end]
|
||||
residue = residues[residues["res_idx"] == r2]
|
||||
r2 = res_start + residue["res_idx"].item()
|
||||
|
||||
atom_start = residue["atom_idx"].item()
|
||||
atom_end = residue["atom_idx"].item() + residue["atom_num"].item()
|
||||
atoms = data.atoms[atom_start:atom_end]
|
||||
assert a2 in atoms["name"], (
|
||||
f"Atom {a2} not found in residue {r2} of chain {c2}"
|
||||
)
|
||||
a2 = np.where(atoms["name"] == a2)[0].item()
|
||||
a2 = (
|
||||
residue["atom_idx"].item() + a2
|
||||
) # THIS STILL NEEDS TO BE CORRECTED
|
||||
covalents.append((c1, c2, r1, r2, a1, a2))
|
||||
elif "total_len" in constraints:
|
||||
continue
|
||||
@@ -1552,6 +1834,7 @@ class YamlDesignParser:
|
||||
res_structure_groups=structure_groups,
|
||||
res_binding_type=res_bind_type,
|
||||
res_ss_types=ss_type,
|
||||
res_aa_constraint_mask=res_aa_constraint_mask,
|
||||
)
|
||||
DesignInfo.is_valid(design_info)
|
||||
|
||||
@@ -1651,6 +1934,7 @@ class YamlDesignParser:
|
||||
|
||||
# Construct include mask from include entries
|
||||
file_chain_to_msa = {}
|
||||
file_chain_symmetric_group = {}
|
||||
if isinstance(include, str):
|
||||
if include == "all":
|
||||
include_mask = np.ones(num_res)
|
||||
@@ -1665,12 +1949,15 @@ class YamlDesignParser:
|
||||
msg = f"Misspecified chain in include with missing 'id' for file with path {path}."
|
||||
raise ValueError(msg)
|
||||
chain_id = chain["id"]
|
||||
|
||||
if chain_id not in structure.chains["name"]:
|
||||
msg = f"Specified chain id {chain_id} not in file {path}."
|
||||
raise ValueError(msg)
|
||||
|
||||
if "msa" in chain:
|
||||
file_chain_to_msa[chain_id] = chain["msa"]
|
||||
if "symmetric_group" in chain:
|
||||
file_chain_symmetric_group[chain_id] = chain["symmetric_group"]
|
||||
data_chain = structure.chains[structure.chains["name"] == chain_id]
|
||||
|
||||
c_start = data_chain["res_idx"].item()
|
||||
@@ -1970,8 +2257,12 @@ class YamlDesignParser:
|
||||
fss_type[indices] = const.ss_type_ids["SHEET"]
|
||||
|
||||
# Parse and apply design insertions
|
||||
# First pass: collect insertions and coordinate lengths for symmetric chains
|
||||
if design_insertions is not None:
|
||||
num_inserted = defaultdict(int)
|
||||
# Group insertions by (symmetric_group, res_index) to coordinate variable lengths
|
||||
symmetric_length_cache = {} # (sym_group, res_index) -> sampled_length
|
||||
|
||||
for list_element in design_insertions:
|
||||
insertion = list_element["insertion"]
|
||||
if "id" not in insertion:
|
||||
@@ -1985,10 +2276,24 @@ class YamlDesignParser:
|
||||
res_index += num_inserted[chain_id]
|
||||
ss_insert_type = insertion.get("secondary_structure", "UNSPECIFIED")
|
||||
|
||||
num_residues_spec = insertion["num_residues"]
|
||||
num_residues_range = parse_range(num_residues_spec)
|
||||
|
||||
# Check if this chain has a symmetric_group
|
||||
chain_sym_group = file_chain_symmetric_group.get(chain_id, 0)
|
||||
|
||||
# If chain has symmetric_group > 0, coordinate length with other symmetric chains
|
||||
if chain_sym_group > 0:
|
||||
cache_key = (chain_sym_group, res_index, str(num_residues_spec))
|
||||
if cache_key in symmetric_length_cache:
|
||||
num_residues = symmetric_length_cache[cache_key]
|
||||
else:
|
||||
num_residues = np.random.choice(num_residues_range).item()
|
||||
symmetric_length_cache[cache_key] = num_residues
|
||||
else:
|
||||
num_residues = np.random.choice(num_residues_range).item()
|
||||
|
||||
# We add +1 because the parse_range function is usually used for indexing where we then convert the 1 based inputs to 0 indexing
|
||||
num_residues = insertion["num_residues"]
|
||||
num_residues = parse_range(num_residues)
|
||||
num_residues = np.random.choice(num_residues).item()
|
||||
num_residues += 1
|
||||
num_inserted[chain_id] += num_residues
|
||||
|
||||
@@ -2122,6 +2427,7 @@ class YamlDesignParser:
|
||||
fbind_types,
|
||||
fss_type,
|
||||
file_chain_to_msa,
|
||||
file_chain_symmetric_group,
|
||||
fuse_info,
|
||||
extra_mols,
|
||||
file_msa_flag,
|
||||
|
||||
@@ -50,11 +50,13 @@ class TokenData:
|
||||
design: bool
|
||||
binding_type: int
|
||||
structure_group: int
|
||||
aa_constraint_mask: np.ndarray # Shape: (20,) - per-residue AA constraints
|
||||
ccd: np.ndarray
|
||||
target_msa_mask: bool
|
||||
design_ss_mask: bool
|
||||
feature_asym_id: int
|
||||
feature_res_idx: int
|
||||
symmetric_group: int
|
||||
|
||||
|
||||
def compute_frame(
|
||||
@@ -265,11 +267,13 @@ class Tokenizer:
|
||||
design=False,
|
||||
binding_type=const.binding_type_ids["UNSPECIFIED"],
|
||||
structure_group=0,
|
||||
aa_constraint_mask=np.zeros(20, dtype=np.float32),
|
||||
ccd=convert_ccd(res["name"]),
|
||||
target_msa_mask=0,
|
||||
design_ss_mask=0,
|
||||
feature_asym_id=chain["asym_id"],
|
||||
feature_res_idx=res["res_idx"],
|
||||
symmetric_group=chain["symmetric_group"],
|
||||
)
|
||||
token_data.append(tokendata_to_tuple(token))
|
||||
|
||||
@@ -329,11 +333,13 @@ class Tokenizer:
|
||||
design=False,
|
||||
binding_type=const.binding_type_ids["UNSPECIFIED"],
|
||||
structure_group=0,
|
||||
aa_constraint_mask=np.zeros(20, dtype=np.float32),
|
||||
ccd=convert_ccd(res["name"]),
|
||||
target_msa_mask=0,
|
||||
design_ss_mask=0,
|
||||
feature_asym_id=chain["asym_id"],
|
||||
feature_res_idx=res["res_idx"],
|
||||
symmetric_group=chain["symmetric_group"],
|
||||
)
|
||||
token_data.append(tokendata_to_tuple(token))
|
||||
|
||||
@@ -387,11 +393,13 @@ class Tokenizer:
|
||||
design=False,
|
||||
binding_type=const.binding_type_ids["UNSPECIFIED"],
|
||||
structure_group=0,
|
||||
aa_constraint_mask=np.zeros(20, dtype=np.float32),
|
||||
ccd=convert_ccd(res["name"]),
|
||||
target_msa_mask=0,
|
||||
design_ss_mask=0,
|
||||
feature_asym_id=chain["asym_id"],
|
||||
feature_res_idx=res["res_idx"],
|
||||
symmetric_group=chain["symmetric_group"],
|
||||
)
|
||||
token_data.append(tokendata_to_tuple(token))
|
||||
|
||||
|
||||
@@ -279,6 +279,22 @@ def compute_ptms(logits, x_preds, feats, multiplicity):
|
||||
dim=1,
|
||||
).values
|
||||
|
||||
# iPTM between designed residues and any token from a different chain
|
||||
design_residue_iptm_mask = (
|
||||
maski[:, :, None]
|
||||
* mask_pad[:, None, :]
|
||||
* mask_pad[:, :, None]
|
||||
* (asym_id[:, None, :] != asym_id[:, :, None])
|
||||
* (
|
||||
is_design_token[:, :, None] + is_design_token[:, None, :]
|
||||
).clamp(max=1)
|
||||
)
|
||||
design_residue_iptm = torch.max(
|
||||
torch.sum(tm_expected_value * design_residue_iptm_mask, dim=-1)
|
||||
/ (torch.sum(design_residue_iptm_mask, dim=-1) + 1e-5),
|
||||
dim=1,
|
||||
).values
|
||||
|
||||
design_ptm_mask = (
|
||||
maski[:, :, None]
|
||||
* mask_pad[:, None, :]
|
||||
@@ -394,6 +410,7 @@ def compute_ptms(logits, x_preds, feats, multiplicity):
|
||||
protein_iptm,
|
||||
chain_pair_iptm,
|
||||
design_to_target_iptm,
|
||||
design_residue_iptm,
|
||||
design_iptm,
|
||||
design_iiptm,
|
||||
target_ptm,
|
||||
|
||||
@@ -84,6 +84,7 @@ class PairformerModule(nn.Module):
|
||||
pair_mask,
|
||||
chunk_size_tri_attn,
|
||||
use_kernels=use_kernels,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
s, z = layer(
|
||||
|
||||
@@ -177,7 +177,7 @@ def compute_bond_loss(pred_atom_coords, true_coords, feats):
|
||||
pred_bond_coords = pred_atom_coords[
|
||||
:, feats["connections_edge_index"][index_batch]
|
||||
]
|
||||
true_bond_coords = pred_atom_coords[
|
||||
true_bond_coords = true_coords[
|
||||
:, feats["connections_edge_index"][index_batch]
|
||||
]
|
||||
pred_bond_lengths = torch.linalg.norm(
|
||||
@@ -190,4 +190,4 @@ def compute_bond_loss(pred_atom_coords, true_coords, feats):
|
||||
num_bonds += pred_bond_lengths.shape[1]
|
||||
if num_bonds > 0:
|
||||
bond_loss /= num_bonds
|
||||
return bond_loss, num_bonds
|
||||
return bond_loss.mean(), num_bonds
|
||||
|
||||
@@ -557,7 +557,7 @@ class ConfidenceHeads(nn.Module):
|
||||
interaction_pae=interaction_pae,
|
||||
min_design_to_target_pae=min_design_to_target_pae
|
||||
if feats["design_mask"].sum() > 0
|
||||
else torch.nan,
|
||||
else torch.tensor([torch.nan]),
|
||||
min_interaction_pae=min_interaction_pae,
|
||||
)
|
||||
|
||||
@@ -581,6 +581,7 @@ class ConfidenceHeads(nn.Module):
|
||||
protein_iptm,
|
||||
pair_chains_iptm,
|
||||
design_to_target_iptm,
|
||||
design_residue_iptm,
|
||||
design_iptm,
|
||||
design_iiptm,
|
||||
target_ptm,
|
||||
@@ -596,6 +597,7 @@ class ConfidenceHeads(nn.Module):
|
||||
out_dict["protein_iptm"] = protein_iptm
|
||||
out_dict["pair_chains_iptm"] = pair_chains_iptm
|
||||
out_dict["design_to_target_iptm"] = design_to_target_iptm
|
||||
out_dict["design_residue_iptm"] = design_residue_iptm
|
||||
out_dict["design_iptm"] = design_iptm
|
||||
out_dict["design_iiptm"] = design_iiptm
|
||||
out_dict["target_ptm"] = target_ptm
|
||||
@@ -612,6 +614,7 @@ class ConfidenceHeads(nn.Module):
|
||||
out_dict["protein_iptm"] = torch.zeros_like(complex_plddt)
|
||||
out_dict["pair_chains_iptm"] = torch.zeros_like(complex_plddt)
|
||||
out_dict["design_to_target_iptm"] = torch.zeros_like(complex_plddt)
|
||||
out_dict["design_residue_iptm"] = torch.zeros_like(complex_plddt)
|
||||
out_dict["design_iptm"] = torch.zeros_like(complex_plddt)
|
||||
out_dict["design_iiptm"] = torch.zeros_like(complex_plddt)
|
||||
out_dict["target_ptm"] = torch.zeros_like(complex_plddt)
|
||||
|
||||
@@ -829,7 +829,7 @@ class AtomDiffusion(Module):
|
||||
|
||||
if add_bond_loss:
|
||||
bond_loss, num_bonds = compute_bond_loss(
|
||||
pred_atom_coords=out_dict["sample_atom_coords"].float(),
|
||||
pred_atom_coords=out_dict["denoised_atom_coords"].float(),
|
||||
true_coords=atom_coords_aligned_ground_truth,
|
||||
feats=feats,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Dict, Tuple, List
|
||||
from typing import Dict, Tuple, List, Optional
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@@ -38,6 +39,75 @@ def softmax_dropout(
|
||||
)
|
||||
|
||||
|
||||
def build_constraint_logit_mask(
|
||||
num_nodes: int,
|
||||
aa_constraint_mask: Optional[Tensor],
|
||||
inverse_fold_restriction: list[str],
|
||||
canonical_tokens: list[str],
|
||||
inf: float,
|
||||
device: torch.device,
|
||||
) -> Tensor:
|
||||
"""Build per-position inverse-folding logit mask.
|
||||
|
||||
The mask uses additive logit bias semantics:
|
||||
0.0 = allowed, -inf = disallowed.
|
||||
"""
|
||||
num_aa = len(canonical_tokens)
|
||||
has_per_residue_constraints = False
|
||||
|
||||
if aa_constraint_mask is None:
|
||||
per_residue_blocked = torch.zeros(
|
||||
num_nodes, num_aa, dtype=torch.bool, device=device
|
||||
)
|
||||
else:
|
||||
expected_shape = (num_nodes, num_aa)
|
||||
if aa_constraint_mask.shape != expected_shape:
|
||||
warnings.warn(
|
||||
f"aa_constraint_mask shape mismatch: "
|
||||
f"got {aa_constraint_mask.shape}, expected {expected_shape}. "
|
||||
f"Ignoring per-residue constraints.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
per_residue_blocked = torch.zeros(
|
||||
num_nodes, num_aa, dtype=torch.bool, device=device
|
||||
)
|
||||
else:
|
||||
has_per_residue_constraints = True
|
||||
per_residue_blocked = aa_constraint_mask.to(device=device) > 0
|
||||
|
||||
global_blocked = torch.zeros(num_aa, dtype=torch.bool, device=device)
|
||||
for res_type in inverse_fold_restriction:
|
||||
global_blocked[canonical_tokens.index(res_type)] = True
|
||||
|
||||
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
|
||||
all_blocked = combined_blocked.all(dim=1)
|
||||
|
||||
if all_blocked.any() and has_per_residue_constraints:
|
||||
blocked_positions = torch.where(all_blocked)[0].tolist()
|
||||
warnings.warn(
|
||||
f"Positions {blocked_positions} have all amino acids blocked by the "
|
||||
f"combination of per-residue constraints and '--inverse_fold_avoid'. "
|
||||
f"Relaxing per-residue constraints for these positions.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
per_residue_blocked = per_residue_blocked.clone()
|
||||
per_residue_blocked[all_blocked] = False
|
||||
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
|
||||
|
||||
still_all_blocked = combined_blocked.all(dim=1)
|
||||
if still_all_blocked.any():
|
||||
blocked_positions = torch.where(still_all_blocked)[0].tolist()
|
||||
raise ValueError(
|
||||
f"Inverse folding has no valid amino acids at token positions "
|
||||
f"{blocked_positions} after applying '--inverse_fold_avoid'. "
|
||||
f"Reduce global restrictions to keep at least one amino acid."
|
||||
)
|
||||
|
||||
return combined_blocked.to(dtype=torch.float32) * (-inf)
|
||||
|
||||
|
||||
class MLPAttnGNN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -464,6 +534,7 @@ class InverseFoldingDecoder(nn.Module):
|
||||
num_decoder_layers: int = 3,
|
||||
inverse_fold_restriction: List[str] = [],
|
||||
sampling_temperature: float = 0.1,
|
||||
tie_symmetric_sequences: bool = True,
|
||||
**kwargs, # old checkpoint compatibility
|
||||
):
|
||||
super().__init__()
|
||||
@@ -485,6 +556,7 @@ class InverseFoldingDecoder(nn.Module):
|
||||
self.num_decoder_layers = num_decoder_layers
|
||||
self.inverse_fold_restriction = inverse_fold_restriction
|
||||
self.sampling_temperature = sampling_temperature
|
||||
self.tie_symmetric_sequences = tie_symmetric_sequences
|
||||
|
||||
self.decoder_layers = nn.ModuleList()
|
||||
self.inf = 10**6
|
||||
@@ -507,6 +579,43 @@ class InverseFoldingDecoder(nn.Module):
|
||||
# init the output of the predictor to be zero
|
||||
self.predictor.weight.zero_()
|
||||
|
||||
def _build_symmetric_groups(
|
||||
self,
|
||||
feats: Dict[str, Tensor],
|
||||
valid_mask: Tensor,
|
||||
design_mask: Tensor,
|
||||
) -> Tuple[Dict[int, List[int]], Dict[int, int]]:
|
||||
"""Build mapping from positions to symmetric groups for homomer tying."""
|
||||
from collections import defaultdict
|
||||
|
||||
symmetric_group = feats["symmetric_group"][valid_mask]
|
||||
res_idx = feats["feature_residue_index"][valid_mask]
|
||||
|
||||
# Group by (symmetric_group, res_idx) - positions that should share sequence
|
||||
key_to_positions = defaultdict(list)
|
||||
|
||||
num_nodes = symmetric_group.shape[0]
|
||||
for i in range(num_nodes):
|
||||
if design_mask[i]:
|
||||
group = symmetric_group[i].item()
|
||||
if group > 0: # 0 = no group
|
||||
key = (group, res_idx[i].item())
|
||||
key_to_positions[key].append(i)
|
||||
|
||||
# Build symmetric groups (only groups with >1 member)
|
||||
sym_groups = {}
|
||||
position_to_group = {}
|
||||
group_id = 0
|
||||
|
||||
for positions in key_to_positions.values():
|
||||
if len(positions) > 1:
|
||||
sym_groups[group_id] = positions
|
||||
for pos in positions:
|
||||
position_to_group[pos] = group_id
|
||||
group_id += 1
|
||||
|
||||
return sym_groups, position_to_group
|
||||
|
||||
def forward(self, s, z, edge_idx, valid_mask, feats):
|
||||
with torch.no_grad():
|
||||
src_idx, dst_idx = edge_idx[0], edge_idx[1]
|
||||
@@ -549,14 +658,17 @@ class InverseFoldingDecoder(nn.Module):
|
||||
f"num_design: {num_design}, num_not_design: {num_not_design}"
|
||||
)
|
||||
|
||||
# Create restriction mask that sets the probability of excluded residues to 0
|
||||
if len(self.inverse_fold_restriction) > 0:
|
||||
restriction_mask = torch.zeros(len(const.canonical_tokens), device=s.device)
|
||||
for res_type in self.inverse_fold_restriction:
|
||||
restriction_mask[const.canonical_tokens.index(res_type)] = -self.inf
|
||||
restriction_mask = restriction_mask.unsqueeze(0)
|
||||
else:
|
||||
restriction_mask = torch.zeros(len(const.canonical_tokens), device=s.device)
|
||||
constraint_mask = None
|
||||
if "aa_constraint_mask" in feats:
|
||||
constraint_mask = feats["aa_constraint_mask"][valid_mask]
|
||||
per_residue_mask = build_constraint_logit_mask(
|
||||
num_nodes=num_nodes,
|
||||
aa_constraint_mask=constraint_mask,
|
||||
inverse_fold_restriction=self.inverse_fold_restriction,
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=self.inf,
|
||||
device=s.device,
|
||||
)
|
||||
|
||||
order = torch.randperm(num_nodes, device=s.device).cpu().numpy().tolist()
|
||||
# Non-design residues are not sampled and used as the condition. So the order should filter them out.
|
||||
@@ -572,33 +684,64 @@ class InverseFoldingDecoder(nn.Module):
|
||||
else:
|
||||
decoded_seq = torch.zeros(num_nodes, const.num_tokens, device=s.device)
|
||||
logits = torch.zeros(num_nodes, const.num_tokens, device=s.device)
|
||||
|
||||
# Build symmetric groups for homomer tying
|
||||
if self.tie_symmetric_sequences and "symmetric_group" in feats:
|
||||
sym_groups, position_to_group = self._build_symmetric_groups(
|
||||
feats, valid_mask, design_mask
|
||||
)
|
||||
sampled = set(torch.where(~design_mask)[0].cpu().numpy().tolist()) if num_not_design > 0 else set()
|
||||
else:
|
||||
sym_groups, position_to_group = {}, {}
|
||||
sampled = set()
|
||||
|
||||
src_idx, dst_idx = edge_idx[0], edge_idx[1]
|
||||
|
||||
# decoding in order
|
||||
for i in order:
|
||||
s_i = s[i : i + 1]
|
||||
edge_mask_i = dst_idx == i
|
||||
z_i = z[edge_mask_i]
|
||||
src_idx_i = src_idx[edge_mask_i]
|
||||
res_type = decoded_seq[src_idx_i]
|
||||
res_rep = self.seq_to_s(res_type)
|
||||
neighbors_rep_i = torch.concat([z_i, s[src_idx_i] + res_rep], dim=-1)
|
||||
# Skip if already sampled (symmetric position was processed earlier)
|
||||
if self.tie_symmetric_sequences and i in sampled:
|
||||
continue
|
||||
|
||||
for layer in self.decoder_layers:
|
||||
s_i = layer.sample(s_i, neighbors_rep_i)
|
||||
# Get symmetric positions (or just [i] if no symmetry)
|
||||
if self.tie_symmetric_sequences and i in position_to_group:
|
||||
positions = sym_groups[position_to_group[i]]
|
||||
else:
|
||||
positions = [i]
|
||||
|
||||
logits_i = self.predictor(s_i)
|
||||
logits[i] = logits_i
|
||||
# Aggregate logits from all symmetric positions
|
||||
aggregated_logits = None
|
||||
for pos in positions:
|
||||
s_pos = s[pos : pos + 1]
|
||||
edge_mask_pos = dst_idx == pos
|
||||
z_pos = z[edge_mask_pos]
|
||||
src_idx_pos = src_idx[edge_mask_pos]
|
||||
res_type_pos = decoded_seq[src_idx_pos]
|
||||
res_rep = self.seq_to_s(res_type_pos)
|
||||
neighbors_rep_pos = torch.concat([z_pos, s[src_idx_pos] + res_rep], dim=-1)
|
||||
|
||||
s_temp = s_pos
|
||||
for layer in self.decoder_layers:
|
||||
s_temp = layer.sample(s_temp, neighbors_rep_pos)
|
||||
|
||||
logits_pos = self.predictor(s_temp)
|
||||
if aggregated_logits is None:
|
||||
aggregated_logits = logits_pos
|
||||
else:
|
||||
aggregated_logits = aggregated_logits + logits_pos
|
||||
|
||||
# Average logits across symmetric positions
|
||||
aggregated_logits = aggregated_logits / len(positions)
|
||||
|
||||
# Sample from aggregated logits
|
||||
pred_canonical = (
|
||||
logits_i[
|
||||
aggregated_logits[
|
||||
:,
|
||||
const.canonicals_offset : len(const.canonical_tokens)
|
||||
+ const.canonicals_offset,
|
||||
]
|
||||
+ restriction_mask
|
||||
+ per_residue_mask[i : i + 1] # Position-specific mask
|
||||
)
|
||||
ids_canonical = torch.argmax(pred_canonical, dim=-1)
|
||||
if self.sampling_temperature is None:
|
||||
ids_canonical = torch.argmax(pred_canonical, dim=-1)
|
||||
else:
|
||||
@@ -609,7 +752,13 @@ class InverseFoldingDecoder(nn.Module):
|
||||
|
||||
ids = ids_canonical + const.canonicals_offset
|
||||
pred_one_hot = F.one_hot(ids, num_classes=const.num_tokens)
|
||||
decoded_seq[i] = pred_one_hot
|
||||
|
||||
# Apply same residue to all symmetric positions
|
||||
for pos in positions:
|
||||
decoded_seq[pos] = pred_one_hot
|
||||
logits[pos] = aggregated_logits
|
||||
if self.tie_symmetric_sequences:
|
||||
sampled.add(pos)
|
||||
|
||||
n_tokens = valid_mask.shape[1]
|
||||
res_type = torch.zeros(1, n_tokens, self.num_res_type, device=s.device)
|
||||
|
||||
@@ -88,6 +88,9 @@ class BoltzMasker(Module):
|
||||
new["disto_target"] = clone["disto_target"]
|
||||
new["token_pair_mask"] = clone["token_pair_mask"]
|
||||
new["binding_type"] = clone["binding_type"]
|
||||
# Per-residue amino acid constraints for inverse folding
|
||||
if "aa_constraint_mask" in clone:
|
||||
new["aa_constraint_mask"] = clone["aa_constraint_mask"]
|
||||
new["structure_group"] = clone["structure_group"]
|
||||
new["cyclic"] = clone["cyclic"]
|
||||
new["modified"] = clone["modified"]
|
||||
@@ -101,6 +104,7 @@ class BoltzMasker(Module):
|
||||
new["res_type_clone"] = clone["res_type_clone"]
|
||||
new["feature_residue_index"] = clone["feature_residue_index"]
|
||||
new["feature_asym_id"] = clone["feature_asym_id"]
|
||||
new["symmetric_group"] = clone["symmetric_group"]
|
||||
new["token_to_res"] = clone["token_to_res"]
|
||||
new["token_bonds"] = torch.zeros_like(
|
||||
clone["token_bonds"]
|
||||
|
||||
@@ -286,6 +286,7 @@ class TemplateModule(nn.Module):
|
||||
)
|
||||
self.u_proj = nn.Linear(template_dim, token_z, bias=False)
|
||||
|
||||
|
||||
if miniformer_blocks:
|
||||
self.pairformer = MiniformerNoSeqModule(
|
||||
template_dim,
|
||||
@@ -330,7 +331,6 @@ class TemplateModule(nn.Module):
|
||||
|
||||
"""
|
||||
# Load relevant features
|
||||
asym_id = feats["asym_id"]
|
||||
res_type = feats["template_restype"]
|
||||
frame_rot = feats["template_frame_rot"]
|
||||
frame_t = feats["template_frame_t"]
|
||||
@@ -338,6 +338,7 @@ class TemplateModule(nn.Module):
|
||||
cb_coords = feats["template_cb"]
|
||||
ca_coords = feats["template_ca"]
|
||||
cb_mask = feats["template_mask_cb"]
|
||||
visibility_ids = feats["visibility_ids"]
|
||||
template_mask = feats["template_mask"].any(dim=2).float()
|
||||
num_templates = template_mask.sum(dim=1)
|
||||
num_templates = num_templates.clamp(min=1)
|
||||
@@ -349,10 +350,11 @@ class TemplateModule(nn.Module):
|
||||
b_cb_mask = b_cb_mask[..., None]
|
||||
b_frame_mask = b_frame_mask[..., None]
|
||||
|
||||
# Compute asym mask, template features only attend within the same chain
|
||||
# Compute visibility mask, template features only attend within the same visibility
|
||||
B, T = res_type.shape[:2] # noqa: N806
|
||||
asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float()
|
||||
asym_mask = asym_mask[:, None].expand(-1, T, -1, -1)
|
||||
tmlp_pair_mask = (
|
||||
visibility_ids[:, :, :, None] == visibility_ids[:, :, None, :]
|
||||
).float()
|
||||
|
||||
# Compute template features
|
||||
with torch.autocast(device_type="cuda", enabled=False):
|
||||
@@ -375,7 +377,8 @@ class TemplateModule(nn.Module):
|
||||
# Concatenate input features
|
||||
a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
|
||||
a_tij = torch.cat(a_tij, dim=-1)
|
||||
a_tij = a_tij * asym_mask.unsqueeze(-1)
|
||||
a_tij = a_tij * tmlp_pair_mask.unsqueeze(-1)
|
||||
|
||||
res_type_i = res_type[:, :, :, None]
|
||||
res_type_j = res_type[:, :, None, :]
|
||||
res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)
|
||||
@@ -669,6 +672,7 @@ class MSAModule(nn.Module):
|
||||
chunk_size_outer_product,
|
||||
chunk_size_tri_attn,
|
||||
use_kernels=use_kernels,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
z, m = self.layers[i](
|
||||
|
||||
@@ -30,7 +30,7 @@ data:
|
||||
fail_if_no_designs: true
|
||||
output_dir: null
|
||||
|
||||
keys_dict_out: ["min_interaction_pae", "min_design_to_target_pae", "interaction_pae", "ligand_iptm", "protein_iptm", "iptm", "design_iptm", "design_iiptm", "design_to_target_iptm", "design_ptm", "target_ptm", "ptm"]
|
||||
keys_dict_out: ["min_interaction_pae", "min_design_to_target_pae", "interaction_pae", "ligand_iptm", "protein_iptm", "iptm", "design_iptm", "design_iiptm", "design_to_target_iptm", "design_residue_iptm", "design_ptm", "target_ptm", "ptm"]
|
||||
|
||||
writer:
|
||||
_target_: boltzgen.task.predict.writer.FoldingWriter
|
||||
|
||||
484
src/boltzgen/resources/config/train/boltzgen.no_distillation.yaml
Executable file
484
src/boltzgen/resources/config/train/boltzgen.no_distillation.yaml
Executable file
@@ -0,0 +1,484 @@
|
||||
_target_: boltzgen.task.train.train.Training
|
||||
|
||||
trainer:
|
||||
accelerator: gpu
|
||||
devices: 8
|
||||
precision: bf16-mixed
|
||||
gradient_clip_val: 10.0
|
||||
accumulate_grad_batches: 1
|
||||
max_epochs: -1
|
||||
num_sanity_val_steps: 3
|
||||
log_every_n_steps: 1
|
||||
|
||||
wandb:
|
||||
group: boltzgen
|
||||
project: boltzgen
|
||||
entity: yourwandb
|
||||
|
||||
name: a_big_run_resume3
|
||||
slurm: true
|
||||
output: workdir
|
||||
strict_loading: false
|
||||
resume: null
|
||||
pretrained: null
|
||||
debug: false
|
||||
save_every_n_train_steps: 2500
|
||||
disable_checkpoint: false
|
||||
matmul_precision: null
|
||||
save_top_k: -1
|
||||
|
||||
data:
|
||||
datasets:
|
||||
- _target_: boltzgen.task.train.data.DatasetConfig
|
||||
target_dir: ./training_data/targets
|
||||
msa_dir: ./training_data/msa
|
||||
prob: 1.0
|
||||
filters:
|
||||
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
|
||||
min_chains: 1
|
||||
max_chains: 300
|
||||
- _target_: boltzgen.data.filter.dynamic.date.DateFilter
|
||||
date: "2023-06-01"
|
||||
ref: released
|
||||
- _target_: boltzgen.data.filter.dynamic.resolution.ResolutionFilter
|
||||
resolution: 9.0
|
||||
sampler:
|
||||
_target_: boltzgen.data.sample.cluster.ClusterSampler
|
||||
cropper:
|
||||
_target_: boltzgen.data.crop.multimer.MultimerCropper
|
||||
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
|
||||
split: ./src/boltzgen/resources/splits/validation_ids_boltz2_all.txt
|
||||
symmetry_correction: false
|
||||
val_group: "RCSB"
|
||||
|
||||
tokenizer:
|
||||
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
|
||||
atomize_modified_residues: false
|
||||
featurizer:
|
||||
_target_: boltzgen.data.feature.featurizer.Featurizer
|
||||
moldir: ./training_data/mols
|
||||
max_tokens: 512
|
||||
max_atoms: 5120
|
||||
max_seqs: 4096
|
||||
pad_to_max_tokens: true
|
||||
pad_to_max_atoms: true
|
||||
pad_to_max_seqs: true
|
||||
samples_per_epoch: 100000
|
||||
batch_size: 1
|
||||
num_workers: 2
|
||||
random_seed: 42
|
||||
pin_memory: false
|
||||
overfit: null
|
||||
return_train_symmetries: false
|
||||
return_val_symmetries: false
|
||||
|
||||
|
||||
atoms_per_window_queries: 32
|
||||
min_dist: 2.0
|
||||
max_dist: 22.0
|
||||
num_bins: 64
|
||||
single_sequence_prop_training: 0.1
|
||||
msa_sampling_training: true
|
||||
|
||||
|
||||
# Design
|
||||
design: true
|
||||
backbone_only: false
|
||||
atom14: true
|
||||
atom37: false
|
||||
selector:
|
||||
_target_: boltzgen.data.select.protein.ProteinSelector
|
||||
design_neighborhood_sizes: [2, 4, 6,8,10,12,14,16,18]
|
||||
substructure_neighborhood_sizes: [2,4,6,8,10,12,24]
|
||||
structure_condition_prob: 0.4
|
||||
distance_noise_std: 1
|
||||
run_selection: true
|
||||
specify_binding_sites: true
|
||||
ss_condition_prob: 0.1
|
||||
select_all: false
|
||||
|
||||
# Design datasets
|
||||
monomer_split: ./src/boltzgen/resources/splits/val_monomers_boltzgen_min50_max220.txt
|
||||
monomer_target_dir: ./training_data/targets
|
||||
monomer_target_structure_condition: true
|
||||
monomer_seq_len: 100
|
||||
|
||||
ligand_split: ./src/boltzgen/resources/splits/val_ccd_pdb_pairs_boltzgen.txt
|
||||
ligand_target_dir: ./training_data/targets
|
||||
ligand_seq_len: 100
|
||||
|
||||
|
||||
model:
|
||||
_target_: boltzgen.model.models.boltz.Boltz
|
||||
atom_s: 128
|
||||
atom_z: 16
|
||||
token_s: 384
|
||||
token_z: 128
|
||||
num_bins: 64
|
||||
atom_feature_dim: 388
|
||||
atoms_per_window_queries: 32
|
||||
atoms_per_window_keys: 128
|
||||
use_miniformer: false
|
||||
ema: true
|
||||
ema_decay: 0.999
|
||||
exclude_ions_from_lddt: true
|
||||
num_val_datasets: 1 # New
|
||||
ignore_ckpt_shape_mismatch: false # New
|
||||
aggregate_distogram: true # New
|
||||
bond_type_feature: true
|
||||
predict_bfactor: true
|
||||
checkpoint_diffusion_conditioning: true
|
||||
use_kernels: true
|
||||
|
||||
|
||||
validators:
|
||||
- _target_: boltzgen.model.validation.design.DesignValidator
|
||||
val_names: ["RCSB"]
|
||||
confidence_prediction: ${model.confidence_prediction}
|
||||
atom14: ${data.atom14}
|
||||
atom37: ${data.atom37}
|
||||
|
||||
masker_args:
|
||||
mask: true
|
||||
mask_backbone: false
|
||||
mask_disto: true
|
||||
|
||||
embedder_args:
|
||||
atom_encoder_depth: 3
|
||||
atom_encoder_heads: 4
|
||||
add_mol_type_feat: true
|
||||
add_method_conditioning: true
|
||||
add_modified_flag: true
|
||||
add_cyclic_flag: true
|
||||
add_design_mask_flag: true
|
||||
add_binding_specification: true
|
||||
add_ss_specification: true
|
||||
|
||||
freeze_template_weights: true
|
||||
use_templates: true
|
||||
template_args:
|
||||
template_dim: 64
|
||||
template_blocks: 2
|
||||
activation_checkpointing: false
|
||||
|
||||
|
||||
use_token_distances: true
|
||||
token_distance_args:
|
||||
token_distance_dim: 64
|
||||
token_distance_blocks: 2
|
||||
use_token_distance_feats: true
|
||||
distance_gaussian_dim: 32
|
||||
activation_checkpointing: true
|
||||
|
||||
|
||||
msa_args:
|
||||
msa_s: 64
|
||||
msa_blocks: 4
|
||||
msa_dropout: 0.15
|
||||
z_dropout: 0.25
|
||||
miniformer_blocks: false
|
||||
pairwise_head_width: 32
|
||||
pairwise_num_heads: 4
|
||||
use_paired_feature: true
|
||||
activation_checkpointing: true
|
||||
|
||||
|
||||
pairformer_args:
|
||||
num_blocks: 64
|
||||
num_heads: 16
|
||||
dropout: 0.25
|
||||
post_layer_norm: false
|
||||
activation_checkpointing: true
|
||||
|
||||
|
||||
score_model_args:
|
||||
sigma_data: 16
|
||||
dim_fourier: 256
|
||||
atom_encoder_depth: 3
|
||||
atom_encoder_heads: 4
|
||||
|
||||
# token level args
|
||||
token_layers: 1
|
||||
token_transformer_depth: 24
|
||||
token_transformer_heads: 16
|
||||
diffusion_pairformer_args:
|
||||
num_blocks: 0
|
||||
num_heads: 2
|
||||
dropout: 0
|
||||
use_s_to_z: false
|
||||
|
||||
|
||||
|
||||
atom_decoder_depth: 3
|
||||
atom_decoder_heads: 4
|
||||
conditioning_transition_layers: 2
|
||||
transformer_post_ln: false
|
||||
activation_checkpointing: true
|
||||
|
||||
confidence_prediction: false
|
||||
structure_prediction_training: true
|
||||
|
||||
training_args:
|
||||
recycling_steps: 3
|
||||
sampling_steps: 20
|
||||
diffusion_multiplicity: 32
|
||||
diffusion_samples: 1
|
||||
confidence_loss_weight: 1e-4
|
||||
diffusion_loss_weight: 4.0
|
||||
distogram_loss_weight: 3e-2
|
||||
bfactor_loss_weight: 1e-3
|
||||
adam_beta_1: 0.9
|
||||
adam_beta_2: 0.95
|
||||
adam_eps: 0.00000001
|
||||
lr_scheduler: af3
|
||||
base_lr: 0.0
|
||||
max_lr: 0.0005
|
||||
lr_warmup_no_steps: 1000
|
||||
lr_start_decay_after_n_steps: 50000
|
||||
lr_decay_every_n_steps: 50000
|
||||
lr_decay_factor: 0.95
|
||||
weight_decay: 0.003
|
||||
weight_decay_exclude: true
|
||||
|
||||
validation_args:
|
||||
recycling_steps: 3
|
||||
sampling_steps: 200
|
||||
diffusion_samples: 1
|
||||
symmetry_correction: false
|
||||
|
||||
diffusion_process_args:
|
||||
sigma_min: 0.0004 # min noise level
|
||||
sigma_max: 160.0 # max noise level
|
||||
sigma_data: 16.0 # standard deviation of data distribution
|
||||
rho: 7 # controls the sampling schedule
|
||||
P_mean: -1.2 # mean of log-normal distribution from which noise is drawn for training
|
||||
P_std: 1.5 # standard deviation of log-normal distribution from which noise is drawn for training
|
||||
gamma_0: 0.8
|
||||
gamma_min: 1.0
|
||||
noise_scale: 1.0
|
||||
step_scale: 1.0
|
||||
mse_rotational_alignment: true
|
||||
coordinate_augmentation: true
|
||||
alignment_reverse_diff: true
|
||||
synchronize_sigmas: false
|
||||
|
||||
diffusion_loss_args:
|
||||
add_smooth_lddt_loss: true
|
||||
add_bond_loss: false
|
||||
nucleotide_loss_weight: 5.0
|
||||
ligand_loss_weight: 10.0
|
||||
|
||||
refolding_validator:
|
||||
_target_: boltzgen.model.validation.refolding.RefoldingValidator
|
||||
val_names: ["RCSB"]
|
||||
step_scale: 1.5
|
||||
noise_scale: 0.75
|
||||
atom14: ${data.atom14}
|
||||
atom37: ${data.atom37}
|
||||
val_monomer: ${data.monomer_split}
|
||||
val_ligand: ${data.ligand_split}
|
||||
analyze_task:
|
||||
_target_: boltzgen.task.analyze.analyze.Analyze
|
||||
name: ${name}
|
||||
debug: ${debug}
|
||||
design_dir: null
|
||||
num_processes: 1
|
||||
|
||||
# Common metrics to compute
|
||||
affinity_metrics: false
|
||||
allatom_fold_metrics: true
|
||||
backbone_fold_metrics: true
|
||||
noncovalents_original: false
|
||||
noncovalents_refolded: false
|
||||
delta_sasa_original: false
|
||||
delta_sasa_refolded: false
|
||||
largest_hydrophobic: false
|
||||
largest_hydrophobic_refolded: false
|
||||
run_clustering: false
|
||||
|
||||
# Liability analysis
|
||||
liability_analysis: false
|
||||
liability_modality: peptide
|
||||
liability_peptide_type: linear
|
||||
|
||||
# Uncommon metrics
|
||||
diversity_original: true
|
||||
diversity_refolded: true
|
||||
diversity_per_target_original: false
|
||||
diversity_per_target_refolded: false
|
||||
novelty_original: false
|
||||
novelty_refolded: false
|
||||
novelty_per_target_original: false
|
||||
novelty_per_target_refolded: false
|
||||
|
||||
wandb: null
|
||||
|
||||
data:
|
||||
_target_: boltzgen.task.predict.data_from_generated.FromGeneratedDataModule
|
||||
cfg:
|
||||
_target_: boltzgen.task.predict.data_from_generated.DataConfig
|
||||
tokenizer:
|
||||
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
|
||||
atomize_modified_residues: false
|
||||
featurizer:
|
||||
_target_: boltzgen.data.feature.featurizer.Featurizer
|
||||
|
||||
suffix: .cif
|
||||
suffix_metadata: .npz
|
||||
suffix_native: _native.cif
|
||||
samples_per_target: 1
|
||||
num_targets: 100000000
|
||||
moldir: ./training_data/mols
|
||||
|
||||
batch_size: 1
|
||||
num_workers: 4
|
||||
pin_memory: true
|
||||
return_native: true
|
||||
|
||||
folding_checkpoint: ./training_data/boltz2_fold.ckpt
|
||||
|
||||
folding_args:
|
||||
recycling_steps: 3
|
||||
sampling_steps: 200
|
||||
diffusion_samples: 1
|
||||
|
||||
folding_model_args:
|
||||
atom_s: 128
|
||||
atom_z: 16
|
||||
token_s: 384
|
||||
token_z: 128
|
||||
num_bins: 64
|
||||
atom_feature_dim: 388
|
||||
atoms_per_window_queries: 32
|
||||
atoms_per_window_keys: 128
|
||||
compile_pairformer: false
|
||||
compile_templates: false
|
||||
compile_msa: false
|
||||
use_miniformer: false
|
||||
ema: true
|
||||
ema_decay: 0.999
|
||||
exclude_ions_from_lddt: true
|
||||
num_val_datasets: 4
|
||||
ignore_ckpt_shape_mismatch: false
|
||||
aggregate_distogram: true
|
||||
bond_type_feature: true
|
||||
conditioning_cutoff_min: 4.0
|
||||
conditioning_cutoff_max: 20.0
|
||||
use_templates: true
|
||||
predict_bfactor: true
|
||||
checkpoint_diffusion_conditioning: false
|
||||
use_kernels: true
|
||||
|
||||
validators: null
|
||||
|
||||
embedder_args:
|
||||
atom_encoder_depth: 3
|
||||
atom_encoder_heads: 4
|
||||
add_mol_type_feat: true
|
||||
add_method_conditioning: true
|
||||
add_modified_flag: true
|
||||
add_cyclic_flag: true
|
||||
|
||||
msa_args:
|
||||
msa_s: 64
|
||||
msa_blocks: 4
|
||||
msa_dropout: 0.15
|
||||
z_dropout: 0.25
|
||||
miniformer_blocks: false
|
||||
pairwise_head_width: 32
|
||||
pairwise_num_heads: 4
|
||||
use_paired_feature: true
|
||||
activation_checkpointing: false
|
||||
|
||||
|
||||
template_args:
|
||||
template_dim: 64
|
||||
template_blocks: 2
|
||||
activation_checkpointing: false
|
||||
|
||||
|
||||
pairformer_args:
|
||||
num_blocks: 64
|
||||
num_heads: 16
|
||||
dropout: 0.25
|
||||
post_layer_norm: false
|
||||
activation_checkpointing: false
|
||||
|
||||
|
||||
score_model_args:
|
||||
sigma_data: 16
|
||||
dim_fourier: 256
|
||||
atom_encoder_depth: 3
|
||||
atom_encoder_heads: 4
|
||||
token_transformer_depth: 24
|
||||
token_transformer_heads: 16
|
||||
atom_decoder_depth: 3
|
||||
atom_decoder_heads: 4
|
||||
conditioning_transition_layers: 2
|
||||
transformer_post_ln: false
|
||||
activation_checkpointing: false
|
||||
|
||||
confidence_prediction: true
|
||||
affinity_prediction: false
|
||||
structure_prediction_training: true
|
||||
affinity_model_args:
|
||||
num_dist_bins: 64
|
||||
max_dist: 22
|
||||
no_trunk_feats: false
|
||||
add_s_to_z_prod: false
|
||||
add_s_input_to_s: false
|
||||
confidence_args:
|
||||
num_plddt_bins: 50
|
||||
num_pde_bins: 64
|
||||
num_pae_bins: 64
|
||||
|
||||
training_args:
|
||||
recycling_steps: 3
|
||||
sampling_steps: 20
|
||||
diffusion_multiplicity: 48
|
||||
diffusion_samples: 1
|
||||
affinity_loss_weight: 3e-3
|
||||
confidence_loss_weight: 1e-4
|
||||
diffusion_loss_weight: 4.0
|
||||
distogram_loss_weight: 3e-2
|
||||
bfactor_loss_weight: 1e-3
|
||||
adam_beta_1: 0.9
|
||||
adam_beta_2: 0.95
|
||||
adam_eps: 0.00000001
|
||||
lr_scheduler: af3
|
||||
base_lr: 0.0
|
||||
max_lr: 0.001
|
||||
lr_warmup_no_steps: 1000
|
||||
lr_start_decay_after_n_steps: 50000
|
||||
lr_decay_every_n_steps: 50000
|
||||
lr_decay_factor: 0.95
|
||||
weight_decay: 0.003
|
||||
weight_decay_exclude: true
|
||||
|
||||
validation_args:
|
||||
recycling_steps: 3
|
||||
sampling_steps: 200
|
||||
diffusion_samples: 5
|
||||
symmetry_correction: false
|
||||
|
||||
diffusion_process_args:
|
||||
sigma_min: 0.0004 # min noise level
|
||||
sigma_max: 160.0 # max noise level
|
||||
sigma_data: 16.0 # standard deviation of data distribution
|
||||
rho: 7 # controls the sampling schedule
|
||||
P_mean: -1.2 # mean of log-normal distribution from which noise is drawn for training
|
||||
P_std: 1.5 # standard deviation of log-normal distribution from which noise is drawn for training
|
||||
gamma_0: 0.8
|
||||
gamma_min: 1.0
|
||||
noise_scale: 1.0
|
||||
step_scale: 1.0
|
||||
mse_rotational_alignment: true
|
||||
coordinate_augmentation: true
|
||||
alignment_reverse_diff: true
|
||||
synchronize_sigmas: false
|
||||
|
||||
diffusion_loss_args:
|
||||
add_smooth_lddt_loss: true
|
||||
add_bond_loss: false
|
||||
nucleotide_loss_weight: 5.0
|
||||
ligand_loss_weight: 10.0
|
||||
38
src/boltzgen/resources/metrics_normalization.json
Normal file
38
src/boltzgen/resources/metrics_normalization.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"design_iiptm": {
|
||||
"mean": 0.3828427453913554,
|
||||
"std": 0.10621994686776445
|
||||
},
|
||||
"design_ptm": {
|
||||
"mean": 0.6548983461494845,
|
||||
"std": 0.041925380367595154
|
||||
},
|
||||
"min_design_to_target_pae": {
|
||||
"mean": 8.665374909044512,
|
||||
"std": 3.8643855151416453
|
||||
},
|
||||
"design_hydrophobicity": {
|
||||
"mean": 43.95844996189343,
|
||||
"std": 7.565721364606706
|
||||
},
|
||||
"design_largest_hydrophobic_patch_refolded": {
|
||||
"mean": 608.7238163271427,
|
||||
"std": 440.3199461402019
|
||||
},
|
||||
"delta_sasa_refolded": {
|
||||
"mean": 1368.936707104108,
|
||||
"std": 514.8573173243741
|
||||
},
|
||||
"plip_saltbridge_refolded": {
|
||||
"mean": 1.746922289263117,
|
||||
"std": 1.9177749177139145
|
||||
},
|
||||
"plip_hbonds_refolded": {
|
||||
"mean": 7.099706654392513,
|
||||
"std": 4.456755351208993
|
||||
},
|
||||
"affinity_probability_binary1": {
|
||||
"mean": 0.4726671722992441,
|
||||
"std": 0.19092615494037402
|
||||
}
|
||||
}
|
||||
@@ -104,6 +104,7 @@ class Analyze(Task):
|
||||
foldseek_binary: str = "/data/rbg/users/hstark/foldseek/bin/foldseek",
|
||||
skip_specific_ids: List[str] = None,
|
||||
designfolding_metrics: bool = False,
|
||||
use_design_mask_for_target: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the task.
|
||||
|
||||
@@ -152,6 +153,7 @@ class Analyze(Task):
|
||||
self.wandb = wandb
|
||||
self.slurm = slurm
|
||||
self.diversity_subset = diversity_subset
|
||||
self.use_design_mask_for_target = use_design_mask_for_target
|
||||
|
||||
# Prevent each worker process from spawning its own multithreaded pools
|
||||
torch.set_num_threads(1)
|
||||
@@ -565,12 +567,50 @@ class Analyze(Task):
|
||||
"designed_chain_sequence": design_chain_seq,
|
||||
}
|
||||
|
||||
# Add per-chain sequences to csv when designing multiple chains
|
||||
design_token_indices = torch.where(feat["design_mask"].bool() & feat["token_pad_mask"].bool())[0]
|
||||
designed_chain_ids = feat["asym_id"][design_token_indices].unique().tolist()
|
||||
if len(designed_chain_ids) > 1:
|
||||
for chain_id in designed_chain_ids:
|
||||
chain_mask = feat["asym_id"] == chain_id
|
||||
|
||||
# Full chain sequence
|
||||
chain_res_types = res_type_argmax[chain_mask]
|
||||
full_chain_seq = "".join(
|
||||
[
|
||||
const.prot_token_to_letter.get(const.tokens[t], "X")
|
||||
for t in chain_res_types
|
||||
]
|
||||
)
|
||||
|
||||
# Designed residues only from this chain
|
||||
design_chain_mask = feat["design_mask"].bool() & feat["token_pad_mask"].bool() & chain_mask
|
||||
design_res_types = res_type_argmax[design_chain_mask]
|
||||
design_seq = "".join(
|
||||
[
|
||||
const.prot_token_to_letter.get(const.tokens[t], "X")
|
||||
for t in design_res_types
|
||||
]
|
||||
)
|
||||
|
||||
metrics[f"designed_sequence_{chain_id}"] = design_seq
|
||||
metrics[f"full_sequence_{chain_id}"] = full_chain_seq
|
||||
|
||||
|
||||
target_id = re.search(rf"{self.data.cfg.target_id_regex}", sample_id).group(1)
|
||||
|
||||
# Get masks
|
||||
design_mask = feat["design_mask"].bool()
|
||||
chain_design_mask = feat["chain_design_mask"].bool()
|
||||
|
||||
design_resolved_mask = design_mask & feat["token_resolved_mask"].bool()
|
||||
target_resolved_mask = ~design_mask & feat["token_resolved_mask"].bool()
|
||||
|
||||
# For symmetric designs where all chains have designed residues, use design_mask
|
||||
# instead of chain_design_mask so "target" = non-designed residues (not empty)
|
||||
if self.use_design_mask_for_target:
|
||||
target_resolved_mask = (~design_mask) & feat["token_resolved_mask"].bool()
|
||||
else:
|
||||
target_resolved_mask = (~chain_design_mask) & feat["token_resolved_mask"].bool()
|
||||
atom_design_resolved_mask = (
|
||||
(feat["atom_to_token"].float() @ design_resolved_mask.unsqueeze(-1).float())
|
||||
.bool()
|
||||
@@ -584,11 +624,10 @@ class Analyze(Task):
|
||||
atom_resolved_mask = feat["atom_resolved_mask"]
|
||||
resolved_atoms_design_mask = atom_design_resolved_mask[atom_resolved_mask]
|
||||
resolved_atoms_target_mask = atom_target_resolved_mask[atom_resolved_mask]
|
||||
design_mask_for_chain = feat["asym_id"] == design_chain_id
|
||||
atom_chain_mask = (
|
||||
(
|
||||
feat["atom_to_token"].float()
|
||||
@ design_mask_for_chain.unsqueeze(-1).float()
|
||||
@ chain_design_mask.unsqueeze(-1).float()
|
||||
)
|
||||
.bool()
|
||||
.squeeze()
|
||||
@@ -669,7 +708,11 @@ class Analyze(Task):
|
||||
delta_sasa_orig,
|
||||
design_sasa_unbound,
|
||||
design_sasa_bound,
|
||||
) = get_delta_sasa(path, resolved_atoms_target_mask)
|
||||
) = get_delta_sasa(
|
||||
path,
|
||||
atom_target_mask=resolved_atoms_target_mask,
|
||||
atom_design_mask=resolved_atoms_design_mask,
|
||||
)
|
||||
metrics["delta_sasa_original"] = delta_sasa_orig
|
||||
metrics["design_sasa_unbound_original"] = design_sasa_unbound
|
||||
metrics["design_sasa_bound_original"] = design_sasa_bound
|
||||
@@ -1039,12 +1082,21 @@ class Analyze(Task):
|
||||
if self.delta_sasa_refolded:
|
||||
cif_path_refolded = self.refold_cif_dir / f"{feat['id']}.cif"
|
||||
|
||||
if not cif_path_refolded.exists():
|
||||
msg = f"Refolded cif path does not exist. This can happen if a process was interrupted between writing the refold .npz file and the refold .cif file. Missing path: {cif_path_refolded}"
|
||||
print(msg)
|
||||
return None
|
||||
|
||||
# Compute delta sasa
|
||||
(
|
||||
delta_sasa_refolded,
|
||||
design_sasa_unbound,
|
||||
design_sasa_bound,
|
||||
) = get_delta_sasa(cif_path_refolded, resolved_atoms_target_mask)
|
||||
) = get_delta_sasa(
|
||||
cif_path_refolded,
|
||||
atom_target_mask=resolved_atoms_target_mask,
|
||||
atom_design_mask=resolved_atoms_design_mask,
|
||||
)
|
||||
|
||||
metrics["delta_sasa_refolded"] = delta_sasa_refolded
|
||||
metrics["design_sasa_unbound_refolded"] = design_sasa_unbound
|
||||
|
||||
@@ -97,7 +97,7 @@ def make_histogram(
|
||||
def get_best_folding_sample(folded):
|
||||
confidence = 0.8 * folded["design_to_target_iptm"] + 0.2 * folded["design_ptm"]
|
||||
best_idx = np.argmax(confidence)
|
||||
|
||||
|
||||
# TODO: remove the "if k in folded"
|
||||
best_sample = {
|
||||
k: folded[k][best_idx] for k in const.eval_keys_confidence if k in folded
|
||||
@@ -188,10 +188,19 @@ def count_noncovalents(feat):
|
||||
biotite_array, _ = hydride.add_hydrogen(biotite_array)
|
||||
hbond = biotite.structure.hbond(biotite_array)
|
||||
donor_idxs, acceptor_idxs = hbond[:, 0], hbond[:, 2]
|
||||
cross_design_hbonds = (
|
||||
biotite_array.is_design[donor_idxs] != biotite_array.is_design[acceptor_idxs]
|
||||
donor_design_hbonds = int(
|
||||
(
|
||||
biotite_array.is_design[donor_idxs]
|
||||
& ~biotite_array.is_chain_design[acceptor_idxs]
|
||||
).sum()
|
||||
)
|
||||
metrics["plip_hbonds"] = int(cross_design_hbonds.sum())
|
||||
acceptor_design_hbonds = int(
|
||||
(
|
||||
~biotite_array.is_chain_design[donor_idxs]
|
||||
& biotite_array.is_design[acceptor_idxs]
|
||||
).sum()
|
||||
)
|
||||
metrics["plip_hbonds"] = donor_design_hbonds + acceptor_design_hbonds
|
||||
|
||||
# saltbridges
|
||||
pos_atoms = biotite_array[biotite_array.charge > 0]
|
||||
@@ -204,10 +213,13 @@ def count_noncovalents(feat):
|
||||
(pos_neg_distances > 0.5) & (pos_neg_distances < 5.5)
|
||||
)
|
||||
# only keep the ones between design and non design
|
||||
cross_design_saltbridges = (
|
||||
pos_atoms.is_design[pos_idxs] != neg_atoms.is_design[neg_idxs]
|
||||
pos_design_sb = int(
|
||||
(pos_atoms.is_design[pos_idxs] & ~neg_atoms.is_chain_design[neg_idxs]).sum()
|
||||
)
|
||||
metrics["plip_saltbridge"] = int(cross_design_saltbridges.sum())
|
||||
neg_design_sb = int(
|
||||
(~pos_atoms.is_chain_design[pos_idxs] & neg_atoms.is_design[neg_idxs]).sum()
|
||||
)
|
||||
metrics["plip_saltbridge"] = pos_design_sb + neg_design_sb
|
||||
else:
|
||||
metrics["plip_saltbridge"] = 0
|
||||
return metrics
|
||||
@@ -629,9 +641,14 @@ def largest_hydrophobic_patch_area(cif_path, distance_cutoff=6.0):
|
||||
return max_patch_area
|
||||
|
||||
|
||||
def get_delta_sasa(path, atom_design_mask):
|
||||
def get_delta_sasa(
|
||||
path,
|
||||
atom_target_mask,
|
||||
atom_design_mask,
|
||||
):
|
||||
stack = _load_stack(path)
|
||||
atoms = stack[0]
|
||||
|
||||
res = [
|
||||
r.decode().strip() if isinstance(r, bytes) else str(r).strip()
|
||||
for r in atoms.res_name
|
||||
@@ -645,23 +662,41 @@ def get_delta_sasa(path, atom_design_mask):
|
||||
radii = np.array(
|
||||
[_radius(rn, an, el) for rn, an, el in zip(res, atm, elem)], dtype=float
|
||||
)
|
||||
area = sasa(atoms, probe_radius=1.4, point_number=960, vdw_radii=radii)
|
||||
design_bound = area[atom_design_mask].sum()
|
||||
|
||||
ligand_atoms = atoms[atom_design_mask]
|
||||
lig_res = [r for r, m in zip(res, atom_design_mask) if m]
|
||||
lig_atm = [a for a, m in zip(atm, atom_design_mask) if m]
|
||||
lig_elem = [e for e, m in zip(elem, atom_design_mask) if m]
|
||||
|
||||
bound_mask = atom_design_mask | atom_target_mask
|
||||
atoms_bound = atoms[bound_mask]
|
||||
radii_bound = radii[bound_mask]
|
||||
|
||||
area_bound = sasa(
|
||||
atoms_bound,
|
||||
probe_radius=1.4,
|
||||
point_number=960,
|
||||
vdw_radii=radii_bound,
|
||||
)
|
||||
|
||||
target_in_bound = atom_target_mask[bound_mask]
|
||||
target_bound = area_bound[target_in_bound].sum()
|
||||
|
||||
|
||||
|
||||
target_atoms = atoms[atom_target_mask]
|
||||
target_res = [r for r, m in zip(res, atom_target_mask) if m]
|
||||
target_atm = [a for a, m in zip(atm, atom_target_mask) if m]
|
||||
target_elem = [e for e, m in zip(elem, atom_target_mask) if m]
|
||||
|
||||
radii_lig = np.array(
|
||||
[_radius(rn, an, el) for rn, an, el in zip(lig_res, lig_atm, lig_elem)],
|
||||
[_radius(rn, an, el) for rn, an, el in zip(target_res, target_atm, target_elem)],
|
||||
dtype=float,
|
||||
)
|
||||
ligand_area = sasa(
|
||||
ligand_atoms, probe_radius=1.4, point_number=960, vdw_radii=radii_lig
|
||||
target_area = sasa(
|
||||
target_atoms,
|
||||
probe_radius=1.4,
|
||||
point_number=960,
|
||||
vdw_radii=radii_lig,
|
||||
)
|
||||
delta = ligand_area.sum() - design_bound
|
||||
return delta, ligand_area.sum(), design_bound
|
||||
delta = target_area.sum() - target_bound
|
||||
return delta, target_area.sum(), target_bound
|
||||
|
||||
|
||||
def compute_ss_metrics(dssp_pred, ss_conditioning_metricsed):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from boltzgen.utils.quiet import quiet_startup
|
||||
|
||||
|
||||
@@ -224,7 +225,7 @@ class Filter(Task):
|
||||
if not metrics_override is None:
|
||||
for k in metrics_override:
|
||||
if metrics_override[k] is None:
|
||||
del self.metrics[k]
|
||||
self.metrics.pop(k, None)
|
||||
else:
|
||||
self.metrics[k] = metrics_override[k]
|
||||
|
||||
@@ -284,12 +285,12 @@ class Filter(Task):
|
||||
{
|
||||
"feature": "GLY_fraction",
|
||||
"lower_is_better": True,
|
||||
"threshold": 0.2,
|
||||
"threshold": 0.3,
|
||||
},
|
||||
{
|
||||
"feature": "GLU_fraction",
|
||||
"lower_is_better": True,
|
||||
"threshold": 0.2,
|
||||
"threshold": 0.3,
|
||||
},
|
||||
{
|
||||
"feature": "LEU_fraction",
|
||||
@@ -299,7 +300,7 @@ class Filter(Task):
|
||||
{
|
||||
"feature": "VAL_fraction",
|
||||
"lower_is_better": True,
|
||||
"threshold": 0.2,
|
||||
"threshold": 0.3,
|
||||
},
|
||||
]
|
||||
)
|
||||
@@ -312,6 +313,7 @@ class Filter(Task):
|
||||
self.load_dataframe()
|
||||
self.reset_outdir()
|
||||
self.filter_df()
|
||||
self.absolute_metrics()
|
||||
self.sort_df()
|
||||
self.optimize_diversity()
|
||||
self.write_outdir()
|
||||
@@ -379,6 +381,8 @@ class Filter(Task):
|
||||
"design_largest_hydrophobic_patch_refolded"
|
||||
]
|
||||
df["neg_min_interaction_pae"] = -df["min_interaction_pae"]
|
||||
df["neg_filter_rmsd"] = -df["filter_rmsd"]
|
||||
df["neg_filter_rmsd_design"] = -df["filter_rmsd_design"]
|
||||
df["has_x"] = df["designed_sequence"].str.contains("X")
|
||||
self.df = df
|
||||
|
||||
@@ -418,6 +422,61 @@ class Filter(Task):
|
||||
)
|
||||
print("\n")
|
||||
|
||||
def absolute_metrics(self):
|
||||
norm_path = Path("src/boltzgen/resources/metrics_normalization.json")
|
||||
if not norm_path.exists():
|
||||
return
|
||||
|
||||
with norm_path.open("r") as f:
|
||||
norm_stats = json.load(f)
|
||||
|
||||
for col, stats in norm_stats.items():
|
||||
mean = stats["mean"]
|
||||
std = stats["std"]
|
||||
if col in self.df.columns:
|
||||
self.df[col + "_z"] = (self.df[col] - mean) / std
|
||||
|
||||
importances = {
|
||||
"affinity_probability_binary1": 1.5,
|
||||
"design_iiptm": 1.0,
|
||||
"design_ptm": 0.5,
|
||||
"min_design_to_target_pae": -1.0, # lower is better
|
||||
"design_hydrophobicity": -0.125, # lower is better
|
||||
"design_largest_hydrophobic_patch_refolded": -0.15, # lower is better
|
||||
"delta_sasa_refolded": 0.25,
|
||||
"plip_saltbridge_refolded": 0.25,
|
||||
"plip_hbonds_refolded": 0.25,
|
||||
}
|
||||
|
||||
self.df["absolute_score"] = 0.0
|
||||
for base_col, weight in importances.items():
|
||||
if base_col in self.df.columns:
|
||||
norm_col = base_col + "_z"
|
||||
self.df["absolute_score"] += weight * self.df[norm_col]
|
||||
total_importance = sum(abs(w) for w in importances.values())
|
||||
self.df["absolute_score"] /= total_importance
|
||||
|
||||
self.df["structure_confidence"] = 0.0
|
||||
weight_sum = 0
|
||||
for col in ["design_iiptm", "design_ptm", "min_design_to_target_pae"]:
|
||||
weight = importances[col]
|
||||
norm_col = col + "_z"
|
||||
weight_sum += abs(weight)
|
||||
self.df["structure_confidence"] += weight * self.df[norm_col]
|
||||
self.df["structure_confidence"] /= weight_sum
|
||||
|
||||
for flt in self.filters:
|
||||
feat = flt["feature"]
|
||||
filter_col = f"pass_{feat}_filter"
|
||||
if "fraction" in feat:
|
||||
# If this is a "fraction" feature, meaning a res_type fraction filter, only apply the penalty if num_design > 8
|
||||
mask_fail = (self.df["num_design"] > 8) & (self.df[filter_col] == False)
|
||||
else:
|
||||
mask_fail = self.df[filter_col] == False
|
||||
|
||||
self.df.loc[mask_fail, "absolute_score"] *= 0.1
|
||||
|
||||
|
||||
def sort_df(self):
|
||||
rank_df = pd.DataFrame(index=self.df.index)
|
||||
|
||||
@@ -584,6 +643,12 @@ class Filter(Task):
|
||||
heapq.heappush(heap, (-gain, i))
|
||||
|
||||
buckets = np.zeros(len(self.size_buckets) + 1)
|
||||
first = selected[0]
|
||||
first_len = len(self.df_m["sequence"][first])
|
||||
for idx, bucket_size in enumerate(self.size_buckets):
|
||||
if first_len >= bucket_size["min"] and first_len < bucket_size["max"]:
|
||||
buckets[idx] += 1
|
||||
break
|
||||
for _ in tqdm(
|
||||
range(k - 1), desc="Performing lazy greedy diversity optimization."
|
||||
):
|
||||
@@ -749,6 +814,12 @@ class Filter(Task):
|
||||
hist_metrics = list(dict.fromkeys(hist_metrics))
|
||||
extra_pairs = list(dict.fromkeys(extra_pairs))
|
||||
|
||||
# Prepend any active ranking metrics not already in the lists
|
||||
extra_ranking = [m for m in self.metrics if m in self.df.columns and m not in hist_metrics]
|
||||
hist_metrics = extra_ranking + hist_metrics
|
||||
summary_metrics = extra_ranking + summary_metrics
|
||||
extra_pairs = [("num_design", m) for m in extra_ranking] + extra_pairs
|
||||
|
||||
if self.use_affinity:
|
||||
summary_metrics.insert(2, "affinity_probability_binary1")
|
||||
hist_metrics.insert(2, "affinity_probability_binary1")
|
||||
@@ -883,7 +954,6 @@ class Filter(Task):
|
||||
pdf_path = self.outdir / f"results_overview.pdf"
|
||||
pdf = PdfPages(pdf_path)
|
||||
|
||||
|
||||
def _ensure_width(fig, target_w=8.5):
|
||||
w, h = fig.get_size_inches()
|
||||
if abs(w - target_w) > 0.01:
|
||||
@@ -1231,8 +1301,6 @@ class Filter(Task):
|
||||
|
||||
pdf.close()
|
||||
|
||||
|
||||
|
||||
print(
|
||||
"A description of metrics and summarizing plots was written to:", pdf_path
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import random
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
from collections import defaultdict
|
||||
from rdkit.Chem import Mol
|
||||
@@ -198,6 +199,7 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
|
||||
return_native: bool = False,
|
||||
reference_metadata_dir: Optional[Path] = None,
|
||||
target_templates: bool = False,
|
||||
design_mask_templates: bool = False,
|
||||
compute_affinity: bool = False,
|
||||
design: bool = False,
|
||||
backbone_only: bool = False,
|
||||
@@ -232,6 +234,7 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
|
||||
self.return_native = return_native
|
||||
self.reference_metadata_dir = reference_metadata_dir
|
||||
self.target_templates = target_templates
|
||||
self.design_mask_templates = design_mask_templates
|
||||
self.compute_affinity = compute_affinity
|
||||
self.design = design
|
||||
self.backbone_only = backbone_only
|
||||
@@ -317,8 +320,28 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
|
||||
if "binding_type" in metadata:
|
||||
binding_type = metadata["binding_type"]
|
||||
|
||||
# Per-residue amino acid constraints for inverse folding
|
||||
aa_constraint_mask = None
|
||||
if "aa_constraint_mask" in metadata:
|
||||
loaded_mask = metadata["aa_constraint_mask"]
|
||||
# Validate the loaded mask is a proper array with expected shape
|
||||
if (
|
||||
isinstance(loaded_mask, np.ndarray)
|
||||
and loaded_mask.ndim == 2
|
||||
and loaded_mask.shape[1] == 20 # 20 canonical amino acids
|
||||
):
|
||||
aa_constraint_mask = loaded_mask
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Invalid aa_constraint_mask in NPZ: "
|
||||
f"type={type(loaded_mask)}, shape={getattr(loaded_mask, 'shape', 'N/A')}. "
|
||||
f"Expected ndarray with shape (N, 20). Ignoring constraints.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Get features
|
||||
feat = self.get_feat(generated_path, design_mask, ss_type, binding_type)
|
||||
feat = self.get_feat(generated_path, design_mask, ss_type, binding_type, aa_constraint_mask)
|
||||
|
||||
# Get native features
|
||||
if self.return_native:
|
||||
@@ -332,7 +355,7 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
|
||||
|
||||
return feat
|
||||
|
||||
def get_feat(self, path, design_mask, ss_type=None, binding_type=None):
|
||||
def get_feat(self, path, design_mask, ss_type=None, binding_type=None, aa_constraint_mask=None):
|
||||
# Load design
|
||||
if self.extra_mol_dir is not None:
|
||||
mols = {
|
||||
@@ -460,6 +483,9 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
|
||||
features["ss_type"] = torch.from_numpy(ss_type).long()
|
||||
if binding_type is not None:
|
||||
features["binding_type"] = torch.from_numpy(binding_type).long()
|
||||
# Per-residue amino acid constraints for inverse folding
|
||||
if aa_constraint_mask is not None:
|
||||
features["aa_constraint_mask"] = torch.from_numpy(aa_constraint_mask).float()
|
||||
|
||||
# If we do not want the design mask to impact the featurizer (e.g. represent atoms as atom14), we set the design mask only here.
|
||||
if not self.design:
|
||||
@@ -490,7 +516,10 @@ class FromGeneratedDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Set templates
|
||||
if self.target_templates:
|
||||
template_mask = ~features["chain_design_mask"].numpy()
|
||||
if self.design_mask_templates:
|
||||
template_mask = ~features["design_mask"].numpy()
|
||||
else:
|
||||
template_mask = ~features["chain_design_mask"].numpy()
|
||||
templates_features = template_from_tokens(tokenized, template_mask)
|
||||
else:
|
||||
# Compute template features
|
||||
@@ -525,6 +554,7 @@ class FromGeneratedDataModule(pl.LightningDataModule):
|
||||
return_native: bool = False,
|
||||
compute_affinity: bool = False,
|
||||
target_templates: bool = False,
|
||||
design_mask_templates: bool = False,
|
||||
skip_existing: bool = False,
|
||||
skip_existing_kind: str = None,
|
||||
legacy_gen_suffix: str = "_gen.cif",
|
||||
@@ -552,6 +582,7 @@ class FromGeneratedDataModule(pl.LightningDataModule):
|
||||
self.legacy_metadata_suffix = legacy_metadata_suffix
|
||||
self.compute_affinity = compute_affinity
|
||||
self.target_templates = target_templates
|
||||
self.design_mask_templates = design_mask_templates
|
||||
self.extra_features = extra_features
|
||||
self.disulfide_prob = cfg.disulfide_prob
|
||||
self.disulfide_on = cfg.disulfide_on
|
||||
@@ -583,6 +614,7 @@ class FromGeneratedDataModule(pl.LightningDataModule):
|
||||
return_native=self.return_native,
|
||||
reference_metadata_dir=self.reference_metadata_dir,
|
||||
target_templates=self.target_templates,
|
||||
design_mask_templates=self.design_mask_templates,
|
||||
compute_affinity=self.compute_affinity,
|
||||
design=self.cfg.design,
|
||||
backbone_only=self.cfg.backbone_only,
|
||||
@@ -790,6 +822,7 @@ class FromGeneratedDataModule(pl.LightningDataModule):
|
||||
return_native=self.return_native,
|
||||
reference_metadata_dir=self.reference_metadata_dir,
|
||||
target_templates=self.target_templates,
|
||||
design_mask_templates=self.design_mask_templates,
|
||||
compute_affinity=self.compute_affinity,
|
||||
design=self.cfg.design,
|
||||
backbone_only=self.cfg.backbone_only,
|
||||
|
||||
@@ -216,6 +216,10 @@ class PredictionDataset(torch.utils.data.Dataset):
|
||||
tokenized.tokens["structure_group"] = design_info.res_structure_groups[
|
||||
token_to_res
|
||||
]
|
||||
# Transfer per-residue amino acid constraints (shape: num_tokens x 20)
|
||||
tokenized.tokens["aa_constraint_mask"] = design_info.res_aa_constraint_mask[
|
||||
token_to_res
|
||||
]
|
||||
|
||||
# Propagate design mask to obtain chain_design_mask (True whenever something is covalently bound to any residue that is in a chain that contains a design residue).
|
||||
chain_design_mask = tokenized.tokens["design_mask"].astype(bool)
|
||||
|
||||
@@ -384,34 +384,31 @@ class DesignWriter(BasePredictionWriter):
|
||||
|
||||
# Write metadata
|
||||
metadata_path = f"{self.outdir}/{file_name}.npz"
|
||||
token_mask = sample["token_pad_mask"].bool()
|
||||
|
||||
np.savez_compressed(
|
||||
metadata_path,
|
||||
design_mask=design_mask[sample["token_pad_mask"].bool()]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
inverse_fold_design_mask=sample["inverse_fold_design_mask"][
|
||||
sample["token_pad_mask"].bool()
|
||||
]
|
||||
.cpu()
|
||||
.numpy()
|
||||
if "inverse_fold_design_mask" in sample
|
||||
else None,
|
||||
mol_type=sample["mol_type"][sample["token_pad_mask"].bool()]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
ss_type=sample["ss_type"][sample["token_pad_mask"].bool()]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
token_resolved_mask=sample["token_resolved_mask"][
|
||||
sample["token_pad_mask"].bool()
|
||||
]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
binding_type=binding_type[sample["token_pad_mask"].bool()]
|
||||
.cpu()
|
||||
.numpy(),
|
||||
)
|
||||
# Build metadata dict with required fields
|
||||
metadata_dict = {
|
||||
"design_mask": design_mask[token_mask].cpu().numpy(),
|
||||
"mol_type": sample["mol_type"][token_mask].cpu().numpy(),
|
||||
"ss_type": sample["ss_type"][token_mask].cpu().numpy(),
|
||||
"token_resolved_mask": sample["token_resolved_mask"][token_mask].cpu().numpy(),
|
||||
"binding_type": binding_type[token_mask].cpu().numpy(),
|
||||
}
|
||||
|
||||
# Add optional fields only if they have valid values (avoid None -> object array)
|
||||
if "inverse_fold_design_mask" in sample:
|
||||
metadata_dict["inverse_fold_design_mask"] = (
|
||||
sample["inverse_fold_design_mask"][token_mask].cpu().numpy()
|
||||
)
|
||||
|
||||
# Per-residue amino acid constraints (for inverse folding step)
|
||||
# Only save if constraints exist AND have non-zero values
|
||||
if "aa_constraint_mask" in batch:
|
||||
aa_mask = batch["aa_constraint_mask"][0]
|
||||
if aa_mask.any(): # Only save if there are actual constraints
|
||||
metadata_dict["aa_constraint_mask"] = aa_mask[token_mask].cpu().numpy()
|
||||
|
||||
np.savez_compressed(metadata_path, **metadata_dict)
|
||||
|
||||
# Write trajectories
|
||||
if self.save_traj:
|
||||
|
||||
79
tests/conftest.py
Normal file
79
tests/conftest.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Optional test configuration for mocking heavy dependencies.
|
||||
|
||||
Enable with:
|
||||
pytest --mock-heavy-deps tests/test_residue_constraints.py
|
||||
|
||||
By default no mocking is performed, so integration tests run against
|
||||
real dependencies.
|
||||
"""
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def _install_mock(name: str) -> None:
|
||||
"""Install a mock module (and parent packages) into sys.modules."""
|
||||
parts = name.split(".")
|
||||
for i in range(len(parts)):
|
||||
mod_name = ".".join(parts[: i + 1])
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = MagicMock()
|
||||
|
||||
|
||||
# Heavy dependencies that schema.py imports transitively but are NOT
|
||||
# needed by the three constraint-parsing functions under test.
|
||||
_MOCK_MODULES = [
|
||||
"torch",
|
||||
"torch.nn",
|
||||
"torch.nn.functional",
|
||||
"torch.utils",
|
||||
"torch.utils.data",
|
||||
"pytorch_lightning",
|
||||
"hydra",
|
||||
"hydra.core",
|
||||
"hydra.core.config_store",
|
||||
"einops",
|
||||
"einx",
|
||||
"mashumaro",
|
||||
"biotite",
|
||||
"biotite.structure",
|
||||
"biotite.structure.io",
|
||||
"biotite.structure.io.pdbx",
|
||||
"pydssp",
|
||||
"logomaker",
|
||||
"hydride",
|
||||
"gemmi",
|
||||
"pdbeccdutils",
|
||||
"pdbeccdutils.core",
|
||||
"pdbeccdutils.core.ccd_reader",
|
||||
"edit_distance",
|
||||
"huggingface_hub",
|
||||
"nvidia_ml_py",
|
||||
"cuequivariance_ops_cu12",
|
||||
"cuequivariance_ops_torch_cu12",
|
||||
"cuequivariance_torch",
|
||||
"numba",
|
||||
"sklearn",
|
||||
"sklearn.cluster",
|
||||
"sklearn.neighbors",
|
||||
"pandas",
|
||||
"matplotlib",
|
||||
"matplotlib.pyplot",
|
||||
"tqdm",
|
||||
"Bio",
|
||||
"Bio.PDB",
|
||||
]
|
||||
|
||||
|
||||
def pytest_addoption(parser) -> None:
|
||||
parser.addoption(
|
||||
"--mock-heavy-deps",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Mock heavy optional dependencies for parser-only unit tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config) -> None:
|
||||
if config.getoption("--mock-heavy-deps"):
|
||||
for mod in _MOCK_MODULES:
|
||||
_install_mock(mod)
|
||||
90
tests/test_inverse_fold_constraint_masks.py
Normal file
90
tests/test_inverse_fold_constraint_masks.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Integration tests for inverse-folding constraint mask composition."""
|
||||
|
||||
import pytest
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
from boltzgen.data import const
|
||||
from boltzgen.model.modules.inverse_fold import build_constraint_logit_mask
|
||||
|
||||
|
||||
INF = 10**6
|
||||
|
||||
|
||||
def _allowed_only_mask(allowed_tokens: list[str]) -> torch.Tensor:
|
||||
"""Build a single-row mask where only `allowed_tokens` are permitted."""
|
||||
num_aa = len(const.canonical_tokens)
|
||||
mask = torch.ones((1, num_aa), dtype=torch.float32)
|
||||
for token in allowed_tokens:
|
||||
mask[0, const.canonical_tokens.index(token)] = 0.0
|
||||
return mask
|
||||
|
||||
|
||||
def test_conflict_allowed_and_global_avoid_keeps_global_restriction() -> None:
|
||||
cys_idx = const.canonical_tokens.index("CYS")
|
||||
aa_constraint_mask = _allowed_only_mask(["CYS"])
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="Relaxing per-residue constraints"):
|
||||
out = build_constraint_logit_mask(
|
||||
num_nodes=1,
|
||||
aa_constraint_mask=aa_constraint_mask,
|
||||
inverse_fold_restriction=["CYS"],
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=INF,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
# Global avoid must still block CYS after conflict handling.
|
||||
assert out[0, cys_idx].item() == -INF
|
||||
# All other residues remain available.
|
||||
assert (out[0] == 0).sum().item() == len(const.canonical_tokens) - 1
|
||||
|
||||
|
||||
def test_non_conflicting_constraints_compose_correctly() -> None:
|
||||
ala_idx = const.canonical_tokens.index("ALA")
|
||||
cys_idx = const.canonical_tokens.index("CYS")
|
||||
aa_constraint_mask = _allowed_only_mask(["ALA"])
|
||||
|
||||
out = build_constraint_logit_mask(
|
||||
num_nodes=1,
|
||||
aa_constraint_mask=aa_constraint_mask,
|
||||
inverse_fold_restriction=["CYS"],
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=INF,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
# Only ALA should remain available.
|
||||
assert out[0, ala_idx].item() == 0.0
|
||||
assert out[0, cys_idx].item() == -INF
|
||||
assert (out[0] == 0).sum().item() == 1
|
||||
|
||||
|
||||
def test_global_restrictions_that_block_all_raise() -> None:
|
||||
with pytest.raises(ValueError, match="no valid amino acids"):
|
||||
build_constraint_logit_mask(
|
||||
num_nodes=1,
|
||||
aa_constraint_mask=None,
|
||||
inverse_fold_restriction=const.canonical_tokens,
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=INF,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
|
||||
def test_shape_mismatch_ignores_per_residue_mask() -> None:
|
||||
bad_shape = torch.zeros((2, 20), dtype=torch.float32)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="shape mismatch"):
|
||||
out = build_constraint_logit_mask(
|
||||
num_nodes=1,
|
||||
aa_constraint_mask=bad_shape,
|
||||
inverse_fold_restriction=[],
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=INF,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
# No restrictions should remain after ignoring mismatched input.
|
||||
assert out.shape == (1, len(const.canonical_tokens))
|
||||
assert torch.all(out == 0)
|
||||
381
tests/test_residue_constraints.py
Normal file
381
tests/test_residue_constraints.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Unit tests for per-residue amino acid constraint parsing.
|
||||
|
||||
Tests parse_residue_constraints(), _normalize_aa_spec(), and
|
||||
_convert_aa_names_to_indices() from boltzgen.data.parse.schema.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from boltzgen.data import const
|
||||
from boltzgen.data.parse.schema import (
|
||||
_convert_aa_names_to_indices,
|
||||
_normalize_aa_spec,
|
||||
parse_residue_constraints,
|
||||
)
|
||||
|
||||
# Shorthand fixtures
|
||||
CANONICAL = const.canonical_tokens # 20 three-letter codes
|
||||
LETTER_MAP = const.prot_letter_to_token # e.g. {"A": "ALA", ...}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _normalize_aa_spec
|
||||
# ============================================================================
|
||||
|
||||
class TestNormalizeAASpec:
|
||||
"""Tests for _normalize_aa_spec helper."""
|
||||
|
||||
def test_single_letter(self):
|
||||
assert _normalize_aa_spec("A") == ["A"]
|
||||
|
||||
def test_multi_letter_string(self):
|
||||
assert _normalize_aa_spec("AGS") == ["A", "G", "S"]
|
||||
|
||||
def test_long_string(self):
|
||||
assert _normalize_aa_spec("AVILMFYW") == list("AVILMFYW")
|
||||
|
||||
def test_three_letter_code(self):
|
||||
assert _normalize_aa_spec("ALA") == ["ALA"]
|
||||
|
||||
def test_three_letter_not_valid(self):
|
||||
# "AGS" is 3 chars but NOT a valid 3-letter code → split into 1-letter
|
||||
assert _normalize_aa_spec("AGS") == ["A", "G", "S"]
|
||||
|
||||
def test_list_format_single_letters(self):
|
||||
assert _normalize_aa_spec(["A", "G", "S"]) == ["A", "G", "S"]
|
||||
|
||||
def test_list_format_three_letter(self):
|
||||
assert _normalize_aa_spec(["ALA", "GLY"]) == ["ALA", "GLY"]
|
||||
|
||||
def test_lowercase_normalised(self):
|
||||
assert _normalize_aa_spec("ags") == ["A", "G", "S"]
|
||||
|
||||
def test_whitespace_stripped(self):
|
||||
assert _normalize_aa_spec(" AG ") == ["A", "G"]
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
_normalize_aa_spec(123)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _convert_aa_names_to_indices
|
||||
# ============================================================================
|
||||
|
||||
class TestConvertAANamesToIndices:
|
||||
"""Tests for _convert_aa_names_to_indices helper."""
|
||||
|
||||
def test_single_letter_a(self):
|
||||
indices = _convert_aa_names_to_indices(["A"], CANONICAL, LETTER_MAP)
|
||||
assert indices == [CANONICAL.index("ALA")]
|
||||
|
||||
def test_single_letter_c(self):
|
||||
indices = _convert_aa_names_to_indices(["C"], CANONICAL, LETTER_MAP)
|
||||
assert indices == [CANONICAL.index("CYS")]
|
||||
|
||||
def test_three_letter_code(self):
|
||||
indices = _convert_aa_names_to_indices(["ALA", "GLY"], CANONICAL, LETTER_MAP)
|
||||
assert indices == [CANONICAL.index("ALA"), CANONICAL.index("GLY")]
|
||||
|
||||
def test_mixed_formats(self):
|
||||
indices = _convert_aa_names_to_indices(["A", "GLY"], CANONICAL, LETTER_MAP)
|
||||
assert indices == [CANONICAL.index("ALA"), CANONICAL.index("GLY")]
|
||||
|
||||
def test_all_20_aas(self):
|
||||
all_letters = list("ACDEFGHIKLMNPQRSTVWY")
|
||||
indices = _convert_aa_names_to_indices(all_letters, CANONICAL, LETTER_MAP)
|
||||
assert len(indices) == 20
|
||||
assert len(set(indices)) == 20 # all unique
|
||||
|
||||
def test_invalid_letter_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown amino acid"):
|
||||
_convert_aa_names_to_indices(["X"], CANONICAL, LETTER_MAP)
|
||||
|
||||
def test_invalid_three_letter_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown amino acid"):
|
||||
_convert_aa_names_to_indices(["ZZZ"], CANONICAL, LETTER_MAP)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# parse_residue_constraints — valid inputs
|
||||
# ============================================================================
|
||||
|
||||
class TestParseResidueConstraintsValid:
|
||||
"""Tests for parse_residue_constraints with valid YAML specs."""
|
||||
|
||||
def test_empty_list_returns_zeros(self):
|
||||
mask = parse_residue_constraints([], 10, CANONICAL, LETTER_MAP)
|
||||
assert mask.shape == (10, 20)
|
||||
assert mask.sum() == 0.0
|
||||
|
||||
def test_single_allowed(self):
|
||||
spec = [{"position": 1, "allowed": "A"}]
|
||||
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
|
||||
ala_idx = CANONICAL.index("ALA")
|
||||
# Position 0 (1-indexed=1): only ALA allowed (0.0), rest blocked (1.0)
|
||||
assert mask[0, ala_idx] == 0.0
|
||||
assert mask[0].sum() == 19.0 # 19 blocked, 1 allowed
|
||||
# Other positions untouched
|
||||
assert mask[1:].sum() == 0.0
|
||||
|
||||
def test_single_disallowed(self):
|
||||
spec = [{"position": 3, "disallowed": "CM"}]
|
||||
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
|
||||
cys_idx = CANONICAL.index("CYS")
|
||||
met_idx = CANONICAL.index("MET")
|
||||
# Position 2 (1-indexed=3): CYS and MET blocked
|
||||
assert mask[2, cys_idx] == 1.0
|
||||
assert mask[2, met_idx] == 1.0
|
||||
assert mask[2].sum() == 2.0 # only 2 blocked
|
||||
|
||||
def test_range_positions(self):
|
||||
spec = [{"position": "3..5", "disallowed": "C"}]
|
||||
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
cys_idx = CANONICAL.index("CYS")
|
||||
# Positions 2,3,4 (1-indexed 3,4,5) should have CYS blocked
|
||||
for pos in [2, 3, 4]:
|
||||
assert mask[pos, cys_idx] == 1.0
|
||||
# Other positions untouched
|
||||
for pos in [0, 1, 5, 6, 7, 8, 9]:
|
||||
assert mask[pos, cys_idx] == 0.0
|
||||
|
||||
def test_allowed_multiple_aas(self):
|
||||
spec = [{"position": 8, "allowed": "AGS"}]
|
||||
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
ala_idx = CANONICAL.index("ALA")
|
||||
gly_idx = CANONICAL.index("GLY")
|
||||
ser_idx = CANONICAL.index("SER")
|
||||
# Position 7 (1-indexed=8): only A,G,S allowed
|
||||
assert mask[7, ala_idx] == 0.0
|
||||
assert mask[7, gly_idx] == 0.0
|
||||
assert mask[7, ser_idx] == 0.0
|
||||
assert mask[7].sum() == 17.0 # 17 blocked
|
||||
|
||||
def test_list_format_allowed(self):
|
||||
spec = [{"position": 1, "allowed": ["A", "G"]}]
|
||||
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
|
||||
ala_idx = CANONICAL.index("ALA")
|
||||
gly_idx = CANONICAL.index("GLY")
|
||||
assert mask[0, ala_idx] == 0.0
|
||||
assert mask[0, gly_idx] == 0.0
|
||||
assert mask[0].sum() == 18.0
|
||||
|
||||
def test_multiple_constraints_no_overlap(self):
|
||||
spec = [
|
||||
{"position": 1, "allowed": "A"},
|
||||
{"position": 5, "allowed": "P"},
|
||||
]
|
||||
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
|
||||
ala_idx = CANONICAL.index("ALA")
|
||||
pro_idx = CANONICAL.index("PRO")
|
||||
assert mask[0, ala_idx] == 0.0
|
||||
assert mask[0].sum() == 19.0
|
||||
assert mask[4, pro_idx] == 0.0
|
||||
assert mask[4].sum() == 19.0
|
||||
# Middle positions untouched
|
||||
assert mask[1:4].sum() == 0.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Intersection semantics (overlapping constraints)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_overlapping_allowed_intersection(self):
|
||||
"""Two allowed constraints on same position → only common AAs survive."""
|
||||
spec = [
|
||||
{"position": 1, "allowed": "AG"},
|
||||
{"position": 1, "allowed": "GS"},
|
||||
]
|
||||
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
|
||||
gly_idx = CANONICAL.index("GLY")
|
||||
ala_idx = CANONICAL.index("ALA")
|
||||
ser_idx = CANONICAL.index("SER")
|
||||
# Only G is in both sets
|
||||
assert mask[0, gly_idx] == 0.0 # allowed
|
||||
assert mask[0, ala_idx] == 1.0 # blocked (not in 2nd)
|
||||
assert mask[0, ser_idx] == 1.0 # blocked (not in 1st)
|
||||
assert mask[0].sum() == 19.0 # only GLY allowed
|
||||
|
||||
def test_overlapping_allowed_range_intersection(self):
|
||||
"""Overlapping ranges intersect at overlap positions."""
|
||||
spec = [
|
||||
{"position": "1..5", "allowed": "AG"},
|
||||
{"position": "3..7", "allowed": "GS"},
|
||||
]
|
||||
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
gly_idx = CANONICAL.index("GLY")
|
||||
ala_idx = CANONICAL.index("ALA")
|
||||
ser_idx = CANONICAL.index("SER")
|
||||
# Positions 0,1 (1-indexed 1,2): only AG (first constraint only)
|
||||
assert mask[0, ala_idx] == 0.0
|
||||
assert mask[0, gly_idx] == 0.0
|
||||
assert mask[0].sum() == 18.0
|
||||
# Positions 2,3,4 (1-indexed 3,4,5): intersection of {A,G} and {G,S} = {G}
|
||||
for pos in [2, 3, 4]:
|
||||
assert mask[pos, gly_idx] == 0.0
|
||||
assert mask[pos, ala_idx] == 1.0
|
||||
assert mask[pos, ser_idx] == 1.0
|
||||
assert mask[pos].sum() == 19.0
|
||||
# Positions 5,6 (1-indexed 6,7): only GS (second constraint only)
|
||||
assert mask[5, gly_idx] == 0.0
|
||||
assert mask[5, ser_idx] == 0.0
|
||||
assert mask[5].sum() == 18.0
|
||||
|
||||
def test_allowed_then_disallowed_same_position(self):
|
||||
"""allowed + disallowed on same position: disallowed narrows the set."""
|
||||
spec = [
|
||||
{"position": 5, "allowed": "AGILMV"},
|
||||
{"position": 5, "disallowed": "CM"},
|
||||
]
|
||||
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
met_idx = CANONICAL.index("MET")
|
||||
ala_idx = CANONICAL.index("ALA")
|
||||
# M was in allowed set but then blocked by disallowed
|
||||
assert mask[4, met_idx] == 1.0
|
||||
# A was in allowed set and not disallowed
|
||||
assert mask[4, ala_idx] == 0.0
|
||||
|
||||
def test_disallowed_then_allowed_same_position(self):
|
||||
"""Order independent: disallowed then allowed gives same result."""
|
||||
spec_ab = [
|
||||
{"position": 5, "allowed": "AGILMV"},
|
||||
{"position": 5, "disallowed": "CM"},
|
||||
]
|
||||
spec_ba = [
|
||||
{"position": 5, "disallowed": "CM"},
|
||||
{"position": 5, "allowed": "AGILMV"},
|
||||
]
|
||||
mask_ab = parse_residue_constraints(spec_ab, 10, CANONICAL, LETTER_MAP)
|
||||
mask_ba = parse_residue_constraints(spec_ba, 10, CANONICAL, LETTER_MAP)
|
||||
np.testing.assert_array_equal(mask_ab, mask_ba)
|
||||
|
||||
def test_disjoint_allowed_sets_all_blocked(self):
|
||||
"""Two allowed sets with no overlap → all 20 AAs blocked."""
|
||||
spec = [
|
||||
{"position": 1, "allowed": "AG"},
|
||||
{"position": 1, "allowed": "VILM"},
|
||||
]
|
||||
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
|
||||
# All 20 blocked at position 0
|
||||
assert mask[0].sum() == 20.0
|
||||
|
||||
def test_multiple_disallowed_accumulate(self):
|
||||
"""Multiple disallowed on same position: union of blocked sets."""
|
||||
spec = [
|
||||
{"position": 1, "disallowed": "CM"},
|
||||
{"position": 1, "disallowed": "WK"},
|
||||
]
|
||||
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
|
||||
cys_idx = CANONICAL.index("CYS")
|
||||
met_idx = CANONICAL.index("MET")
|
||||
trp_idx = CANONICAL.index("TRP")
|
||||
lys_idx = CANONICAL.index("LYS")
|
||||
assert mask[0, cys_idx] == 1.0
|
||||
assert mask[0, met_idx] == 1.0
|
||||
assert mask[0, trp_idx] == 1.0
|
||||
assert mask[0, lys_idx] == 1.0
|
||||
assert mask[0].sum() == 4.0
|
||||
|
||||
def test_dtype_and_shape(self):
|
||||
spec = [{"position": 1, "allowed": "A"}]
|
||||
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
assert mask.dtype == np.float32
|
||||
assert mask.shape == (10, 20)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# parse_residue_constraints — error paths
|
||||
# ============================================================================
|
||||
|
||||
class TestParseResidueConstraintsErrors:
|
||||
"""Tests for parse_residue_constraints with invalid YAML specs."""
|
||||
|
||||
def test_missing_position(self):
|
||||
spec = [{"allowed": "A"}]
|
||||
with pytest.raises(ValueError, match="position.*required"):
|
||||
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
|
||||
def test_position_out_of_bounds_high(self):
|
||||
spec = [{"position": 11, "allowed": "A"}]
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
|
||||
def test_position_out_of_bounds_zero(self):
|
||||
# Position 0 is invalid (1-indexed); parse_range catches this
|
||||
spec = [{"position": 0, "allowed": "A"}]
|
||||
with pytest.raises(ValueError, match="1 indexed|out of bounds"):
|
||||
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
|
||||
def test_both_allowed_and_disallowed(self):
|
||||
spec = [{"position": 1, "allowed": "A", "disallowed": "C"}]
|
||||
with pytest.raises(ValueError, match="cannot specify both"):
|
||||
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
|
||||
def test_neither_allowed_nor_disallowed(self):
|
||||
spec = [{"position": 1}]
|
||||
with pytest.raises(ValueError, match="must specify either"):
|
||||
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
|
||||
def test_empty_allowed(self):
|
||||
spec = [{"position": 1, "allowed": ""}]
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
|
||||
def test_invalid_amino_acid_code(self):
|
||||
spec = [{"position": 1, "allowed": "X"}]
|
||||
with pytest.raises(ValueError, match="Unknown amino acid"):
|
||||
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
|
||||
def test_invalid_amino_acid_in_disallowed(self):
|
||||
spec = [{"position": 1, "disallowed": "XZ"}]
|
||||
with pytest.raises(ValueError, match="Unknown amino acid"):
|
||||
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Regression: original test case (no overlaps)
|
||||
# ============================================================================
|
||||
|
||||
class TestOriginalTestCase:
|
||||
"""Regression test matching residue_constraints_test.yaml."""
|
||||
|
||||
def test_original_yaml_constraints(self):
|
||||
"""Matches the constraints from example/residue_constraints_test.yaml."""
|
||||
spec = [
|
||||
{"position": 1, "allowed": "A"},
|
||||
{"position": "3..5", "disallowed": "CM"},
|
||||
{"position": 8, "allowed": "AGS"},
|
||||
{"position": 10, "allowed": "P"},
|
||||
]
|
||||
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
|
||||
|
||||
ala_idx = CANONICAL.index("ALA")
|
||||
cys_idx = CANONICAL.index("CYS")
|
||||
met_idx = CANONICAL.index("MET")
|
||||
gly_idx = CANONICAL.index("GLY")
|
||||
ser_idx = CANONICAL.index("SER")
|
||||
pro_idx = CANONICAL.index("PRO")
|
||||
|
||||
# Position 1: only A
|
||||
assert mask[0, ala_idx] == 0.0
|
||||
assert mask[0].sum() == 19.0
|
||||
|
||||
# Positions 3-5: C and M blocked
|
||||
for pos in [2, 3, 4]:
|
||||
assert mask[pos, cys_idx] == 1.0
|
||||
assert mask[pos, met_idx] == 1.0
|
||||
assert mask[pos].sum() == 2.0
|
||||
|
||||
# Position 8: only A, G, S
|
||||
assert mask[7, ala_idx] == 0.0
|
||||
assert mask[7, gly_idx] == 0.0
|
||||
assert mask[7, ser_idx] == 0.0
|
||||
assert mask[7].sum() == 17.0
|
||||
|
||||
# Position 10: only P
|
||||
assert mask[9, pro_idx] == 0.0
|
||||
assert mask[9].sum() == 19.0
|
||||
|
||||
# Unconstrained positions (2, 6, 7, 9 in 0-indexed) are all zeros
|
||||
for pos in [1, 5, 6, 8]:
|
||||
assert mask[pos].sum() == 0.0
|
||||
Reference in New Issue
Block a user