test: unit tests for foundry geometry utils (alignment, rotation augmentation) (#285)

* test: bootstrap mypy + pytest + coverage CI gates

Wire up the tooling the upcoming test/annotation/refactor work depends on:

- Add mypy>=1.13 to [dev] and a lenient [tool.mypy] config scoped to
  src/foundry + src/foundry_cli (ignore_missing_imports, no
  disallow_untyped_defs). The 14 modules with pre-existing type errors
  are pinned via [[tool.mypy.overrides]] ignore_errors=true and listed
  as the ratchet target — fix and remove, never add.
- Add [tool.pytest.ini_options] (testpaths=["tests"], --strict-markers
  --strict-config) and [tool.coverage.*] (source = src/foundry +
  src/foundry_cli, branch = true) for opt-in gap finding.
- Add .github/workflows/test.yaml with mypy and pytest jobs running on
  the same triggers as lint_production.yaml. Top-level tests/ only;
  per-model tests under models/*/tests/ may require GPU and checkpoints
  and stay out of CI for now.

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* test(mypy): extend ignore list with 4 modules CI surfaced

Local sanity-check ran mypy without foundry installed, so torch /
lightning resolved to `Any` and errors that depend on knowing those
types stayed invisible. First CI run installed the full deps and
surfaced 6 errors in 4 additional modules:

  foundry.model.layers.blocks
  foundry.training.schedulers
  foundry.utils.logging
  foundry.utils.xpu.xpu_accelerator

Same ratchet contract as the original 14: do not add, only remove
(after fixing the errors and removing `ignore_errors = true`).

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* style: ruff format 2 pre-existing files unrelated to bootstrap

These files have been failing `ruff format --check` on production HEAD
(merged via #275 and #281 without a pre-commit run). They block the
existing `lint_production` workflow on every PR, including this
bootstrap. Strictly out of scope for 0001 — kept in a separate commit
so it can be cherry-picked or reverted cleanly.

  models/rfd3/src/rfd3/inference/input_parsing.py
  models/rfd3/tests/test_partial_diffusion.py

No semantic changes — `ruff format` output only.

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* ci: run test workflow on all pull requests

Drop the base-branch filter on the pull_request trigger so PRs targeting
stacked task branches are gated too, not only PRs into the mainline branches.

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* test: unit tests for foundry geometry utils (alignment, rotation augmentation)

Pin the numeric behaviour of weighted_rigid_align/get_rmsd and the
rotation/centre/augment helpers used across the diffusion models:
exact recovery of known rigid transforms, mask/weight exclusion from
the fit, proper-rotation invariants, global-centroid centring, and
distance-preserving augmentation. CPU-only, float32, 18 tests.

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

---------

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>
Co-authored-by: Hope Woods <hope.woods@omsf.io>
This commit is contained in:
lyskov-ai
2026-06-02 14:26:39 -06:00
committed by GitHub
parent b69bed5e4c
commit 2d66d998ce
2 changed files with 247 additions and 0 deletions

136
tests/test_alignment.py Normal file
View File

@@ -0,0 +1,136 @@
"""Unit tests for foundry.utils.alignment.
`weighted_rigid_align` is the SE(3) Kabsch alignment used across rf3/rfd3/rfd3na
to align predicted/ground-truth coordinates before computing losses and during
sampling. Its contract (recover an exact rigid transform, ignore points the
mask/weights exclude from the fit, detach the output) is not obvious from the
signature, so the tests below pin it on representative CPU inputs.
The implementation builds its determinant-correction matrix with the default
float dtype, so it only accepts float32 inputs in practice; all tests use
float32 to match production call sites.
"""
import pytest
import torch
from foundry.utils.alignment import get_rmsd, weighted_rigid_align
def _rotation_about_z(angle: float) -> torch.Tensor:
"""Proper rotation (det +1) about the z-axis, as a [3, 3] float32 matrix."""
a = torch.tensor(angle)
c, s = torch.cos(a), torch.sin(a)
return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]])
def test_identity_alignment_is_noop():
"""Aligning a structure onto itself returns the structure unchanged."""
torch.manual_seed(0)
X = torch.randn(1, 16, 3)
aligned = weighted_rigid_align(X, X)
assert torch.allclose(aligned, X, atol=1e-4)
def test_recovers_known_rigid_transform():
"""A rigid image of X aligns back onto X regardless of the rotation/translation."""
torch.manual_seed(0)
X = torch.randn(1, 16, 3)
R = _rotation_about_z(1.0)
t = torch.tensor([3.0, -2.0, 5.0])
X_gt = X @ R.T + t # an exact rigid image of X
aligned = weighted_rigid_align(X, X_gt)
assert torch.allclose(aligned, X, atol=1e-4)
def test_recovers_pure_translation():
"""Translation-only ground truth aligns back exactly."""
torch.manual_seed(0)
X = torch.randn(1, 16, 3)
X_gt = X + torch.tensor([10.0, -4.0, 0.5])
aligned = weighted_rigid_align(X, X_gt)
assert torch.allclose(aligned, X, atol=1e-4)
def test_output_is_detached_and_canonicalized():
"""Output is detached and a bare [L, 3] input is promoted to [1, L, 3]."""
torch.manual_seed(0)
X = torch.randn(16, 3, requires_grad=True)
X_gt = (X @ _rotation_about_z(0.7).T).detach()
aligned = weighted_rigid_align(X, X_gt)
assert aligned.shape == (1, 16, 3)
assert not aligned.requires_grad
def test_x_exists_mask_excludes_points_from_fit():
"""Points marked absent must not influence the recovered transform."""
torch.manual_seed(0)
L = 16
X = torch.randn(1, L, 3)
R = _rotation_about_z(1.0)
t = torch.tensor([1.0, 2.0, 3.0])
X_gt = (X @ R.T + t).clone()
exists = torch.ones(L, dtype=torch.bool)
exists[-3:] = False
X_gt[:, ~exists] = 1e3 # corrupt the absent points
aligned = weighted_rigid_align(X, X_gt, exists)
# The fit ignored the corrupted points, so the present ones recover exactly.
assert torch.allclose(aligned[:, exists], X[:, exists], atol=1e-4)
def test_zero_weight_points_excluded_from_fit():
"""Zero-weighted points are excluded from the fit, like an absent-point mask."""
torch.manual_seed(0)
L = 16
X = torch.randn(1, L, 3)
R = _rotation_about_z(1.0)
X_gt = (X @ R.T).clone()
X_gt[:, -3:] = 1e3 # corrupt the points we will zero-weight
w = torch.ones(1, L)
w[:, -3:] = 0.0
aligned = weighted_rigid_align(X, X_gt, w_L=w)
assert torch.allclose(aligned[:, :-3], X[:, :-3], atol=1e-4)
def test_uniform_weights_match_default():
"""Explicit all-ones weights produce the same result as the default (None)."""
torch.manual_seed(0)
X = torch.randn(1, 16, 3)
X_gt = X @ _rotation_about_z(0.4).T
default = weighted_rigid_align(X, X_gt)
ones = weighted_rigid_align(X, X_gt, None, torch.ones(1, 16))
assert torch.allclose(default, ones, atol=1e-5)
def test_x_exists_must_be_boolean():
"""A non-boolean mask is rejected to avoid a silent mis-alignment."""
X = torch.randn(1, 8, 3)
with pytest.raises(AssertionError, match="boolean mask"):
weighted_rigid_align(X, X, torch.ones(8))
def test_get_rmsd_identical_returns_sqrt_eps():
"""RMSD of identical coordinates is sqrt(eps), not 0 (eps lives under the sqrt)."""
torch.manual_seed(0)
X = torch.randn(4, 10, 3)
rmsd = get_rmsd(X, X)
assert torch.allclose(rmsd, torch.full_like(rmsd, 0.01)) # sqrt(1e-4)
def test_get_rmsd_constant_offset():
"""A constant per-atom offset v gives RMSD sqrt(|v|^2 + eps)."""
torch.manual_seed(0)
X = torch.randn(2, 10, 3)
v = torch.tensor([1.0, 2.0, 2.0]) # |v|^2 = 9
rmsd = get_rmsd(X, X + v)
expected = torch.full_like(rmsd, (9.0 + 1e-4) ** 0.5)
assert torch.allclose(rmsd, expected, atol=1e-4)
if __name__ == "__main__":
pytest.main(["-v", __file__])

View File

@@ -0,0 +1,111 @@
"""Unit tests for foundry.utils.rotation_augmentation.
These helpers apply the random SE(3) augmentation used during rf3 sampling.
The contracts worth pinning are geometric: `uniform_random_rotation` must emit
proper rotations, `centre` removes the global centroid of the present atoms and
zeros the absent ones, and the augmentation is rigid (distance-preserving).
Inputs follow the production shapes: coordinates are [D, L, 3] and the
existence mask is [D, L].
"""
import pytest
import torch
from foundry.utils.rotation_augmentation import (
centre,
centre_random_augmentation,
get_random_augmentation,
uniform_random_rotation,
)
def _pairwise_distances(x: torch.Tensor) -> torch.Tensor:
"""Per-batch pairwise distance matrices for [D, L, 3] coordinates."""
return torch.cdist(x, x)
def test_uniform_random_rotation_shape():
rotations = uniform_random_rotation((5,))
assert rotations.shape == (5, 3, 3)
def test_uniform_random_rotation_is_proper_rotation():
"""Sampled matrices are orthogonal with determinant +1 (no reflections)."""
torch.manual_seed(0)
n = 8
rotations = uniform_random_rotation((n,))
identity = torch.eye(3).expand(n, 3, 3)
assert torch.allclose(rotations @ rotations.transpose(-1, -2), identity, atol=1e-5)
assert torch.allclose(torch.linalg.det(rotations), torch.ones(n), atol=1e-5)
def test_centre_removes_global_centroid_of_present_atoms():
"""Present atoms are shifted by the centroid taken over all present atoms."""
torch.manual_seed(0)
X = torch.randn(2, 12, 3)
mask = torch.ones(2, 12, dtype=torch.bool)
mask[1, 0] = False # one absent atom
centred = centre(X, mask)
expected_present = X[mask] - X[mask].mean(dim=0)
assert torch.allclose(centred[mask], expected_present, atol=1e-6)
assert torch.allclose(centred[mask].mean(dim=0), torch.zeros(3), atol=1e-5)
def test_centre_zeros_absent_atoms():
torch.manual_seed(0)
X = torch.randn(2, 12, 3)
mask = torch.ones(2, 12, dtype=torch.bool)
mask[0, 3:6] = False
centred = centre(X, mask)
assert torch.all(centred[~mask] == 0.0)
def test_centre_does_not_mutate_input():
"""centre clones, so the caller's tensor is left untouched."""
torch.manual_seed(0)
X = torch.randn(2, 12, 3)
original = X.clone()
centre(X, torch.ones(2, 12, dtype=torch.bool))
assert torch.equal(X, original)
def test_get_random_augmentation_preserves_distances():
"""Augmentation is rigid: intra-structure distances are unchanged."""
torch.manual_seed(0)
X = torch.randn(3, 12, 3)
augmented = get_random_augmentation(X, s_trans=2.0)
assert augmented.shape == X.shape
assert torch.allclose(
_pairwise_distances(X), _pairwise_distances(augmented), atol=1e-4
)
def test_get_random_augmentation_zero_translation_keeps_centroid_rotating():
"""With s_trans=0 the centroid only rotates, so its distance to origin holds."""
torch.manual_seed(0)
X = torch.randn(3, 12, 3)
augmented = get_random_augmentation(X, s_trans=0.0)
before = X.mean(dim=1).norm(dim=-1)
after = augmented.mean(dim=1).norm(dim=-1)
assert torch.allclose(before, after, atol=1e-4)
def test_centre_random_augmentation_preserves_present_distances():
"""The composed centre+augment step stays rigid over the present atoms."""
torch.manual_seed(0)
X = torch.randn(3, 12, 3)
mask = torch.ones(3, 12, dtype=torch.bool)
result = centre_random_augmentation(X, mask, s_trans=1.0)
assert result.shape == X.shape
# centre then a rigid transform preserves all pairwise distances.
centred = centre(X, mask)
assert torch.allclose(
_pairwise_distances(centred), _pairwise_distances(result), atol=1e-4
)
if __name__ == "__main__":
pytest.main(["-v", __file__])