mirror of
https://github.com/HannesStark/boltzgen.git
synced 2026-06-04 11:54:23 +08:00
fix: resolve inverse-fold constraint conflicts with global avoid
Idead by buerbaumer@ and quality checked with Codex 5.3
This commit is contained in:
committed by
Hannes Stärk
parent
0822cb3b71
commit
39c87c65d6
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Tuple, List
|
||||
from typing import Dict, Tuple, List, Optional
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
@@ -39,6 +39,75 @@ def softmax_dropout(
|
||||
)
|
||||
|
||||
|
||||
def build_constraint_logit_mask(
|
||||
num_nodes: int,
|
||||
aa_constraint_mask: Optional[Tensor],
|
||||
inverse_fold_restriction: list[str],
|
||||
canonical_tokens: list[str],
|
||||
inf: float,
|
||||
device: torch.device,
|
||||
) -> Tensor:
|
||||
"""Build per-position inverse-folding logit mask.
|
||||
|
||||
The mask uses additive logit bias semantics:
|
||||
0.0 = allowed, -inf = disallowed.
|
||||
"""
|
||||
num_aa = len(canonical_tokens)
|
||||
has_per_residue_constraints = False
|
||||
|
||||
if aa_constraint_mask is None:
|
||||
per_residue_blocked = torch.zeros(
|
||||
num_nodes, num_aa, dtype=torch.bool, device=device
|
||||
)
|
||||
else:
|
||||
expected_shape = (num_nodes, num_aa)
|
||||
if aa_constraint_mask.shape != expected_shape:
|
||||
warnings.warn(
|
||||
f"aa_constraint_mask shape mismatch: "
|
||||
f"got {aa_constraint_mask.shape}, expected {expected_shape}. "
|
||||
f"Ignoring per-residue constraints.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
per_residue_blocked = torch.zeros(
|
||||
num_nodes, num_aa, dtype=torch.bool, device=device
|
||||
)
|
||||
else:
|
||||
has_per_residue_constraints = True
|
||||
per_residue_blocked = aa_constraint_mask.to(device=device) > 0
|
||||
|
||||
global_blocked = torch.zeros(num_aa, dtype=torch.bool, device=device)
|
||||
for res_type in inverse_fold_restriction:
|
||||
global_blocked[canonical_tokens.index(res_type)] = True
|
||||
|
||||
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
|
||||
all_blocked = combined_blocked.all(dim=1)
|
||||
|
||||
if all_blocked.any() and has_per_residue_constraints:
|
||||
blocked_positions = torch.where(all_blocked)[0].tolist()
|
||||
warnings.warn(
|
||||
f"Positions {blocked_positions} have all amino acids blocked by the "
|
||||
f"combination of per-residue constraints and '--inverse_fold_avoid'. "
|
||||
f"Relaxing per-residue constraints for these positions.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
per_residue_blocked = per_residue_blocked.clone()
|
||||
per_residue_blocked[all_blocked] = False
|
||||
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
|
||||
|
||||
still_all_blocked = combined_blocked.all(dim=1)
|
||||
if still_all_blocked.any():
|
||||
blocked_positions = torch.where(still_all_blocked)[0].tolist()
|
||||
raise ValueError(
|
||||
f"Inverse folding has no valid amino acids at token positions "
|
||||
f"{blocked_positions} after applying '--inverse_fold_avoid'. "
|
||||
f"Reduce global restrictions to keep at least one amino acid."
|
||||
)
|
||||
|
||||
return combined_blocked.to(dtype=torch.float32) * (-inf)
|
||||
|
||||
|
||||
class MLPAttnGNN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -589,53 +658,17 @@ class InverseFoldingDecoder(nn.Module):
|
||||
f"num_design: {num_design}, num_not_design: {num_not_design}"
|
||||
)
|
||||
|
||||
# Create per-residue restriction mask
|
||||
# Initialize with per-residue constraints from YAML if available
|
||||
constraint_mask = None
|
||||
if "aa_constraint_mask" in feats:
|
||||
constraint_mask = feats["aa_constraint_mask"][valid_mask]
|
||||
# Validate shape: should be (num_nodes, 20) for 20 canonical amino acids
|
||||
expected_shape = (num_nodes, len(const.canonical_tokens))
|
||||
if constraint_mask.shape != expected_shape:
|
||||
warnings.warn(
|
||||
f"aa_constraint_mask shape mismatch: "
|
||||
f"got {constraint_mask.shape}, expected {expected_shape}. "
|
||||
f"Ignoring per-residue constraints.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
per_residue_mask = build_constraint_logit_mask(
|
||||
num_nodes=num_nodes,
|
||||
aa_constraint_mask=constraint_mask,
|
||||
inverse_fold_restriction=self.inverse_fold_restriction,
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=self.inf,
|
||||
device=s.device,
|
||||
)
|
||||
per_residue_mask = torch.zeros(num_nodes, len(const.canonical_tokens), device=s.device)
|
||||
else:
|
||||
# Check for positions with ALL AAs blocked (would leave no valid options).
|
||||
# This can happen when per-residue 'allowed' and global '--inverse_fold_avoid'
|
||||
# combine to eliminate every amino acid at a position.
|
||||
all_blocked = constraint_mask.sum(dim=1) >= len(const.canonical_tokens)
|
||||
if all_blocked.any():
|
||||
blocked_positions = torch.where(all_blocked)[0].tolist()
|
||||
warnings.warn(
|
||||
f"Positions {blocked_positions} have ALL amino acids blocked "
|
||||
f"(likely conflict between per-residue 'allowed' and global "
|
||||
f"'--inverse_fold_avoid'). Relaxing constraints for these positions "
|
||||
f"to allow all amino acids.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
constraint_mask[all_blocked] = 0 # Allow all AAs for fully-blocked positions
|
||||
|
||||
# Per-residue constraints: values 0=allowed, 1=disallowed
|
||||
per_residue_mask = constraint_mask.clone()
|
||||
# Convert binary mask to logit bias: 0 -> 0, 1 -> -inf
|
||||
per_residue_mask = per_residue_mask * (-self.inf)
|
||||
else:
|
||||
# No per-residue constraints: all AAs allowed
|
||||
per_residue_mask = torch.zeros(num_nodes, len(const.canonical_tokens), device=s.device)
|
||||
|
||||
# Add global restriction (backward compatible with --inverse_fold_avoid CLI)
|
||||
if len(self.inverse_fold_restriction) > 0:
|
||||
global_mask = torch.zeros(len(const.canonical_tokens), device=s.device)
|
||||
for res_type in self.inverse_fold_restriction:
|
||||
global_mask[const.canonical_tokens.index(res_type)] = -self.inf
|
||||
# Combine: global restrictions apply to ALL positions additively
|
||||
per_residue_mask = per_residue_mask + global_mask.unsqueeze(0)
|
||||
|
||||
order = torch.randperm(num_nodes, device=s.device).cpu().numpy().tolist()
|
||||
# Non-design residues are not sampled and used as the condition. So the order should filter them out.
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
"""Test configuration: mock heavy dependencies for CPU-only unit tests.
|
||||
"""Optional test configuration for mocking heavy dependencies.
|
||||
|
||||
The per-residue constraint functions under test (_normalize_aa_spec,
|
||||
_convert_aa_names_to_indices, parse_residue_constraints) only use numpy
|
||||
and the boltzgen.data.const module. However, schema.py transitively
|
||||
imports torch, pytorch_lightning, etc. via other boltzgen modules.
|
||||
Enable with:
|
||||
pytest --mock-heavy-deps tests/test_residue_constraints.py
|
||||
|
||||
This conftest patches those heavy imports so tests can run without GPU
|
||||
libraries installed — achieving the "Level 1: No GPU, fast" goal.
|
||||
By default no mocking is performed, so integration tests run against
|
||||
real dependencies.
|
||||
"""
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
@@ -66,5 +63,17 @@ _MOCK_MODULES = [
|
||||
"Bio.PDB",
|
||||
]
|
||||
|
||||
|
||||
def pytest_addoption(parser) -> None:
|
||||
parser.addoption(
|
||||
"--mock-heavy-deps",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Mock heavy optional dependencies for parser-only unit tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config) -> None:
|
||||
if config.getoption("--mock-heavy-deps"):
|
||||
for mod in _MOCK_MODULES:
|
||||
_install_mock(mod)
|
||||
|
||||
90
tests/test_inverse_fold_constraint_masks.py
Normal file
90
tests/test_inverse_fold_constraint_masks.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Integration tests for inverse-folding constraint mask composition."""
|
||||
|
||||
import pytest
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
from boltzgen.data import const
|
||||
from boltzgen.model.modules.inverse_fold import build_constraint_logit_mask
|
||||
|
||||
|
||||
INF = 10**6
|
||||
|
||||
|
||||
def _allowed_only_mask(allowed_tokens: list[str]) -> torch.Tensor:
|
||||
"""Build a single-row mask where only `allowed_tokens` are permitted."""
|
||||
num_aa = len(const.canonical_tokens)
|
||||
mask = torch.ones((1, num_aa), dtype=torch.float32)
|
||||
for token in allowed_tokens:
|
||||
mask[0, const.canonical_tokens.index(token)] = 0.0
|
||||
return mask
|
||||
|
||||
|
||||
def test_conflict_allowed_and_global_avoid_keeps_global_restriction() -> None:
|
||||
cys_idx = const.canonical_tokens.index("CYS")
|
||||
aa_constraint_mask = _allowed_only_mask(["CYS"])
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="Relaxing per-residue constraints"):
|
||||
out = build_constraint_logit_mask(
|
||||
num_nodes=1,
|
||||
aa_constraint_mask=aa_constraint_mask,
|
||||
inverse_fold_restriction=["CYS"],
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=INF,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
# Global avoid must still block CYS after conflict handling.
|
||||
assert out[0, cys_idx].item() == -INF
|
||||
# All other residues remain available.
|
||||
assert (out[0] == 0).sum().item() == len(const.canonical_tokens) - 1
|
||||
|
||||
|
||||
def test_non_conflicting_constraints_compose_correctly() -> None:
|
||||
ala_idx = const.canonical_tokens.index("ALA")
|
||||
cys_idx = const.canonical_tokens.index("CYS")
|
||||
aa_constraint_mask = _allowed_only_mask(["ALA"])
|
||||
|
||||
out = build_constraint_logit_mask(
|
||||
num_nodes=1,
|
||||
aa_constraint_mask=aa_constraint_mask,
|
||||
inverse_fold_restriction=["CYS"],
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=INF,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
# Only ALA should remain available.
|
||||
assert out[0, ala_idx].item() == 0.0
|
||||
assert out[0, cys_idx].item() == -INF
|
||||
assert (out[0] == 0).sum().item() == 1
|
||||
|
||||
|
||||
def test_global_restrictions_that_block_all_raise() -> None:
|
||||
with pytest.raises(ValueError, match="no valid amino acids"):
|
||||
build_constraint_logit_mask(
|
||||
num_nodes=1,
|
||||
aa_constraint_mask=None,
|
||||
inverse_fold_restriction=const.canonical_tokens,
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=INF,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
|
||||
def test_shape_mismatch_ignores_per_residue_mask() -> None:
|
||||
bad_shape = torch.zeros((2, 20), dtype=torch.float32)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="shape mismatch"):
|
||||
out = build_constraint_logit_mask(
|
||||
num_nodes=1,
|
||||
aa_constraint_mask=bad_shape,
|
||||
inverse_fold_restriction=[],
|
||||
canonical_tokens=const.canonical_tokens,
|
||||
inf=INF,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
# No restrictions should remain after ignoring mismatched input.
|
||||
assert out.shape == (1, len(const.canonical_tokens))
|
||||
assert torch.all(out == 0)
|
||||
Reference in New Issue
Block a user