fix: resolve inverse-fold constraint conflicts with global avoid

Idead by buerbaumer@ and quality checked with Codex 5.3
This commit is contained in:
Harald Buerbaumer
2026-02-16 18:18:40 +01:00
committed by Hannes Stärk
parent 0822cb3b71
commit 39c87c65d6
3 changed files with 188 additions and 56 deletions

View File

@@ -1,4 +1,4 @@
from typing import Dict, Tuple, List
from typing import Dict, Tuple, List, Optional
import warnings
import torch
@@ -39,6 +39,75 @@ def softmax_dropout(
)
def build_constraint_logit_mask(
num_nodes: int,
aa_constraint_mask: Optional[Tensor],
inverse_fold_restriction: list[str],
canonical_tokens: list[str],
inf: float,
device: torch.device,
) -> Tensor:
"""Build per-position inverse-folding logit mask.
The mask uses additive logit bias semantics:
0.0 = allowed, -inf = disallowed.
"""
num_aa = len(canonical_tokens)
has_per_residue_constraints = False
if aa_constraint_mask is None:
per_residue_blocked = torch.zeros(
num_nodes, num_aa, dtype=torch.bool, device=device
)
else:
expected_shape = (num_nodes, num_aa)
if aa_constraint_mask.shape != expected_shape:
warnings.warn(
f"aa_constraint_mask shape mismatch: "
f"got {aa_constraint_mask.shape}, expected {expected_shape}. "
f"Ignoring per-residue constraints.",
RuntimeWarning,
stacklevel=2,
)
per_residue_blocked = torch.zeros(
num_nodes, num_aa, dtype=torch.bool, device=device
)
else:
has_per_residue_constraints = True
per_residue_blocked = aa_constraint_mask.to(device=device) > 0
global_blocked = torch.zeros(num_aa, dtype=torch.bool, device=device)
for res_type in inverse_fold_restriction:
global_blocked[canonical_tokens.index(res_type)] = True
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
all_blocked = combined_blocked.all(dim=1)
if all_blocked.any() and has_per_residue_constraints:
blocked_positions = torch.where(all_blocked)[0].tolist()
warnings.warn(
f"Positions {blocked_positions} have all amino acids blocked by the "
f"combination of per-residue constraints and '--inverse_fold_avoid'. "
f"Relaxing per-residue constraints for these positions.",
RuntimeWarning,
stacklevel=2,
)
per_residue_blocked = per_residue_blocked.clone()
per_residue_blocked[all_blocked] = False
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
still_all_blocked = combined_blocked.all(dim=1)
if still_all_blocked.any():
blocked_positions = torch.where(still_all_blocked)[0].tolist()
raise ValueError(
f"Inverse folding has no valid amino acids at token positions "
f"{blocked_positions} after applying '--inverse_fold_avoid'. "
f"Reduce global restrictions to keep at least one amino acid."
)
return combined_blocked.to(dtype=torch.float32) * (-inf)
class MLPAttnGNN(nn.Module):
def __init__(
self,
@@ -589,53 +658,17 @@ class InverseFoldingDecoder(nn.Module):
f"num_design: {num_design}, num_not_design: {num_not_design}"
)
# Create per-residue restriction mask
# Initialize with per-residue constraints from YAML if available
constraint_mask = None
if "aa_constraint_mask" in feats:
constraint_mask = feats["aa_constraint_mask"][valid_mask]
# Validate shape: should be (num_nodes, 20) for 20 canonical amino acids
expected_shape = (num_nodes, len(const.canonical_tokens))
if constraint_mask.shape != expected_shape:
warnings.warn(
f"aa_constraint_mask shape mismatch: "
f"got {constraint_mask.shape}, expected {expected_shape}. "
f"Ignoring per-residue constraints.",
RuntimeWarning,
stacklevel=2,
)
per_residue_mask = torch.zeros(num_nodes, len(const.canonical_tokens), device=s.device)
else:
# Check for positions with ALL AAs blocked (would leave no valid options).
# This can happen when per-residue 'allowed' and global '--inverse_fold_avoid'
# combine to eliminate every amino acid at a position.
all_blocked = constraint_mask.sum(dim=1) >= len(const.canonical_tokens)
if all_blocked.any():
blocked_positions = torch.where(all_blocked)[0].tolist()
warnings.warn(
f"Positions {blocked_positions} have ALL amino acids blocked "
f"(likely conflict between per-residue 'allowed' and global "
f"'--inverse_fold_avoid'). Relaxing constraints for these positions "
f"to allow all amino acids.",
RuntimeWarning,
stacklevel=2,
)
constraint_mask[all_blocked] = 0 # Allow all AAs for fully-blocked positions
# Per-residue constraints: values 0=allowed, 1=disallowed
per_residue_mask = constraint_mask.clone()
# Convert binary mask to logit bias: 0 -> 0, 1 -> -inf
per_residue_mask = per_residue_mask * (-self.inf)
else:
# No per-residue constraints: all AAs allowed
per_residue_mask = torch.zeros(num_nodes, len(const.canonical_tokens), device=s.device)
# Add global restriction (backward compatible with --inverse_fold_avoid CLI)
if len(self.inverse_fold_restriction) > 0:
global_mask = torch.zeros(len(const.canonical_tokens), device=s.device)
for res_type in self.inverse_fold_restriction:
global_mask[const.canonical_tokens.index(res_type)] = -self.inf
# Combine: global restrictions apply to ALL positions additively
per_residue_mask = per_residue_mask + global_mask.unsqueeze(0)
per_residue_mask = build_constraint_logit_mask(
num_nodes=num_nodes,
aa_constraint_mask=constraint_mask,
inverse_fold_restriction=self.inverse_fold_restriction,
canonical_tokens=const.canonical_tokens,
inf=self.inf,
device=s.device,
)
order = torch.randperm(num_nodes, device=s.device).cpu().numpy().tolist()
# Non-design residues are not sampled and used as the condition. So the order should filter them out.

