Automatically un-shift the sampled data if data was originally shifted

This commit is contained in:
Kevin Wu
2022-08-31 16:00:12 -07:00
parent fdb3c2bf96
commit cffc794833

View File

@@ -74,7 +74,6 @@ def p_sample_loop(
noise: torch.Tensor,
timesteps: int,
betas: torch.Tensor,
noise_modulo: Optional[Union[float, torch.Tensor]] = None,
is_angle: Union[bool, List[bool]] = [False, True, True, True],
) -> torch.Tensor:
"""
@@ -86,7 +85,7 @@ def p_sample_loop(
# Report metrics on starting noise
# amin and amax support reducing on multiple dimensions
logging.info(
f"Starting from noise {noise.shape} with modulo {noise_modulo} and range {torch.amin(img, dim=(0, 1))} - {torch.amax(img, dim=(0, 1))} using {device}"
f"Starting from noise {noise.shape} with angularity {is_angle} and range {torch.amin(img, dim=(0, 1))} - {torch.amax(img, dim=(0, 1))} using {device}"
)
imgs = []
@@ -144,9 +143,28 @@ def sample(
noise=noise,
timesteps=train_dset.timesteps,
betas=train_dset.alpha_beta_terms["betas"],
is_angle=True,
is_angle=train_dset.feature_is_angular["angles"],
)
# Gets to size (timesteps, seq_len, n_ft)
trimmed_sampled = [sampled[:, i, :l, :] for i, l in enumerate(lengths)]
retval.extend(trimmed_sampled)
# Note that we don't use means variable here directly because we may need a subset
# of it based on which features are active in the dataset. The function
# get_masked_means handles this gracefully
if (
hasattr(train_dset, "dset")
and hasattr(train_dset.dset, "get_masked_means")
and train_dset.dset.get_masked_means() is not None
):
logging.info(
f"Shifting predicted values by original offset: {train_dset.dset.means}"
)
retval = [s + train_dset.dset.get_masked_means() for s in retval]
# Because shifting may have caused us to go across the circle boundary, re-wrap
angular_idx = np.where(train_dset.feature_is_angular["angles"])[0]
for s in retval:
s[..., angular_idx] = utils.modulo_with_wrapped_range(
s[..., angular_idx], range_min=-np.pi, range_max=np.pi
)
return retval