fix: pad constraint mask for file entities, add tests and examples

- Fix crash when YAML combines residue_constraints with file entities
  (missing res_aa_constraint_mask concatenation for CIF-loaded chains)
- Add CPU-only unit tests for constraint parsing (30 tests, no GPU needed)
- Add residue_constraints_test.yaml example
- Update showcase YAML with residue_constraints demo

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Harald Buerbaumer
2026-02-15 10:02:29 +01:00
committed by Hannes Stärk
parent 6d73b21d33
commit 0822cb3b71
5 changed files with 495 additions and 1 deletions

View File

@@ -1,7 +1,14 @@
entities:
- protein:
- protein:
id: G
sequence: 15..20AAAAAAVTTTT18PPP # range between 15 and 20 inclusive on both sides
residue_constraints:
- position: 1
allowed: A # Only Alanine at position 1
- position: 3..5
disallowed: CM # No Cysteine or Methionine at positions 3-5
- position: 8
allowed: AGS # Only Ala, Gly, or Ser at position 8
- protein:
id: R
sequence: 3..5C6C3 # Random number of design residues between 3 and 5, then a Cystein, then 6 design residues, then ...

View File

@@ -0,0 +1,33 @@
# Test file for per-residue amino acid constraints
# This tests the new residue_constraints feature
#
# Usage:
# boltzgen check example/residue_constraints_test.yaml
# boltzgen run example/residue_constraints_test.yaml --output test_output/ --steps design inverse_folding --num_designs 50
#
# Note: Use --num_designs 50 (not 5) for statistically meaningful verification.
# With only 5 designs, blacklist constraints have a ~21% false-pass probability.
#
# Note: Supports both string format ("AGS") and list format ([A, G, S])
# String format is preferred for consistency with sequence/binding_types
entities:
- protein:
id: A
sequence: 10
residue_constraints:
# Position 1: Force Alanine only
- position: 1
allowed: A
# Positions 3-5: Exclude Cysteine and Methionine
- position: 3..5
disallowed: CM
# Position 8: Allow only small amino acids
- position: 8
allowed: AGS
# Position 10: Force Proline
- position: 10
allowed: P

View File

@@ -1662,6 +1662,9 @@ class YamlDesignParser:
res_design_mask = np.concatenate([res_design_mask, new_design_mask])
res_bind_type = np.concatenate([res_bind_type, fbind_types])
ss_type = np.concatenate([ss_type, fss_type])
# File entities have no residue constraints — pad with zeros (all AAs allowed)
file_constraint_mask = np.zeros((len(new_design_mask), len(const.canonical_tokens)), dtype=np.float32)
res_aa_constraint_mask = np.concatenate([res_aa_constraint_mask, file_constraint_mask], axis=0)
extra_mols.update(new_extra_mols)
if len(renaming) > 0:
msg = f"\nChain ids conflict with existing chain ids. Renaming with {renaming}. This is for the structure from '{path}'."

70
tests/conftest.py Normal file
View File

@@ -0,0 +1,70 @@
"""Test configuration: mock heavy dependencies for CPU-only unit tests.
The per-residue constraint functions under test (_normalize_aa_spec,
_convert_aa_names_to_indices, parse_residue_constraints) only use numpy
and the boltzgen.data.const module. However, schema.py transitively
imports torch, pytorch_lightning, etc. via other boltzgen modules.
This conftest patches those heavy imports so tests can run without GPU
libraries installed — achieving the "Level 1: No GPU, fast" goal.
"""
import sys
from types import ModuleType
from unittest.mock import MagicMock
def _install_mock(name: str) -> None:
"""Install a mock module (and parent packages) into sys.modules."""
parts = name.split(".")
for i in range(len(parts)):
mod_name = ".".join(parts[: i + 1])
if mod_name not in sys.modules:
sys.modules[mod_name] = MagicMock()
# Heavy dependencies that schema.py imports transitively but are NOT
# needed by the three constraint-parsing functions under test.
_MOCK_MODULES = [
"torch",
"torch.nn",
"torch.nn.functional",
"torch.utils",
"torch.utils.data",
"pytorch_lightning",
"hydra",
"hydra.core",
"hydra.core.config_store",
"einops",
"einx",
"mashumaro",
"biotite",
"biotite.structure",
"biotite.structure.io",
"biotite.structure.io.pdbx",
"pydssp",
"logomaker",
"hydride",
"gemmi",
"pdbeccdutils",
"pdbeccdutils.core",
"pdbeccdutils.core.ccd_reader",
"edit_distance",
"huggingface_hub",
"nvidia_ml_py",
"cuequivariance_ops_cu12",
"cuequivariance_ops_torch_cu12",
"cuequivariance_torch",
"numba",
"sklearn",
"sklearn.cluster",
"sklearn.neighbors",
"pandas",
"matplotlib",
"matplotlib.pyplot",
"tqdm",
"Bio",
"Bio.PDB",
]
for mod in _MOCK_MODULES:
_install_mock(mod)