View File

@@ -1,15 +1,12 @@
"""Test configuration: mock heavy dependencies for CPU-only unit tests.
"""Optional test configuration for mocking heavy dependencies.
The per-residue constraint functions under test (_normalize_aa_spec,
_convert_aa_names_to_indices, parse_residue_constraints) only use numpy
and the boltzgen.data.const module. However, schema.py transitively
imports torch, pytorch_lightning, etc. via other boltzgen modules.
Enable with:
pytest --mock-heavy-deps tests/test_residue_constraints.py
This conftest patches those heavy imports so tests can run without GPU
libraries installed — achieving the "Level 1: No GPU, fast" goal.
By default no mocking is performed, so integration tests run against
real dependencies.
"""
import sys
from types import ModuleType
from unittest.mock import MagicMock
@@ -66,5 +63,17 @@ _MOCK_MODULES = [
"Bio.PDB",
]
for mod in _MOCK_MODULES:
_install_mock(mod)
def pytest_addoption(parser) -> None:
parser.addoption(
"--mock-heavy-deps",
action="store_true",
default=False,
help="Mock heavy optional dependencies for parser-only unit tests.",
)
def pytest_configure(config) -> None:
if config.getoption("--mock-heavy-deps"):
for mod in _MOCK_MODULES:
_install_mock(mod)

View File

@@ -0,0 +1,90 @@
"""Integration tests for inverse-folding constraint mask composition."""
import pytest
torch = pytest.importorskip("torch")
from boltzgen.data import const
from boltzgen.model.modules.inverse_fold import build_constraint_logit_mask
INF = 10**6
def _allowed_only_mask(allowed_tokens: list[str]) -> torch.Tensor:
"""Build a single-row mask where only `allowed_tokens` are permitted."""
num_aa = len(const.canonical_tokens)
mask = torch.ones((1, num_aa), dtype=torch.float32)
for token in allowed_tokens:
mask[0, const.canonical_tokens.index(token)] = 0.0
return mask
def test_conflict_allowed_and_global_avoid_keeps_global_restriction() -> None:
cys_idx = const.canonical_tokens.index("CYS")
aa_constraint_mask = _allowed_only_mask(["CYS"])
with pytest.warns(RuntimeWarning, match="Relaxing per-residue constraints"):
out = build_constraint_logit_mask(
num_nodes=1,
aa_constraint_mask=aa_constraint_mask,
inverse_fold_restriction=["CYS"],
canonical_tokens=const.canonical_tokens,
inf=INF,
device=torch.device("cpu"),
)
# Global avoid must still block CYS after conflict handling.
assert out[0, cys_idx].item() == -INF
# All other residues remain available.
assert (out[0] == 0).sum().item() == len(const.canonical_tokens) - 1
def test_non_conflicting_constraints_compose_correctly() -> None:
ala_idx = const.canonical_tokens.index("ALA")
cys_idx = const.canonical_tokens.index("CYS")
aa_constraint_mask = _allowed_only_mask(["ALA"])
out = build_constraint_logit_mask(
num_nodes=1,
aa_constraint_mask=aa_constraint_mask,
inverse_fold_restriction=["CYS"],
canonical_tokens=const.canonical_tokens,
inf=INF,
device=torch.device("cpu"),
)
# Only ALA should remain available.
assert out[0, ala_idx].item() == 0.0
assert out[0, cys_idx].item() == -INF
assert (out[0] == 0).sum().item() == 1
def test_global_restrictions_that_block_all_raise() -> None:
with pytest.raises(ValueError, match="no valid amino acids"):
build_constraint_logit_mask(
num_nodes=1,
aa_constraint_mask=None,
inverse_fold_restriction=const.canonical_tokens,
canonical_tokens=const.canonical_tokens,
inf=INF,
device=torch.device("cpu"),
)
def test_shape_mismatch_ignores_per_residue_mask() -> None:
bad_shape = torch.zeros((2, 20), dtype=torch.float32)
with pytest.warns(RuntimeWarning, match="shape mismatch"):
out = build_constraint_logit_mask(
num_nodes=1,
aa_constraint_mask=bad_shape,
inverse_fold_restriction=[],
canonical_tokens=const.canonical_tokens,
inf=INF,
device=torch.device("cpu"),
)
# No restrictions should remain after ignoring mismatched input.
assert out.shape == (1, len(const.canonical_tokens))
assert torch.all(out == 0)