From d8b0be6015bd4079e907c938dbb7765feff86b9c Mon Sep 17 00:00:00 2001 From: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> Date: Tue, 2 Jun 2026 15:45:33 -0600 Subject: [PATCH] fix: make weighted_rigid_align dtype-agnostic (accept float64) (#286) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: bootstrap mypy + pytest + coverage CI gates Wire up the tooling the upcoming test/annotation/refactor work depends on: - Add mypy>=1.13 to [dev] and a lenient [tool.mypy] config scoped to src/foundry + src/foundry_cli (ignore_missing_imports, no disallow_untyped_defs). The 14 modules with pre-existing type errors are pinned via [[tool.mypy.overrides]] ignore_errors=true and listed as the ratchet target — fix and remove, never add. - Add [tool.pytest.ini_options] (testpaths=["tests"], --strict-markers --strict-config) and [tool.coverage.*] (source = src/foundry + src/foundry_cli, branch = true) for opt-in gap finding. - Add .github/workflows/test.yaml with mypy and pytest jobs running on the same triggers as lint_production.yaml. Top-level tests/ only; per-model tests under models/*/tests/ may require GPU and checkpoints and stay out of CI for now. Co-authored-by: Sergey Lyskov * test(mypy): extend ignore list with 4 modules CI surfaced Local sanity-check ran mypy without foundry installed, so torch / lightning resolved to `Any` and errors that depend on knowing those types stayed invisible. First CI run installed the full deps and surfaced 6 errors in 4 additional modules: foundry.model.layers.blocks foundry.training.schedulers foundry.utils.logging foundry.utils.xpu.xpu_accelerator Same ratchet contract as the original 14: do not add, only remove (after fixing the errors and removing `ignore_errors = true`). Co-authored-by: Sergey Lyskov * style: ruff format 2 pre-existing files unrelated to bootstrap These files have been failing `ruff format --check` on production HEAD (merged via #275 and #281 without a pre-commit run). They block the existing `lint_production` workflow on every PR, including this bootstrap. Strictly out of scope for 0001 — kept in a separate commit so it can be cherry-picked or reverted cleanly. models/rfd3/src/rfd3/inference/input_parsing.py models/rfd3/tests/test_partial_diffusion.py No semantic changes — `ruff format` output only. Co-authored-by: Sergey Lyskov * ci: run test workflow on all pull requests Drop the base-branch filter on the pull_request trigger so PRs targeting stacked task branches are gated too, not only PRs into the mainline branches. Co-authored-by: Sergey Lyskov * test: unit tests for foundry geometry utils (alignment, rotation augmentation) Pin the numeric behaviour of weighted_rigid_align/get_rmsd and the rotation/centre/augment helpers used across the diffusion models: exact recovery of known rigid transforms, mask/weight exclusion from the fit, proper-rotation invariants, global-centroid centring, and distance-preserving augmentation. CPU-only, float32, 18 tests. Co-authored-by: Sergey Lyskov * fix: make weighted_rigid_align dtype-agnostic (accept float64) The determinant-correction matrix F was built with torch.eye's default dtype (float32), so float64 coordinate inputs raised at U @ F @ V. Pass dtype=X_L.dtype so F follows the input. Add a float64 regression test. Co-authored-by: Sergey Lyskov * Add test for float64 coords with float32 weights Add test for float64 coordinates with float32 weights to ensure correct output dtype and alignment. * Add additional blank line before functions Fix formatting by adding blank line before functions in test_alignment --------- Co-authored-by: Sergey Lyskov Co-authored-by: Hope Woods --- src/foundry/utils/alignment.py | 7 ++++- tests/test_alignment.py | 47 +++++++++++++++++++++++++++++----- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/src/foundry/utils/alignment.py b/src/foundry/utils/alignment.py index c5d066d..891744d 100644 --- a/src/foundry/utils/alignment.py +++ b/src/foundry/utils/alignment.py @@ -61,7 +61,12 @@ def weighted_rigid_align( R = U @ V B, _, _ = X_L.shape - F = torch.eye(3, 3, device=X_L.device)[None].tile( + # F is the reflection correction for the Kabsch rotation: its last diagonal + # entry flips to -1 when det(R) < 0 so R stays a proper rotation. It feeds + # `U @ F @ V` below, where matmul requires a shared dtype — U and V carry + # X_L's dtype (SVD of an input-derived covariance), so F must too. torch.eye + # defaults to float32, so without dtype= float64 inputs raise a mismatch. + F = torch.eye(3, 3, device=X_L.device, dtype=X_L.dtype)[None].tile( ( B, 1, diff --git a/tests/test_alignment.py b/tests/test_alignment.py index af72f4a..acddad8 100644 --- a/tests/test_alignment.py +++ b/tests/test_alignment.py @@ -6,9 +6,9 @@ sampling. Its contract (recover an exact rigid transform, ignore points the mask/weights exclude from the fit, detach the output) is not obvious from the signature, so the tests below pin it on representative CPU inputs. -The implementation builds its determinant-correction matrix with the default -float dtype, so it only accepts float32 inputs in practice; all tests use -float32 to match production call sites. +Most tests use float32 to match production call sites; one float64 test guards +that the det-correction matrix follows the input dtype rather than defaulting to +float32 (which used to make float64 inputs raise). """ import pytest @@ -17,11 +17,11 @@ import torch from foundry.utils.alignment import get_rmsd, weighted_rigid_align -def _rotation_about_z(angle: float) -> torch.Tensor: - """Proper rotation (det +1) about the z-axis, as a [3, 3] float32 matrix.""" - a = torch.tensor(angle) +def _rotation_about_z(angle: float, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Proper rotation (det +1) about the z-axis, as a [3, 3] matrix in `dtype`.""" + a = torch.tensor(angle, dtype=dtype) c, s = torch.cos(a), torch.sin(a) - return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]]) + return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=dtype) def test_identity_alignment_is_noop(): @@ -53,6 +53,39 @@ def test_recovers_pure_translation(): assert torch.allclose(aligned, X, atol=1e-4) +def test_float64_inputs_supported(): + """float64 coordinates align in float64 — det-correction matrix follows the input dtype.""" + torch.manual_seed(0) + X = torch.randn(1, 16, 3, dtype=torch.float64) + R = _rotation_about_z(1.0, dtype=torch.float64) + t = torch.tensor([3.0, -2.0, 5.0], dtype=torch.float64) + X_gt = X @ R.T + t # an exact rigid image of X, in float64 + + aligned = weighted_rigid_align(X, X_gt) + assert aligned.dtype == torch.float64 + assert torch.allclose(aligned, X, atol=1e-10) + + +def test_float64_coords_with_float32_weights(): + """float64 coords + explicit float32 w_L: output is float64, alignment is correct. + + w_L is cast inside the function (currently to float32, potentially to X_L.dtype + in future). Either way the weighted products promote to float64, so the output + dtype must follow the coordinate dtype, not the weight dtype. + """ + torch.manual_seed(0) + X = torch.randn(1, 16, 3, dtype=torch.float64) + R = _rotation_about_z(1.0, dtype=torch.float64) + t = torch.tensor([3.0, -2.0, 5.0], dtype=torch.float64) + X_gt = X @ R.T + t + + w = torch.ones(1, 16, dtype=torch.float32) # deliberately mismatched dtype + aligned = weighted_rigid_align(X, X_gt, w_L=w) + + assert aligned.dtype == torch.float64 + assert torch.allclose(aligned, X, atol=1e-10) + + def test_output_is_detached_and_canonicalized(): """Output is detached and a bare [L, 3] input is promoted to [1, L, 3].""" torch.manual_seed(0)