fix: make weighted_rigid_align dtype-agnostic (accept float64) (#286)

* test: bootstrap mypy + pytest + coverage CI gates

Wire up the tooling the upcoming test/annotation/refactor work depends on:

- Add mypy>=1.13 to [dev] and a lenient [tool.mypy] config scoped to
  src/foundry + src/foundry_cli (ignore_missing_imports, no
  disallow_untyped_defs). The 14 modules with pre-existing type errors
  are pinned via [[tool.mypy.overrides]] ignore_errors=true and listed
  as the ratchet target — fix and remove, never add.
- Add [tool.pytest.ini_options] (testpaths=["tests"], --strict-markers
  --strict-config) and [tool.coverage.*] (source = src/foundry +
  src/foundry_cli, branch = true) for opt-in gap finding.
- Add .github/workflows/test.yaml with mypy and pytest jobs running on
  the same triggers as lint_production.yaml. Top-level tests/ only;
  per-model tests under models/*/tests/ may require GPU and checkpoints
  and stay out of CI for now.

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* test(mypy): extend ignore list with 4 modules CI surfaced

Local sanity-check ran mypy without foundry installed, so torch /
lightning resolved to `Any` and errors that depend on knowing those
types stayed invisible. First CI run installed the full deps and
surfaced 6 errors in 4 additional modules:

  foundry.model.layers.blocks
  foundry.training.schedulers
  foundry.utils.logging
  foundry.utils.xpu.xpu_accelerator

Same ratchet contract as the original 14: do not add, only remove
(after fixing the errors and removing `ignore_errors = true`).

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* style: ruff format 2 pre-existing files unrelated to bootstrap

These files have been failing `ruff format --check` on production HEAD
(merged via #275 and #281 without a pre-commit run). They block the
existing `lint_production` workflow on every PR, including this
bootstrap. Strictly out of scope for 0001 — kept in a separate commit
so it can be cherry-picked or reverted cleanly.

  models/rfd3/src/rfd3/inference/input_parsing.py
  models/rfd3/tests/test_partial_diffusion.py

No semantic changes — `ruff format` output only.

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* ci: run test workflow on all pull requests

Drop the base-branch filter on the pull_request trigger so PRs targeting
stacked task branches are gated too, not only PRs into the mainline branches.

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* test: unit tests for foundry geometry utils (alignment, rotation augmentation)

Pin the numeric behaviour of weighted_rigid_align/get_rmsd and the
rotation/centre/augment helpers used across the diffusion models:
exact recovery of known rigid transforms, mask/weight exclusion from
the fit, proper-rotation invariants, global-centroid centring, and
distance-preserving augmentation. CPU-only, float32, 18 tests.

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* fix: make weighted_rigid_align dtype-agnostic (accept float64)

The determinant-correction matrix F was built with torch.eye's default
dtype (float32), so float64 coordinate inputs raised at U @ F @ V. Pass
dtype=X_L.dtype so F follows the input. Add a float64 regression test.

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>

* Add test for float64 coords with float32 weights

Add test for float64 coordinates with float32 weights to ensure correct output dtype and alignment.

* Add additional blank line before functions

Fix formatting by adding blank line before functions in test_alignment

---------

Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu>
Co-authored-by: Hope Woods <hope.woods@omsf.io>
This commit is contained in:
lyskov-ai
2026-06-02 15:45:33 -06:00
committed by GitHub
parent 2d66d998ce
commit d8b0be6015
2 changed files with 46 additions and 8 deletions

View File

@@ -61,7 +61,12 @@ def weighted_rigid_align(
R = U @ V
B, _, _ = X_L.shape
F = torch.eye(3, 3, device=X_L.device)[None].tile(
# F is the reflection correction for the Kabsch rotation: its last diagonal
# entry flips to -1 when det(R) < 0 so R stays a proper rotation. It feeds
# `U @ F @ V` below, where matmul requires a shared dtype — U and V carry
# X_L's dtype (SVD of an input-derived covariance), so F must too. torch.eye
# defaults to float32, so without dtype= float64 inputs raise a mismatch.
F = torch.eye(3, 3, device=X_L.device, dtype=X_L.dtype)[None].tile(
(
B,
1,

View File

@@ -6,9 +6,9 @@ sampling. Its contract (recover an exact rigid transform, ignore points the
mask/weights exclude from the fit, detach the output) is not obvious from the
signature, so the tests below pin it on representative CPU inputs.
The implementation builds its determinant-correction matrix with the default
float dtype, so it only accepts float32 inputs in practice; all tests use
float32 to match production call sites.
Most tests use float32 to match production call sites; one float64 test guards
that the det-correction matrix follows the input dtype rather than defaulting to
float32 (which used to make float64 inputs raise).
"""
import pytest
@@ -17,11 +17,11 @@ import torch
from foundry.utils.alignment import get_rmsd, weighted_rigid_align
def _rotation_about_z(angle: float) -> torch.Tensor:
"""Proper rotation (det +1) about the z-axis, as a [3, 3] float32 matrix."""
a = torch.tensor(angle)
def _rotation_about_z(angle: float, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Proper rotation (det +1) about the z-axis, as a [3, 3] matrix in `dtype`."""
a = torch.tensor(angle, dtype=dtype)
c, s = torch.cos(a), torch.sin(a)
return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]])
return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=dtype)
def test_identity_alignment_is_noop():
@@ -53,6 +53,39 @@ def test_recovers_pure_translation():
assert torch.allclose(aligned, X, atol=1e-4)
def test_float64_inputs_supported():
"""float64 coordinates align in float64 — det-correction matrix follows the input dtype."""
torch.manual_seed(0)
X = torch.randn(1, 16, 3, dtype=torch.float64)
R = _rotation_about_z(1.0, dtype=torch.float64)
t = torch.tensor([3.0, -2.0, 5.0], dtype=torch.float64)
X_gt = X @ R.T + t # an exact rigid image of X, in float64
aligned = weighted_rigid_align(X, X_gt)
assert aligned.dtype == torch.float64
assert torch.allclose(aligned, X, atol=1e-10)
def test_float64_coords_with_float32_weights():
"""float64 coords + explicit float32 w_L: output is float64, alignment is correct.
w_L is cast inside the function (currently to float32, potentially to X_L.dtype
in future). Either way the weighted products promote to float64, so the output
dtype must follow the coordinate dtype, not the weight dtype.
"""
torch.manual_seed(0)
X = torch.randn(1, 16, 3, dtype=torch.float64)
R = _rotation_about_z(1.0, dtype=torch.float64)
t = torch.tensor([3.0, -2.0, 5.0], dtype=torch.float64)
X_gt = X @ R.T + t
w = torch.ones(1, 16, dtype=torch.float32) # deliberately mismatched dtype
aligned = weighted_rigid_align(X, X_gt, w_L=w)
assert aligned.dtype == torch.float64
assert torch.allclose(aligned, X, atol=1e-10)
def test_output_is_detached_and_canonicalized():
"""Output is detached and a bare [L, 3] input is promoted to [1, L, 3]."""
torch.manual_seed(0)