mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Moar checks
This commit is contained in:
@@ -42,6 +42,9 @@ def radian_smooth_l1_loss(
|
||||
>>> radian_smooth_l1_loss(torch.tensor(-17.0466), torch.tensor(-1.3888), beta=0.1)
|
||||
tensor(3.0414)
|
||||
"""
|
||||
assert (
|
||||
target.shape == input.shape
|
||||
), f"Mismatched shapes: {input.shape} != {target.shape}"
|
||||
assert beta > 0
|
||||
d = target - input
|
||||
d = utils.modulo_with_wrapped_range(d, -torch.pi, torch.pi)
|
||||
|
||||
Reference in New Issue
Block a user