View File

@@ -0,0 +1,381 @@
"""Unit tests for per-residue amino acid constraint parsing.
Tests parse_residue_constraints(), _normalize_aa_spec(), and
_convert_aa_names_to_indices() from boltzgen.data.parse.schema.
"""
import numpy as np
import pytest
from boltzgen.data import const
from boltzgen.data.parse.schema import (
_convert_aa_names_to_indices,
_normalize_aa_spec,
parse_residue_constraints,
)
# Shorthand fixtures
CANONICAL = const.canonical_tokens # 20 three-letter codes
LETTER_MAP = const.prot_letter_to_token # e.g. {"A": "ALA", ...}
# ============================================================================
# _normalize_aa_spec
# ============================================================================
class TestNormalizeAASpec:
"""Tests for _normalize_aa_spec helper."""
def test_single_letter(self):
assert _normalize_aa_spec("A") == ["A"]
def test_multi_letter_string(self):
assert _normalize_aa_spec("AGS") == ["A", "G", "S"]
def test_long_string(self):
assert _normalize_aa_spec("AVILMFYW") == list("AVILMFYW")
def test_three_letter_code(self):
assert _normalize_aa_spec("ALA") == ["ALA"]
def test_three_letter_not_valid(self):
# "AGS" is 3 chars but NOT a valid 3-letter code → split into 1-letter
assert _normalize_aa_spec("AGS") == ["A", "G", "S"]
def test_list_format_single_letters(self):
assert _normalize_aa_spec(["A", "G", "S"]) == ["A", "G", "S"]
def test_list_format_three_letter(self):
assert _normalize_aa_spec(["ALA", "GLY"]) == ["ALA", "GLY"]
def test_lowercase_normalised(self):
assert _normalize_aa_spec("ags") == ["A", "G", "S"]
def test_whitespace_stripped(self):
assert _normalize_aa_spec(" AG ") == ["A", "G"]
def test_invalid_type_raises(self):
with pytest.raises(ValueError):
_normalize_aa_spec(123)
# ============================================================================
# _convert_aa_names_to_indices
# ============================================================================
class TestConvertAANamesToIndices:
"""Tests for _convert_aa_names_to_indices helper."""
def test_single_letter_a(self):
indices = _convert_aa_names_to_indices(["A"], CANONICAL, LETTER_MAP)
assert indices == [CANONICAL.index("ALA")]
def test_single_letter_c(self):
indices = _convert_aa_names_to_indices(["C"], CANONICAL, LETTER_MAP)
assert indices == [CANONICAL.index("CYS")]
def test_three_letter_code(self):
indices = _convert_aa_names_to_indices(["ALA", "GLY"], CANONICAL, LETTER_MAP)
assert indices == [CANONICAL.index("ALA"), CANONICAL.index("GLY")]
def test_mixed_formats(self):
indices = _convert_aa_names_to_indices(["A", "GLY"], CANONICAL, LETTER_MAP)
assert indices == [CANONICAL.index("ALA"), CANONICAL.index("GLY")]
def test_all_20_aas(self):
all_letters = list("ACDEFGHIKLMNPQRSTVWY")
indices = _convert_aa_names_to_indices(all_letters, CANONICAL, LETTER_MAP)
assert len(indices) == 20
assert len(set(indices)) == 20 # all unique
def test_invalid_letter_raises(self):
with pytest.raises(ValueError, match="Unknown amino acid"):
_convert_aa_names_to_indices(["X"], CANONICAL, LETTER_MAP)
def test_invalid_three_letter_raises(self):
with pytest.raises(ValueError, match="Unknown amino acid"):
_convert_aa_names_to_indices(["ZZZ"], CANONICAL, LETTER_MAP)
# ============================================================================
# parse_residue_constraints — valid inputs
# ============================================================================
class TestParseResidueConstraintsValid:
"""Tests for parse_residue_constraints with valid YAML specs."""
def test_empty_list_returns_zeros(self):
mask = parse_residue_constraints([], 10, CANONICAL, LETTER_MAP)
assert mask.shape == (10, 20)
assert mask.sum() == 0.0
def test_single_allowed(self):
spec = [{"position": 1, "allowed": "A"}]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
# Position 0 (1-indexed=1): only ALA allowed (0.0), rest blocked (1.0)
assert mask[0, ala_idx] == 0.0
assert mask[0].sum() == 19.0 # 19 blocked, 1 allowed
# Other positions untouched
assert mask[1:].sum() == 0.0
def test_single_disallowed(self):
spec = [{"position": 3, "disallowed": "CM"}]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
cys_idx = CANONICAL.index("CYS")
met_idx = CANONICAL.index("MET")
# Position 2 (1-indexed=3): CYS and MET blocked
assert mask[2, cys_idx] == 1.0
assert mask[2, met_idx] == 1.0
assert mask[2].sum() == 2.0 # only 2 blocked
def test_range_positions(self):
spec = [{"position": "3..5", "disallowed": "C"}]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
cys_idx = CANONICAL.index("CYS")
# Positions 2,3,4 (1-indexed 3,4,5) should have CYS blocked
for pos in [2, 3, 4]:
assert mask[pos, cys_idx] == 1.0
# Other positions untouched
for pos in [0, 1, 5, 6, 7, 8, 9]:
assert mask[pos, cys_idx] == 0.0
def test_allowed_multiple_aas(self):
spec = [{"position": 8, "allowed": "AGS"}]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
gly_idx = CANONICAL.index("GLY")
ser_idx = CANONICAL.index("SER")
# Position 7 (1-indexed=8): only A,G,S allowed
assert mask[7, ala_idx] == 0.0
assert mask[7, gly_idx] == 0.0
assert mask[7, ser_idx] == 0.0
assert mask[7].sum() == 17.0 # 17 blocked
def test_list_format_allowed(self):
spec = [{"position": 1, "allowed": ["A", "G"]}]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
gly_idx = CANONICAL.index("GLY")
assert mask[0, ala_idx] == 0.0
assert mask[0, gly_idx] == 0.0
assert mask[0].sum() == 18.0
def test_multiple_constraints_no_overlap(self):
spec = [
{"position": 1, "allowed": "A"},
{"position": 5, "allowed": "P"},
]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
pro_idx = CANONICAL.index("PRO")
assert mask[0, ala_idx] == 0.0
assert mask[0].sum() == 19.0
assert mask[4, pro_idx] == 0.0
assert mask[4].sum() == 19.0
# Middle positions untouched
assert mask[1:4].sum() == 0.0
# ------------------------------------------------------------------
# Intersection semantics (overlapping constraints)
# ------------------------------------------------------------------
def test_overlapping_allowed_intersection(self):
"""Two allowed constraints on same position → only common AAs survive."""
spec = [
{"position": 1, "allowed": "AG"},
{"position": 1, "allowed": "GS"},
]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
gly_idx = CANONICAL.index("GLY")
ala_idx = CANONICAL.index("ALA")
ser_idx = CANONICAL.index("SER")
# Only G is in both sets
assert mask[0, gly_idx] == 0.0 # allowed
assert mask[0, ala_idx] == 1.0 # blocked (not in 2nd)
assert mask[0, ser_idx] == 1.0 # blocked (not in 1st)
assert mask[0].sum() == 19.0 # only GLY allowed
def test_overlapping_allowed_range_intersection(self):
"""Overlapping ranges intersect at overlap positions."""
spec = [
{"position": "1..5", "allowed": "AG"},
{"position": "3..7", "allowed": "GS"},
]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
gly_idx = CANONICAL.index("GLY")
ala_idx = CANONICAL.index("ALA")
ser_idx = CANONICAL.index("SER")
# Positions 0,1 (1-indexed 1,2): only AG (first constraint only)
assert mask[0, ala_idx] == 0.0
assert mask[0, gly_idx] == 0.0
assert mask[0].sum() == 18.0
# Positions 2,3,4 (1-indexed 3,4,5): intersection of {A,G} and {G,S} = {G}
for pos in [2, 3, 4]:
assert mask[pos, gly_idx] == 0.0
assert mask[pos, ala_idx] == 1.0
assert mask[pos, ser_idx] == 1.0
assert mask[pos].sum() == 19.0
# Positions 5,6 (1-indexed 6,7): only GS (second constraint only)
assert mask[5, gly_idx] == 0.0
assert mask[5, ser_idx] == 0.0
assert mask[5].sum() == 18.0
def test_allowed_then_disallowed_same_position(self):
"""allowed + disallowed on same position: disallowed narrows the set."""
spec = [
{"position": 5, "allowed": "AGILMV"},
{"position": 5, "disallowed": "CM"},
]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
met_idx = CANONICAL.index("MET")
ala_idx = CANONICAL.index("ALA")
# M was in allowed set but then blocked by disallowed
assert mask[4, met_idx] == 1.0
# A was in allowed set and not disallowed
assert mask[4, ala_idx] == 0.0
def test_disallowed_then_allowed_same_position(self):
"""Order independent: disallowed then allowed gives same result."""
spec_ab = [
{"position": 5, "allowed": "AGILMV"},
{"position": 5, "disallowed": "CM"},
]
spec_ba = [
{"position": 5, "disallowed": "CM"},
{"position": 5, "allowed": "AGILMV"},
]
mask_ab = parse_residue_constraints(spec_ab, 10, CANONICAL, LETTER_MAP)
mask_ba = parse_residue_constraints(spec_ba, 10, CANONICAL, LETTER_MAP)
np.testing.assert_array_equal(mask_ab, mask_ba)
def test_disjoint_allowed_sets_all_blocked(self):
"""Two allowed sets with no overlap → all 20 AAs blocked."""
spec = [
{"position": 1, "allowed": "AG"},
{"position": 1, "allowed": "VILM"},
]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
# All 20 blocked at position 0
assert mask[0].sum() == 20.0
def test_multiple_disallowed_accumulate(self):
"""Multiple disallowed on same position: union of blocked sets."""
spec = [
{"position": 1, "disallowed": "CM"},
{"position": 1, "disallowed": "WK"},
]
mask = parse_residue_constraints(spec, 5, CANONICAL, LETTER_MAP)
cys_idx = CANONICAL.index("CYS")
met_idx = CANONICAL.index("MET")
trp_idx = CANONICAL.index("TRP")
lys_idx = CANONICAL.index("LYS")
assert mask[0, cys_idx] == 1.0
assert mask[0, met_idx] == 1.0
assert mask[0, trp_idx] == 1.0
assert mask[0, lys_idx] == 1.0
assert mask[0].sum() == 4.0
def test_dtype_and_shape(self):
spec = [{"position": 1, "allowed": "A"}]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
assert mask.dtype == np.float32
assert mask.shape == (10, 20)
# ============================================================================
# parse_residue_constraints — error paths
# ============================================================================
class TestParseResidueConstraintsErrors:
"""Tests for parse_residue_constraints with invalid YAML specs."""
def test_missing_position(self):
spec = [{"allowed": "A"}]
with pytest.raises(ValueError, match="position.*required"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_position_out_of_bounds_high(self):
spec = [{"position": 11, "allowed": "A"}]
with pytest.raises(ValueError, match="out of bounds"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_position_out_of_bounds_zero(self):
# Position 0 is invalid (1-indexed); parse_range catches this
spec = [{"position": 0, "allowed": "A"}]
with pytest.raises(ValueError, match="1 indexed|out of bounds"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_both_allowed_and_disallowed(self):
spec = [{"position": 1, "allowed": "A", "disallowed": "C"}]
with pytest.raises(ValueError, match="cannot specify both"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_neither_allowed_nor_disallowed(self):
spec = [{"position": 1}]
with pytest.raises(ValueError, match="must specify either"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_empty_allowed(self):
spec = [{"position": 1, "allowed": ""}]
with pytest.raises(ValueError, match="cannot be empty"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_invalid_amino_acid_code(self):
spec = [{"position": 1, "allowed": "X"}]
with pytest.raises(ValueError, match="Unknown amino acid"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
def test_invalid_amino_acid_in_disallowed(self):
spec = [{"position": 1, "disallowed": "XZ"}]
with pytest.raises(ValueError, match="Unknown amino acid"):
parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
# ============================================================================
# Regression: original test case (no overlaps)
# ============================================================================
class TestOriginalTestCase:
"""Regression test matching residue_constraints_test.yaml."""
def test_original_yaml_constraints(self):
"""Matches the constraints from example/residue_constraints_test.yaml."""
spec = [
{"position": 1, "allowed": "A"},
{"position": "3..5", "disallowed": "CM"},
{"position": 8, "allowed": "AGS"},
{"position": 10, "allowed": "P"},
]
mask = parse_residue_constraints(spec, 10, CANONICAL, LETTER_MAP)
ala_idx = CANONICAL.index("ALA")
cys_idx = CANONICAL.index("CYS")
met_idx = CANONICAL.index("MET")
gly_idx = CANONICAL.index("GLY")
ser_idx = CANONICAL.index("SER")
pro_idx = CANONICAL.index("PRO")
# Position 1: only A
assert mask[0, ala_idx] == 0.0
assert mask[0].sum() == 19.0
# Positions 3-5: C and M blocked
for pos in [2, 3, 4]:
assert mask[pos, cys_idx] == 1.0
assert mask[pos, met_idx] == 1.0
assert mask[pos].sum() == 2.0
# Position 8: only A, G, S
assert mask[7, ala_idx] == 0.0
assert mask[7, gly_idx] == 0.0
assert mask[7, ser_idx] == 0.0
assert mask[7].sum() == 17.0
# Position 10: only P
assert mask[9, pro_idx] == 0.0
assert mask[9].sum() == 19.0
# Unconstrained positions (2, 6, 7, 9 in 0-indexed) are all zeros
for pos in [1, 5, 6, 8]:
assert mask[pos].sum() == 0.0