Moar checks

This commit is contained in:
Kevin Wu
2022-11-08 21:54:35 -08:00
parent 4be2d89309
commit 966a3be053

View File

@@ -42,6 +42,9 @@ def radian_smooth_l1_loss(
>>> radian_smooth_l1_loss(torch.tensor(-17.0466), torch.tensor(-1.3888), beta=0.1)
tensor(3.0414)
"""
assert (
target.shape == input.shape
), f"Mismatched shapes: {input.shape} != {target.shape}"
assert beta > 0
d = target - input
d = utils.modulo_with_wrapped_range(d, -torch.pi, torch.pi)