mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Automatically un-shift the sampled data if data was originally shifted
This commit is contained in:
@@ -74,7 +74,6 @@ def p_sample_loop(
|
||||
noise: torch.Tensor,
|
||||
timesteps: int,
|
||||
betas: torch.Tensor,
|
||||
noise_modulo: Optional[Union[float, torch.Tensor]] = None,
|
||||
is_angle: Union[bool, List[bool]] = [False, True, True, True],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -86,7 +85,7 @@ def p_sample_loop(
|
||||
# Report metrics on starting noise
|
||||
# amin and amax support reducing on multiple dimensions
|
||||
logging.info(
|
||||
f"Starting from noise {noise.shape} with modulo {noise_modulo} and range {torch.amin(img, dim=(0, 1))} - {torch.amax(img, dim=(0, 1))} using {device}"
|
||||
f"Starting from noise {noise.shape} with angularity {is_angle} and range {torch.amin(img, dim=(0, 1))} - {torch.amax(img, dim=(0, 1))} using {device}"
|
||||
)
|
||||
|
||||
imgs = []
|
||||
@@ -144,9 +143,28 @@ def sample(
|
||||
noise=noise,
|
||||
timesteps=train_dset.timesteps,
|
||||
betas=train_dset.alpha_beta_terms["betas"],
|
||||
is_angle=True,
|
||||
is_angle=train_dset.feature_is_angular["angles"],
|
||||
)
|
||||
# Gets to size (timesteps, seq_len, n_ft)
|
||||
trimmed_sampled = [sampled[:, i, :l, :] for i, l in enumerate(lengths)]
|
||||
retval.extend(trimmed_sampled)
|
||||
# Note that we don't use means variable here directly because we may need a subset
|
||||
# of it based on which features are active in the dataset. The function
|
||||
# get_masked_means handles this gracefully
|
||||
if (
|
||||
hasattr(train_dset, "dset")
|
||||
and hasattr(train_dset.dset, "get_masked_means")
|
||||
and train_dset.dset.get_masked_means() is not None
|
||||
):
|
||||
logging.info(
|
||||
f"Shifting predicted values by original offset: {train_dset.dset.means}"
|
||||
)
|
||||
retval = [s + train_dset.dset.get_masked_means() for s in retval]
|
||||
# Because shifting may have caused us to go across the circle boundary, re-wrap
|
||||
angular_idx = np.where(train_dset.feature_is_angular["angles"])[0]
|
||||
for s in retval:
|
||||
s[..., angular_idx] = utils.modulo_with_wrapped_range(
|
||||
s[..., angular_idx], range_min=-np.pi, range_max=np.pi
|
||||
)
|
||||
|
||||
return retval
|
||||
|
||||
Reference in New Issue
Block a user