diff --git a/src/foundry/utils/alignment.py b/src/foundry/utils/alignment.py index c5d066d..891744d 100644 --- a/src/foundry/utils/alignment.py +++ b/src/foundry/utils/alignment.py @@ -61,7 +61,12 @@ def weighted_rigid_align( R = U @ V B, _, _ = X_L.shape - F = torch.eye(3, 3, device=X_L.device)[None].tile( + # F is the reflection correction for the Kabsch rotation: its last diagonal + # entry flips to -1 when det(R) < 0 so R stays a proper rotation. It feeds + # `U @ F @ V` below, where matmul requires a shared dtype — U and V carry + # X_L's dtype (SVD of an input-derived covariance), so F must too. torch.eye + # defaults to float32, so without dtype= float64 inputs raise a mismatch. + F = torch.eye(3, 3, device=X_L.device, dtype=X_L.dtype)[None].tile( ( B, 1, diff --git a/tests/test_alignment.py b/tests/test_alignment.py index af72f4a..acddad8 100644 --- a/tests/test_alignment.py +++ b/tests/test_alignment.py @@ -6,9 +6,9 @@ sampling. Its contract (recover an exact rigid transform, ignore points the mask/weights exclude from the fit, detach the output) is not obvious from the signature, so the tests below pin it on representative CPU inputs. -The implementation builds its determinant-correction matrix with the default -float dtype, so it only accepts float32 inputs in practice; all tests use -float32 to match production call sites. +Most tests use float32 to match production call sites; one float64 test guards +that the det-correction matrix follows the input dtype rather than defaulting to +float32 (which used to make float64 inputs raise). """ import pytest @@ -17,11 +17,11 @@ import torch from foundry.utils.alignment import get_rmsd, weighted_rigid_align -def _rotation_about_z(angle: float) -> torch.Tensor: - """Proper rotation (det +1) about the z-axis, as a [3, 3] float32 matrix.""" - a = torch.tensor(angle) +def _rotation_about_z(angle: float, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Proper rotation (det +1) about the z-axis, as a [3, 3] matrix in `dtype`.""" + a = torch.tensor(angle, dtype=dtype) c, s = torch.cos(a), torch.sin(a) - return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]]) + return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=dtype) def test_identity_alignment_is_noop(): @@ -53,6 +53,39 @@ def test_recovers_pure_translation(): assert torch.allclose(aligned, X, atol=1e-4) +def test_float64_inputs_supported(): + """float64 coordinates align in float64 — det-correction matrix follows the input dtype.""" + torch.manual_seed(0) + X = torch.randn(1, 16, 3, dtype=torch.float64) + R = _rotation_about_z(1.0, dtype=torch.float64) + t = torch.tensor([3.0, -2.0, 5.0], dtype=torch.float64) + X_gt = X @ R.T + t # an exact rigid image of X, in float64 + + aligned = weighted_rigid_align(X, X_gt) + assert aligned.dtype == torch.float64 + assert torch.allclose(aligned, X, atol=1e-10) + + +def test_float64_coords_with_float32_weights(): + """float64 coords + explicit float32 w_L: output is float64, alignment is correct. + + w_L is cast inside the function (currently to float32, potentially to X_L.dtype + in future). Either way the weighted products promote to float64, so the output + dtype must follow the coordinate dtype, not the weight dtype. + """ + torch.manual_seed(0) + X = torch.randn(1, 16, 3, dtype=torch.float64) + R = _rotation_about_z(1.0, dtype=torch.float64) + t = torch.tensor([3.0, -2.0, 5.0], dtype=torch.float64) + X_gt = X @ R.T + t + + w = torch.ones(1, 16, dtype=torch.float32) # deliberately mismatched dtype + aligned = weighted_rigid_align(X, X_gt, w_L=w) + + assert aligned.dtype == torch.float64 + assert torch.allclose(aligned, X, atol=1e-10) + + def test_output_is_detached_and_canonicalized(): """Output is detached and a bare [L, 3] input is promoted to [1, L, 3].""" torch.manual_seed(0)