diff --git a/src/boltzgen/model/modules/inverse_fold.py b/src/boltzgen/model/modules/inverse_fold.py index f5932f6..56974c1 100755 --- a/src/boltzgen/model/modules/inverse_fold.py +++ b/src/boltzgen/model/modules/inverse_fold.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple, List +from typing import Dict, Tuple, List, Optional import warnings import torch @@ -39,6 +39,75 @@ def softmax_dropout( ) +def build_constraint_logit_mask( + num_nodes: int, + aa_constraint_mask: Optional[Tensor], + inverse_fold_restriction: list[str], + canonical_tokens: list[str], + inf: float, + device: torch.device, +) -> Tensor: + """Build per-position inverse-folding logit mask. + + The mask uses additive logit bias semantics: + 0.0 = allowed, -inf = disallowed. + """ + num_aa = len(canonical_tokens) + has_per_residue_constraints = False + + if aa_constraint_mask is None: + per_residue_blocked = torch.zeros( + num_nodes, num_aa, dtype=torch.bool, device=device + ) + else: + expected_shape = (num_nodes, num_aa) + if aa_constraint_mask.shape != expected_shape: + warnings.warn( + f"aa_constraint_mask shape mismatch: " + f"got {aa_constraint_mask.shape}, expected {expected_shape}. " + f"Ignoring per-residue constraints.", + RuntimeWarning, + stacklevel=2, + ) + per_residue_blocked = torch.zeros( + num_nodes, num_aa, dtype=torch.bool, device=device + ) + else: + has_per_residue_constraints = True + per_residue_blocked = aa_constraint_mask.to(device=device) > 0 + + global_blocked = torch.zeros(num_aa, dtype=torch.bool, device=device) + for res_type in inverse_fold_restriction: + global_blocked[canonical_tokens.index(res_type)] = True + + combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0) + all_blocked = combined_blocked.all(dim=1) + + if all_blocked.any() and has_per_residue_constraints: + blocked_positions = torch.where(all_blocked)[0].tolist() + warnings.warn( + f"Positions {blocked_positions} have all amino acids blocked by the " + f"combination of per-residue constraints and '--inverse_fold_avoid'. " + f"Relaxing per-residue constraints for these positions.", + RuntimeWarning, + stacklevel=2, + ) + per_residue_blocked = per_residue_blocked.clone() + per_residue_blocked[all_blocked] = False + combined_blocked = per_residue_blocked | global_blocked.unsqueeze(0) + + still_all_blocked = combined_blocked.all(dim=1) + if still_all_blocked.any(): + blocked_positions = torch.where(still_all_blocked)[0].tolist() + raise ValueError( + f"Inverse folding has no valid amino acids at token positions " + f"{blocked_positions} after applying '--inverse_fold_avoid'. " + f"Reduce global restrictions to keep at least one amino acid." + ) + + return combined_blocked.to(dtype=torch.float32) * (-inf) + + class MLPAttnGNN(nn.Module): def __init__( self, @@ -589,53 +658,17 @@ class InverseFoldingDecoder(nn.Module): f"num_design: {num_design}, num_not_design: {num_not_design}" ) - # Create per-residue restriction mask - # Initialize with per-residue constraints from YAML if available + constraint_mask = None if "aa_constraint_mask" in feats: constraint_mask = feats["aa_constraint_mask"][valid_mask] - # Validate shape: should be (num_nodes, 20) for 20 canonical amino acids - expected_shape = (num_nodes, len(const.canonical_tokens)) - if constraint_mask.shape != expected_shape: - warnings.warn( - f"aa_constraint_mask shape mismatch: " - f"got {constraint_mask.shape}, expected {expected_shape}. " - f"Ignoring per-residue constraints.", - RuntimeWarning, - stacklevel=2, - ) - per_residue_mask = torch.zeros(num_nodes, len(const.canonical_tokens), device=s.device) - else: - # Check for positions with ALL AAs blocked (would leave no valid options). - # This can happen when per-residue 'allowed' and global '--inverse_fold_avoid' - # combine to eliminate every amino acid at a position. - all_blocked = constraint_mask.sum(dim=1) >= len(const.canonical_tokens) - if all_blocked.any(): - blocked_positions = torch.where(all_blocked)[0].tolist() - warnings.warn( - f"Positions {blocked_positions} have ALL amino acids blocked " - f"(likely conflict between per-residue 'allowed' and global " - f"'--inverse_fold_avoid'). Relaxing constraints for these positions " - f"to allow all amino acids.", - RuntimeWarning, - stacklevel=2, - ) - constraint_mask[all_blocked] = 0 # Allow all AAs for fully-blocked positions - - # Per-residue constraints: values 0=allowed, 1=disallowed - per_residue_mask = constraint_mask.clone() - # Convert binary mask to logit bias: 0 -> 0, 1 -> -inf - per_residue_mask = per_residue_mask * (-self.inf) - else: - # No per-residue constraints: all AAs allowed - per_residue_mask = torch.zeros(num_nodes, len(const.canonical_tokens), device=s.device) - - # Add global restriction (backward compatible with --inverse_fold_avoid CLI) - if len(self.inverse_fold_restriction) > 0: - global_mask = torch.zeros(len(const.canonical_tokens), device=s.device) - for res_type in self.inverse_fold_restriction: - global_mask[const.canonical_tokens.index(res_type)] = -self.inf - # Combine: global restrictions apply to ALL positions additively - per_residue_mask = per_residue_mask + global_mask.unsqueeze(0) + per_residue_mask = build_constraint_logit_mask( + num_nodes=num_nodes, + aa_constraint_mask=constraint_mask, + inverse_fold_restriction=self.inverse_fold_restriction, + canonical_tokens=const.canonical_tokens, + inf=self.inf, + device=s.device, + ) order = torch.randperm(num_nodes, device=s.device).cpu().numpy().tolist() # Non-design residues are not sampled and used as the condition. So the order should filter them out. diff --git a/tests/conftest.py b/tests/conftest.py index 3b3577c..5594ac1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,12 @@ -"""Test configuration: mock heavy dependencies for CPU-only unit tests. +"""Optional test configuration for mocking heavy dependencies. -The per-residue constraint functions under test (_normalize_aa_spec, -_convert_aa_names_to_indices, parse_residue_constraints) only use numpy -and the boltzgen.data.const module. However, schema.py transitively -imports torch, pytorch_lightning, etc. via other boltzgen modules. +Enable with: + pytest --mock-heavy-deps tests/test_residue_constraints.py -This conftest patches those heavy imports so tests can run without GPU -libraries installed — achieving the "Level 1: No GPU, fast" goal. +By default no mocking is performed, so integration tests run against +real dependencies. """ import sys -from types import ModuleType from unittest.mock import MagicMock @@ -66,5 +63,17 @@ _MOCK_MODULES = [ "Bio.PDB", ] -for mod in _MOCK_MODULES: - _install_mock(mod) + +def pytest_addoption(parser) -> None: + parser.addoption( + "--mock-heavy-deps", + action="store_true", + default=False, + help="Mock heavy optional dependencies for parser-only unit tests.", + ) + + +def pytest_configure(config) -> None: + if config.getoption("--mock-heavy-deps"): + for mod in _MOCK_MODULES: + _install_mock(mod) diff --git a/tests/test_inverse_fold_constraint_masks.py b/tests/test_inverse_fold_constraint_masks.py new file mode 100644 index 0000000..74b3a00 --- /dev/null +++ b/tests/test_inverse_fold_constraint_masks.py @@ -0,0 +1,90 @@ +"""Integration tests for inverse-folding constraint mask composition.""" + +import pytest + +torch = pytest.importorskip("torch") + +from boltzgen.data import const +from boltzgen.model.modules.inverse_fold import build_constraint_logit_mask + + +INF = 10**6 + + +def _allowed_only_mask(allowed_tokens: list[str]) -> torch.Tensor: + """Build a single-row mask where only `allowed_tokens` are permitted.""" + num_aa = len(const.canonical_tokens) + mask = torch.ones((1, num_aa), dtype=torch.float32) + for token in allowed_tokens: + mask[0, const.canonical_tokens.index(token)] = 0.0 + return mask + + +def test_conflict_allowed_and_global_avoid_keeps_global_restriction() -> None: + cys_idx = const.canonical_tokens.index("CYS") + aa_constraint_mask = _allowed_only_mask(["CYS"]) + + with pytest.warns(RuntimeWarning, match="Relaxing per-residue constraints"): + out = build_constraint_logit_mask( + num_nodes=1, + aa_constraint_mask=aa_constraint_mask, + inverse_fold_restriction=["CYS"], + canonical_tokens=const.canonical_tokens, + inf=INF, + device=torch.device("cpu"), + ) + + # Global avoid must still block CYS after conflict handling. + assert out[0, cys_idx].item() == -INF + # All other residues remain available. + assert (out[0] == 0).sum().item() == len(const.canonical_tokens) - 1 + + +def test_non_conflicting_constraints_compose_correctly() -> None: + ala_idx = const.canonical_tokens.index("ALA") + cys_idx = const.canonical_tokens.index("CYS") + aa_constraint_mask = _allowed_only_mask(["ALA"]) + + out = build_constraint_logit_mask( + num_nodes=1, + aa_constraint_mask=aa_constraint_mask, + inverse_fold_restriction=["CYS"], + canonical_tokens=const.canonical_tokens, + inf=INF, + device=torch.device("cpu"), + ) + + # Only ALA should remain available. + assert out[0, ala_idx].item() == 0.0 + assert out[0, cys_idx].item() == -INF + assert (out[0] == 0).sum().item() == 1 + + +def test_global_restrictions_that_block_all_raise() -> None: + with pytest.raises(ValueError, match="no valid amino acids"): + build_constraint_logit_mask( + num_nodes=1, + aa_constraint_mask=None, + inverse_fold_restriction=const.canonical_tokens, + canonical_tokens=const.canonical_tokens, + inf=INF, + device=torch.device("cpu"), + ) + + +def test_shape_mismatch_ignores_per_residue_mask() -> None: + bad_shape = torch.zeros((2, 20), dtype=torch.float32) + + with pytest.warns(RuntimeWarning, match="shape mismatch"): + out = build_constraint_logit_mask( + num_nodes=1, + aa_constraint_mask=bad_shape, + inverse_fold_restriction=[], + canonical_tokens=const.canonical_tokens, + inf=INF, + device=torch.device("cpu"), + ) + + # No restrictions should remain after ignoring mismatched input. + assert out.shape == (1, len(const.canonical_tokens)) + assert torch.all(out == 0)