mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fix: make weighted_rigid_align dtype-agnostic (accept float64) (#286)
* test: bootstrap mypy + pytest + coverage CI gates Wire up the tooling the upcoming test/annotation/refactor work depends on: - Add mypy>=1.13 to [dev] and a lenient [tool.mypy] config scoped to src/foundry + src/foundry_cli (ignore_missing_imports, no disallow_untyped_defs). The 14 modules with pre-existing type errors are pinned via [[tool.mypy.overrides]] ignore_errors=true and listed as the ratchet target — fix and remove, never add. - Add [tool.pytest.ini_options] (testpaths=["tests"], --strict-markers --strict-config) and [tool.coverage.*] (source = src/foundry + src/foundry_cli, branch = true) for opt-in gap finding. - Add .github/workflows/test.yaml with mypy and pytest jobs running on the same triggers as lint_production.yaml. Top-level tests/ only; per-model tests under models/*/tests/ may require GPU and checkpoints and stay out of CI for now. Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * test(mypy): extend ignore list with 4 modules CI surfaced Local sanity-check ran mypy without foundry installed, so torch / lightning resolved to `Any` and errors that depend on knowing those types stayed invisible. First CI run installed the full deps and surfaced 6 errors in 4 additional modules: foundry.model.layers.blocks foundry.training.schedulers foundry.utils.logging foundry.utils.xpu.xpu_accelerator Same ratchet contract as the original 14: do not add, only remove (after fixing the errors and removing `ignore_errors = true`). Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * style: ruff format 2 pre-existing files unrelated to bootstrap These files have been failing `ruff format --check` on production HEAD (merged via #275 and #281 without a pre-commit run). They block the existing `lint_production` workflow on every PR, including this bootstrap. Strictly out of scope for 0001 — kept in a separate commit so it can be cherry-picked or reverted cleanly. models/rfd3/src/rfd3/inference/input_parsing.py models/rfd3/tests/test_partial_diffusion.py No semantic changes — `ruff format` output only. Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * ci: run test workflow on all pull requests Drop the base-branch filter on the pull_request trigger so PRs targeting stacked task branches are gated too, not only PRs into the mainline branches. Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * test: unit tests for foundry geometry utils (alignment, rotation augmentation) Pin the numeric behaviour of weighted_rigid_align/get_rmsd and the rotation/centre/augment helpers used across the diffusion models: exact recovery of known rigid transforms, mask/weight exclusion from the fit, proper-rotation invariants, global-centroid centring, and distance-preserving augmentation. CPU-only, float32, 18 tests. Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * fix: make weighted_rigid_align dtype-agnostic (accept float64) The determinant-correction matrix F was built with torch.eye's default dtype (float32), so float64 coordinate inputs raised at U @ F @ V. Pass dtype=X_L.dtype so F follows the input. Add a float64 regression test. Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> * Add test for float64 coords with float32 weights Add test for float64 coordinates with float32 weights to ensure correct output dtype and alignment. * Add additional blank line before functions Fix formatting by adding blank line before functions in test_alignment --------- Co-authored-by: Sergey Lyskov <sergey.lyskov@jhu.edu> Co-authored-by: Hope Woods <hope.woods@omsf.io>
This commit is contained in:
@@ -61,7 +61,12 @@ def weighted_rigid_align(
|
||||
|
||||
R = U @ V
|
||||
B, _, _ = X_L.shape
|
||||
F = torch.eye(3, 3, device=X_L.device)[None].tile(
|
||||
# F is the reflection correction for the Kabsch rotation: its last diagonal
|
||||
# entry flips to -1 when det(R) < 0 so R stays a proper rotation. It feeds
|
||||
# `U @ F @ V` below, where matmul requires a shared dtype — U and V carry
|
||||
# X_L's dtype (SVD of an input-derived covariance), so F must too. torch.eye
|
||||
# defaults to float32, so without dtype= float64 inputs raise a mismatch.
|
||||
F = torch.eye(3, 3, device=X_L.device, dtype=X_L.dtype)[None].tile(
|
||||
(
|
||||
B,
|
||||
1,
|
||||
|
||||
@@ -6,9 +6,9 @@ sampling. Its contract (recover an exact rigid transform, ignore points the
|
||||
mask/weights exclude from the fit, detach the output) is not obvious from the
|
||||
signature, so the tests below pin it on representative CPU inputs.
|
||||
|
||||
The implementation builds its determinant-correction matrix with the default
|
||||
float dtype, so it only accepts float32 inputs in practice; all tests use
|
||||
float32 to match production call sites.
|
||||
Most tests use float32 to match production call sites; one float64 test guards
|
||||
that the det-correction matrix follows the input dtype rather than defaulting to
|
||||
float32 (which used to make float64 inputs raise).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@@ -17,11 +17,11 @@ import torch
|
||||
from foundry.utils.alignment import get_rmsd, weighted_rigid_align
|
||||
|
||||
|
||||
def _rotation_about_z(angle: float) -> torch.Tensor:
|
||||
"""Proper rotation (det +1) about the z-axis, as a [3, 3] float32 matrix."""
|
||||
a = torch.tensor(angle)
|
||||
def _rotation_about_z(angle: float, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
"""Proper rotation (det +1) about the z-axis, as a [3, 3] matrix in `dtype`."""
|
||||
a = torch.tensor(angle, dtype=dtype)
|
||||
c, s = torch.cos(a), torch.sin(a)
|
||||
return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]])
|
||||
return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=dtype)
|
||||
|
||||
|
||||
def test_identity_alignment_is_noop():
|
||||
@@ -53,6 +53,39 @@ def test_recovers_pure_translation():
|
||||
assert torch.allclose(aligned, X, atol=1e-4)
|
||||
|
||||
|
||||
def test_float64_inputs_supported():
|
||||
"""float64 coordinates align in float64 — det-correction matrix follows the input dtype."""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(1, 16, 3, dtype=torch.float64)
|
||||
R = _rotation_about_z(1.0, dtype=torch.float64)
|
||||
t = torch.tensor([3.0, -2.0, 5.0], dtype=torch.float64)
|
||||
X_gt = X @ R.T + t # an exact rigid image of X, in float64
|
||||
|
||||
aligned = weighted_rigid_align(X, X_gt)
|
||||
assert aligned.dtype == torch.float64
|
||||
assert torch.allclose(aligned, X, atol=1e-10)
|
||||
|
||||
|
||||
def test_float64_coords_with_float32_weights():
|
||||
"""float64 coords + explicit float32 w_L: output is float64, alignment is correct.
|
||||
|
||||
w_L is cast inside the function (currently to float32, potentially to X_L.dtype
|
||||
in future). Either way the weighted products promote to float64, so the output
|
||||
dtype must follow the coordinate dtype, not the weight dtype.
|
||||
"""
|
||||
torch.manual_seed(0)
|
||||
X = torch.randn(1, 16, 3, dtype=torch.float64)
|
||||
R = _rotation_about_z(1.0, dtype=torch.float64)
|
||||
t = torch.tensor([3.0, -2.0, 5.0], dtype=torch.float64)
|
||||
X_gt = X @ R.T + t
|
||||
|
||||
w = torch.ones(1, 16, dtype=torch.float32) # deliberately mismatched dtype
|
||||
aligned = weighted_rigid_align(X, X_gt, w_L=w)
|
||||
|
||||
assert aligned.dtype == torch.float64
|
||||
assert torch.allclose(aligned, X, atol=1e-10)
|
||||
|
||||
|
||||
def test_output_is_detached_and_canonicalized():
|
||||
"""Output is detached and a bare [L, 3] input is promoted to [1, L, 3]."""
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user