mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
test: unit tests for foundry geometry utils (alignment, rotation augmentation) (#285)
* test: bootstrap mypy + pytest + coverage CI gates Wire up the tooling the upcoming test/annotation/refactor work depends on: - Add mypy>=1.13 to [dev] and a lenient [tool.mypy] config scoped to src/foundry + src/foundry_cli (ignore_missing_imports, no disallow_untyped_defs). The 14 modules with pre-existing type errors are pinned via [[tool.mypy.overrides]] ignore_errors=true and listed as the ratchet target — fix and remove, never add. - Add [tool.pytest.ini_options] (testpaths=["tests"], --strict-markers --strict-config) and [tool.coverage.*] (source = src/foundry + src/foundry_cli, branch = true) for opt-in gap finding. - Add .github/workflows/test.yaml with mypy and pytest jobs running on the same triggers as lint_production.yaml. Top-level tests/ only; per-model tests under models/*/tests/ may require GPU and checkpoints and stay out of CI for now. Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * test(mypy): extend ignore list with 4 modules CI surfaced Local sanity-check ran mypy without foundry installed, so torch / lightning resolved to `Any` and errors that depend on knowing those types stayed invisible. First CI run installed the full deps and surfaced 6 errors in 4 additional modules: foundry.model.layers.blocks foundry.training.schedulers foundry.utils.logging foundry.utils.xpu.xpu_accelerator Same ratchet contract as the original 14: do not add, only remove (after fixing the errors and removing `ignore_errors = true`). Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * style: ruff format 2 pre-existing files unrelated to bootstrap These files have been failing `ruff format --check` on production HEAD (merged via #275 and #281 without a pre-commit run). They block the existing `lint_production` workflow on every PR, including this bootstrap. Strictly out of scope for 0001 — kept in a separate commit so it can be cherry-picked or reverted cleanly. models/rfd3/src/rfd3/inference/input_parsing.py models/rfd3/tests/test_partial_diffusion.py No semantic changes — `ruff format` output only. Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * ci: run test workflow on all pull requests Drop the base-branch filter on the pull_request trigger so PRs targeting stacked task branches are gated too, not only PRs into the mainline branches. Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * test: unit tests for foundry geometry utils (alignment, rotation augmentation) Pin the numeric behaviour of weighted_rigid_align/get_rmsd and the rotation/centre/augment helpers used across the diffusion models: exact recovery of known rigid transforms, mask/weight exclusion from the fit, proper-rotation invariants, global-centroid centring, and distance-preserving augmentation. CPU-only, float32, 18 tests. Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> --------- Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> Co-authored-by: Hope Woods <hope.woods@omsf.io>
This commit is contained in:
136
tests/test_alignment.py
Normal file
136
tests/test_alignment.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Unit tests for foundry.utils.alignment.
|
||||
|
||||
`weighted_rigid_align` is the SE(3) Kabsch alignment used across rf3/rfd3/rfd3na
|
||||
to align predicted/ground-truth coordinates before computing losses and during
|
||||
sampling. Its contract (recover an exact rigid transform, ignore points the
|
||||
mask/weights exclude from the fit, detach the output) is not obvious from the
|
||||
signature, so the tests below pin it on representative CPU inputs.
|
||||
|
||||
The implementation builds its determinant-correction matrix with the default
|
||||
float dtype, so it only accepts float32 inputs in practice; all tests use
|
||||
float32 to match production call sites.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from foundry.utils.alignment import get_rmsd, weighted_rigid_align
|
||||
|
||||
|
||||
def _rotation_about_z(angle: float) -> torch.Tensor:
|
||||
"""Proper rotation (det +1) about the z-axis, as a [3, 3] float32 matrix."""
|
||||
a = torch.tensor(angle)
|
||||
c, s = torch.cos(a), torch.sin(a)
|
||||
return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]])
|
||||
|
||||
|
||||
def test_identity_alignment_is_noop():
|
||||
"""Aligning a structure onto itself returns the structure unchanged."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(1, 16, 3)
|
||||
aligned = weighted_rigid_align(X, X)
|
||||
assert torch.allclose(aligned, X, atol=1e-4)
|
||||
|
||||
|
||||
def test_recovers_known_rigid_transform():
|
||||
"""A rigid image of X aligns back onto X regardless of the rotation/translation."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(1, 16, 3)
|
||||
R = _rotation_about_z(1.0)
|
||||
t = torch.tensor([3.0, -2.0, 5.0])
|
||||
X_gt = X @ R.T + t # an exact rigid image of X
|
||||
|
||||
aligned = weighted_rigid_align(X, X_gt)
|
||||
assert torch.allclose(aligned, X, atol=1e-4)
|
||||
|
||||
|
||||
def test_recovers_pure_translation():
|
||||
"""Translation-only ground truth aligns back exactly."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(1, 16, 3)
|
||||
X_gt = X + torch.tensor([10.0, -4.0, 0.5])
|
||||
aligned = weighted_rigid_align(X, X_gt)
|
||||
assert torch.allclose(aligned, X, atol=1e-4)
|
||||
|
||||
|
||||
def test_output_is_detached_and_canonicalized():
|
||||
"""Output is detached and a bare [L, 3] input is promoted to [1, L, 3]."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(16, 3, requires_grad=True)
|
||||
X_gt = (X @ _rotation_about_z(0.7).T).detach()
|
||||
|
||||
aligned = weighted_rigid_align(X, X_gt)
|
||||
assert aligned.shape == (1, 16, 3)
|
||||
assert not aligned.requires_grad
|
||||
|
||||
|
||||
def test_x_exists_mask_excludes_points_from_fit():
|
||||
"""Points marked absent must not influence the recovered transform."""
|
||||
torch.manual_seed(0)
|
||||
L = 16
|
||||
X = torch.randn(1, L, 3)
|
||||
R = _rotation_about_z(1.0)
|
||||
t = torch.tensor([1.0, 2.0, 3.0])
|
||||
X_gt = (X @ R.T + t).clone()
|
||||
|
||||
exists = torch.ones(L, dtype=torch.bool)
|
||||
exists[-3:] = False
|
||||
X_gt[:, ~exists] = 1e3 # corrupt the absent points
|
||||
|
||||
aligned = weighted_rigid_align(X, X_gt, exists)
|
||||
# The fit ignored the corrupted points, so the present ones recover exactly.
|
||||
assert torch.allclose(aligned[:, exists], X[:, exists], atol=1e-4)
|
||||
|
||||
|
||||
def test_zero_weight_points_excluded_from_fit():
|
||||
"""Zero-weighted points are excluded from the fit, like an absent-point mask."""
|
||||
torch.manual_seed(0)
|
||||
L = 16
|
||||
X = torch.randn(1, L, 3)
|
||||
R = _rotation_about_z(1.0)
|
||||
X_gt = (X @ R.T).clone()
|
||||
X_gt[:, -3:] = 1e3 # corrupt the points we will zero-weight
|
||||
|
||||
w = torch.ones(1, L)
|
||||
w[:, -3:] = 0.0
|
||||
aligned = weighted_rigid_align(X, X_gt, w_L=w)
|
||||
assert torch.allclose(aligned[:, :-3], X[:, :-3], atol=1e-4)
|
||||
|
||||
|
||||
def test_uniform_weights_match_default():
|
||||
"""Explicit all-ones weights produce the same result as the default (None)."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(1, 16, 3)
|
||||
X_gt = X @ _rotation_about_z(0.4).T
|
||||
default = weighted_rigid_align(X, X_gt)
|
||||
ones = weighted_rigid_align(X, X_gt, None, torch.ones(1, 16))
|
||||
assert torch.allclose(default, ones, atol=1e-5)
|
||||
|
||||
|
||||
def test_x_exists_must_be_boolean():
|
||||
"""A non-boolean mask is rejected to avoid a silent mis-alignment."""
|
||||
X = torch.randn(1, 8, 3)
|
||||
with pytest.raises(AssertionError, match="boolean mask"):
|
||||
weighted_rigid_align(X, X, torch.ones(8))
|
||||
|
||||
|
||||
def test_get_rmsd_identical_returns_sqrt_eps():
|
||||
"""RMSD of identical coordinates is sqrt(eps), not 0 (eps lives under the sqrt)."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(4, 10, 3)
|
||||
rmsd = get_rmsd(X, X)
|
||||
assert torch.allclose(rmsd, torch.full_like(rmsd, 0.01)) # sqrt(1e-4)
|
||||
|
||||
|
||||
def test_get_rmsd_constant_offset():
|
||||
"""A constant per-atom offset v gives RMSD sqrt(|v|^2 + eps)."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(2, 10, 3)
|
||||
v = torch.tensor([1.0, 2.0, 2.0]) # |v|^2 = 9
|
||||
rmsd = get_rmsd(X, X + v)
|
||||
expected = torch.full_like(rmsd, (9.0 + 1e-4) ** 0.5)
|
||||
assert torch.allclose(rmsd, expected, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-v", __file__])
|
||||
111
tests/test_rotation_augmentation.py
Normal file
111
tests/test_rotation_augmentation.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Unit tests for foundry.utils.rotation_augmentation.
|
||||
|
||||
These helpers apply the random SE(3) augmentation used during rf3 sampling.
|
||||
The contracts worth pinning are geometric: `uniform_random_rotation` must emit
|
||||
proper rotations, `centre` removes the global centroid of the present atoms and
|
||||
zeros the absent ones, and the augmentation is rigid (distance-preserving).
|
||||
Inputs follow the production shapes: coordinates are [D, L, 3] and the
|
||||
existence mask is [D, L].
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from foundry.utils.rotation_augmentation import (
|
||||
centre,
|
||||
centre_random_augmentation,
|
||||
get_random_augmentation,
|
||||
uniform_random_rotation,
|
||||
)
|
||||
|
||||
|
||||
def _pairwise_distances(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Per-batch pairwise distance matrices for [D, L, 3] coordinates."""
|
||||
return torch.cdist(x, x)
|
||||
|
||||
|
||||
def test_uniform_random_rotation_shape():
|
||||
rotations = uniform_random_rotation((5,))
|
||||
assert rotations.shape == (5, 3, 3)
|
||||
|
||||
|
||||
def test_uniform_random_rotation_is_proper_rotation():
|
||||
"""Sampled matrices are orthogonal with determinant +1 (no reflections)."""
|
||||
torch.manual_seed(0)
|
||||
n = 8
|
||||
rotations = uniform_random_rotation((n,))
|
||||
identity = torch.eye(3).expand(n, 3, 3)
|
||||
assert torch.allclose(rotations @ rotations.transpose(-1, -2), identity, atol=1e-5)
|
||||
assert torch.allclose(torch.linalg.det(rotations), torch.ones(n), atol=1e-5)
|
||||
|
||||
|
||||
def test_centre_removes_global_centroid_of_present_atoms():
|
||||
"""Present atoms are shifted by the centroid taken over all present atoms."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(2, 12, 3)
|
||||
mask = torch.ones(2, 12, dtype=torch.bool)
|
||||
mask[1, 0] = False # one absent atom
|
||||
|
||||
centred = centre(X, mask)
|
||||
expected_present = X[mask] - X[mask].mean(dim=0)
|
||||
assert torch.allclose(centred[mask], expected_present, atol=1e-6)
|
||||
assert torch.allclose(centred[mask].mean(dim=0), torch.zeros(3), atol=1e-5)
|
||||
|
||||
|
||||
def test_centre_zeros_absent_atoms():
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(2, 12, 3)
|
||||
mask = torch.ones(2, 12, dtype=torch.bool)
|
||||
mask[0, 3:6] = False
|
||||
|
||||
centred = centre(X, mask)
|
||||
assert torch.all(centred[~mask] == 0.0)
|
||||
|
||||
|
||||
def test_centre_does_not_mutate_input():
|
||||
"""centre clones, so the caller's tensor is left untouched."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(2, 12, 3)
|
||||
original = X.clone()
|
||||
centre(X, torch.ones(2, 12, dtype=torch.bool))
|
||||
assert torch.equal(X, original)
|
||||
|
||||
|
||||
def test_get_random_augmentation_preserves_distances():
|
||||
"""Augmentation is rigid: intra-structure distances are unchanged."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(3, 12, 3)
|
||||
augmented = get_random_augmentation(X, s_trans=2.0)
|
||||
assert augmented.shape == X.shape
|
||||
assert torch.allclose(
|
||||
_pairwise_distances(X), _pairwise_distances(augmented), atol=1e-4
|
||||
)
|
||||
|
||||
|
||||
def test_get_random_augmentation_zero_translation_keeps_centroid_rotating():
|
||||
"""With s_trans=0 the centroid only rotates, so its distance to origin holds."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(3, 12, 3)
|
||||
augmented = get_random_augmentation(X, s_trans=0.0)
|
||||
before = X.mean(dim=1).norm(dim=-1)
|
||||
after = augmented.mean(dim=1).norm(dim=-1)
|
||||
assert torch.allclose(before, after, atol=1e-4)
|
||||
|
||||
|
||||
def test_centre_random_augmentation_preserves_present_distances():
|
||||
"""The composed centre+augment step stays rigid over the present atoms."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(3, 12, 3)
|
||||
mask = torch.ones(3, 12, dtype=torch.bool)
|
||||
|
||||
result = centre_random_augmentation(X, mask, s_trans=1.0)
|
||||
assert result.shape == X.shape
|
||||
# centre then a rigid transform preserves all pairwise distances.
|
||||
centred = centre(X, mask)
|
||||
assert torch.allclose(
|
||||
_pairwise_distances(centred), _pairwise_distances(result), atol=1e-4
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-v", __file__])
|
||||
Reference in New Issue
Block a user