mirror of
https://github.com/HannesStark/boltzgen.git
synced 2026-06-04 11:54:23 +08:00
fix: resolve inverse-fold constraint conflicts with global avoid
Idead by buerbaumer@ and quality checked with Codex 5.3
This commit is contained in:
committed by
Hannes Stärk
parent
0822cb3b71
commit
39c87c65d6
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Tuple, List
|
from typing import Dict, Tuple, List, Optional
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -39,6 +39,75 @@ def softmax_dropout(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_constraint_logit_mask(
|
||||||
|
num_nodes: int,
|
||||||
|
aa_constraint_mask: Optional[Tensor],
|
||||||
|
inverse_fold_restriction: list[str],
|
||||||
|
canonical_tokens: list[str],
|
||||||
|
inf: float,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Build per-position inverse-folding logit mask.
|
||||||
|
|
||||||
|
The mask uses additive logit bias semantics:
|
||||||
|
0.0 = allowed, -inf = disallowed.
|
||||||
|
"""
|
||||||
|
num_aa = len(canonical_tokens)
|
||||||
|
has_per_residue_constraints = False
|
||||||
|
|
||||||
|
if aa_constraint_mask is None:
|
||||||
|
per_residue_blocked = torch.zeros(
|
||||||
|
num_nodes, num_aa, dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
expected_shape = (num_nodes, num_aa)
|
||||||
|
if aa_constraint_mask.shape != expected_shape:
|
||||||
|
warnings.warn(
|
||||||
|
f"aa_constraint_mask shape mismatch: "
|
||||||
|
f"got {aa_constraint_mask.shape}, expected {expected_shape}. "
|
||||||
|
f"Ignoring per-residue constraints.",
|
||||||
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
per_residue_blocked = torch.zeros(
|
||||||
|
num_nodes, num_aa, dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
has_per_residue_constraints = True
|
||||||
|
per_residue_blocked = aa_constraint_mask.to(device=device) > 0
|
||||||
|
|
||||||
|
global_blocked = torch.zeros(num_aa, dtype=torch.bool, device=device)
|
||||||
|
for res_type in inverse_fold_restriction:
|
||||||
|
global_blocked[canonical_tokens.index(res_type)] = True
|
||||||
|
|
||||||
|
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
|
||||||
|
all_blocked = combined_blocked.all(dim=1)
|
||||||
|
|
||||||
|
if all_blocked.any() and has_per_residue_constraints:
|
||||||
|
blocked_positions = torch.where(all_blocked)[0].tolist()
|
||||||
|
warnings.warn(
|
||||||
|
f"Positions {blocked_positions} have all amino acids blocked by the "
|
||||||
|
f"combination of per-residue constraints and '--inverse_fold_avoid'. "
|
||||||
|
f"Relaxing per-residue constraints for these positions.",
|
||||||
|
RuntimeWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
per_residue_blocked = per_residue_blocked.clone()
|
||||||
|
per_residue_blocked[all_blocked] = False
|
||||||
|
combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0)
|
||||||
|
|
||||||
|
still_all_blocked = combined_blocked.all(dim=1)
|
||||||
|
if still_all_blocked.any():
|
||||||
|
blocked_positions = torch.where(still_all_blocked)[0].tolist()
|
||||||
|
raise ValueError(
|
||||||
|
f"Inverse folding has no valid amino acids at token positions "
|
||||||
|
f"{blocked_positions} after applying '--inverse_fold_avoid'. "
|
||||||
|
f"Reduce global restrictions to keep at least one amino acid."
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_blocked.to(dtype=torch.float32) * (-inf)
|
||||||
|
|
||||||
|
|
||||||
class MLPAttnGNN(nn.Module):
|
class MLPAttnGNN(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -589,53 +658,17 @@ class InverseFoldingDecoder(nn.Module):
|
|||||||
f"num_design: {num_design}, num_not_design: {num_not_design}"
|
f"num_design: {num_design}, num_not_design: {num_not_design}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create per-residue restriction mask
|
constraint_mask = None
|
||||||
# Initialize with per-residue constraints from YAML if available
|
|
||||||
if "aa_constraint_mask" in feats:
|
if "aa_constraint_mask" in feats:
|
||||||
constraint_mask = feats["aa_constraint_mask"][valid_mask]
|
constraint_mask = feats["aa_constraint_mask"][valid_mask]
|
||||||
# Validate shape: should be (num_nodes, 20) for 20 canonical amino acids
|
per_residue_mask = build_constraint_logit_mask(
|
||||||
expected_shape = (num_nodes, len(const.canonical_tokens))
|
num_nodes=num_nodes,
|
||||||
if constraint_mask.shape != expected_shape:
|
aa_constraint_mask=constraint_mask,
|
||||||
warnings.warn(
|
inverse_fold_restriction=self.inverse_fold_restriction,
|
||||||
f"aa_constraint_mask shape mismatch: "
|
canonical_tokens=const.canonical_tokens,
|
||||||
f"got {constraint_mask.shape}, expected {expected_shape}. "
|
inf=self.inf,
|
||||||
f"Ignoring per-residue constraints.",
|
device=s.device,
|
||||||
RuntimeWarning,
|
)
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
per_residue_mask = torch.zeros(num_nodes, len(const.canonical_tokens), device=s.device)
|
|
||||||
else:
|
|
||||||
# Check for positions with ALL AAs blocked (would leave no valid options).
|
|
||||||
# This can happen when per-residue 'allowed' and global '--inverse_fold_avoid'
|
|
||||||
# combine to eliminate every amino acid at a position.
|
|
||||||
all_blocked = constraint_mask.sum(dim=1) >= len(const.canonical_tokens)
|
|
||||||
if all_blocked.any():
|
|
||||||
blocked_positions = torch.where(all_blocked)[0].tolist()
|
|
||||||
warnings.warn(
|
|
||||||
f"Positions {blocked_positions} have ALL amino acids blocked "
|
|
||||||
f"(likely conflict between per-residue 'allowed' and global "
|
|
||||||
f"'--inverse_fold_avoid'). Relaxing constraints for these positions "
|
|
||||||
f"to allow all amino acids.",
|
|
||||||
RuntimeWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
constraint_mask[all_blocked] = 0 # Allow all AAs for fully-blocked positions
|
|
||||||
|
|
||||||
# Per-residue constraints: values 0=allowed, 1=disallowed
|
|
||||||
per_residue_mask = constraint_mask.clone()
|
|
||||||
# Convert binary mask to logit bias: 0 -> 0, 1 -> -inf
|
|
||||||
per_residue_mask = per_residue_mask * (-self.inf)
|
|
||||||
else:
|
|
||||||
# No per-residue constraints: all AAs allowed
|
|
||||||
per_residue_mask = torch.zeros(num_nodes, len(const.canonical_tokens), device=s.device)
|
|
||||||
|
|
||||||
# Add global restriction (backward compatible with --inverse_fold_avoid CLI)
|
|
||||||
if len(self.inverse_fold_restriction) > 0:
|
|
||||||
global_mask = torch.zeros(len(const.canonical_tokens), device=s.device)
|
|
||||||
for res_type in self.inverse_fold_restriction:
|
|
||||||
global_mask[const.canonical_tokens.index(res_type)] = -self.inf
|
|
||||||
# Combine: global restrictions apply to ALL positions additively
|
|
||||||
per_residue_mask = per_residue_mask + global_mask.unsqueeze(0)
|
|
||||||
|
|
||||||
order = torch.randperm(num_nodes, device=s.device).cpu().numpy().tolist()
|
order = torch.randperm(num_nodes, device=s.device).cpu().numpy().tolist()
|
||||||
# Non-design residues are not sampled and used as the condition. So the order should filter them out.
|
# Non-design residues are not sampled and used as the condition. So the order should filter them out.
|
||||||
|
|||||||
@@ -1,15 +1,12 @@
|
|||||||
"""Test configuration: mock heavy dependencies for CPU-only unit tests.
|
"""Optional test configuration for mocking heavy dependencies.
|
||||||
|
|
||||||
The per-residue constraint functions under test (_normalize_aa_spec,
|
Enable with:
|
||||||
_convert_aa_names_to_indices, parse_residue_constraints) only use numpy
|
pytest --mock-heavy-deps tests/test_residue_constraints.py
|
||||||
and the boltzgen.data.const module. However, schema.py transitively
|
|
||||||
imports torch, pytorch_lightning, etc. via other boltzgen modules.
|
|
||||||
|
|
||||||
This conftest patches those heavy imports so tests can run without GPU
|
By default no mocking is performed, so integration tests run against
|
||||||
libraries installed — achieving the "Level 1: No GPU, fast" goal.
|
real dependencies.
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
from types import ModuleType
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
@@ -66,5 +63,17 @@ _MOCK_MODULES = [
|
|||||||
"Bio.PDB",
|
"Bio.PDB",
|
||||||
]
|
]
|
||||||
|
|
||||||
for mod in _MOCK_MODULES:
|
|
||||||
_install_mock(mod)
|
def pytest_addoption(parser) -> None:
|
||||||
|
parser.addoption(
|
||||||
|
"--mock-heavy-deps",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Mock heavy optional dependencies for parser-only unit tests.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config) -> None:
|
||||||
|
if config.getoption("--mock-heavy-deps"):
|
||||||
|
for mod in _MOCK_MODULES:
|
||||||
|
_install_mock(mod)
|
||||||
|
|||||||
90
tests/test_inverse_fold_constraint_masks.py
Normal file
90
tests/test_inverse_fold_constraint_masks.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Integration tests for inverse-folding constraint mask composition."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
torch = pytest.importorskip("torch")
|
||||||
|
|
||||||
|
from boltzgen.data import const
|
||||||
|
from boltzgen.model.modules.inverse_fold import build_constraint_logit_mask
|
||||||
|
|
||||||
|
|
||||||
|
INF = 10**6
|
||||||
|
|
||||||
|
|
||||||
|
def _allowed_only_mask(allowed_tokens: list[str]) -> torch.Tensor:
|
||||||
|
"""Build a single-row mask where only `allowed_tokens` are permitted."""
|
||||||
|
num_aa = len(const.canonical_tokens)
|
||||||
|
mask = torch.ones((1, num_aa), dtype=torch.float32)
|
||||||
|
for token in allowed_tokens:
|
||||||
|
mask[0, const.canonical_tokens.index(token)] = 0.0
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def test_conflict_allowed_and_global_avoid_keeps_global_restriction() -> None:
|
||||||
|
cys_idx = const.canonical_tokens.index("CYS")
|
||||||
|
aa_constraint_mask = _allowed_only_mask(["CYS"])
|
||||||
|
|
||||||
|
with pytest.warns(RuntimeWarning, match="Relaxing per-residue constraints"):
|
||||||
|
out = build_constraint_logit_mask(
|
||||||
|
num_nodes=1,
|
||||||
|
aa_constraint_mask=aa_constraint_mask,
|
||||||
|
inverse_fold_restriction=["CYS"],
|
||||||
|
canonical_tokens=const.canonical_tokens,
|
||||||
|
inf=INF,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global avoid must still block CYS after conflict handling.
|
||||||
|
assert out[0, cys_idx].item() == -INF
|
||||||
|
# All other residues remain available.
|
||||||
|
assert (out[0] == 0).sum().item() == len(const.canonical_tokens) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_conflicting_constraints_compose_correctly() -> None:
|
||||||
|
ala_idx = const.canonical_tokens.index("ALA")
|
||||||
|
cys_idx = const.canonical_tokens.index("CYS")
|
||||||
|
aa_constraint_mask = _allowed_only_mask(["ALA"])
|
||||||
|
|
||||||
|
out = build_constraint_logit_mask(
|
||||||
|
num_nodes=1,
|
||||||
|
aa_constraint_mask=aa_constraint_mask,
|
||||||
|
inverse_fold_restriction=["CYS"],
|
||||||
|
canonical_tokens=const.canonical_tokens,
|
||||||
|
inf=INF,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only ALA should remain available.
|
||||||
|
assert out[0, ala_idx].item() == 0.0
|
||||||
|
assert out[0, cys_idx].item() == -INF
|
||||||
|
assert (out[0] == 0).sum().item() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_global_restrictions_that_block_all_raise() -> None:
|
||||||
|
with pytest.raises(ValueError, match="no valid amino acids"):
|
||||||
|
build_constraint_logit_mask(
|
||||||
|
num_nodes=1,
|
||||||
|
aa_constraint_mask=None,
|
||||||
|
inverse_fold_restriction=const.canonical_tokens,
|
||||||
|
canonical_tokens=const.canonical_tokens,
|
||||||
|
inf=INF,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shape_mismatch_ignores_per_residue_mask() -> None:
|
||||||
|
bad_shape = torch.zeros((2, 20), dtype=torch.float32)
|
||||||
|
|
||||||
|
with pytest.warns(RuntimeWarning, match="shape mismatch"):
|
||||||
|
out = build_constraint_logit_mask(
|
||||||
|
num_nodes=1,
|
||||||
|
aa_constraint_mask=bad_shape,
|
||||||
|
inverse_fold_restriction=[],
|
||||||
|
canonical_tokens=const.canonical_tokens,
|
||||||
|
inf=INF,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# No restrictions should remain after ignoring mismatched input.
|
||||||
|
assert out.shape == (1, len(const.canonical_tokens))
|
||||||
|
assert torch.all(out == 0)
|
||||||
Reference in New Issue
Block a user