diff --git a/protdiff/sampling.py b/protdiff/sampling.py index 202e67c..76f8a9e 100644 --- a/protdiff/sampling.py +++ b/protdiff/sampling.py @@ -74,7 +74,6 @@ def p_sample_loop( noise: torch.Tensor, timesteps: int, betas: torch.Tensor, - noise_modulo: Optional[Union[float, torch.Tensor]] = None, is_angle: Union[bool, List[bool]] = [False, True, True, True], ) -> torch.Tensor: """ @@ -86,7 +85,7 @@ def p_sample_loop( # Report metrics on starting noise # amin and amax support reducing on multiple dimensions logging.info( - f"Starting from noise {noise.shape} with modulo {noise_modulo} and range {torch.amin(img, dim=(0, 1))} - {torch.amax(img, dim=(0, 1))} using {device}" + f"Starting from noise {noise.shape} with angularity {is_angle} and range {torch.amin(img, dim=(0, 1))} - {torch.amax(img, dim=(0, 1))} using {device}" ) imgs = [] @@ -144,9 +143,28 @@ def sample( noise=noise, timesteps=train_dset.timesteps, betas=train_dset.alpha_beta_terms["betas"], - is_angle=True, + is_angle=train_dset.feature_is_angular["angles"], ) # Gets to size (timesteps, seq_len, n_ft) trimmed_sampled = [sampled[:, i, :l, :] for i, l in enumerate(lengths)] retval.extend(trimmed_sampled) + # Note that we don't use means variable here directly because we may need a subset + # of it based on which features are active in the dataset. The function + # get_masked_means handles this gracefully + if ( + hasattr(train_dset, "dset") + and hasattr(train_dset.dset, "get_masked_means") + and train_dset.dset.get_masked_means() is not None + ): + logging.info( + f"Shifting predicted values by original offset: {train_dset.dset.means}" + ) + retval = [s + train_dset.dset.get_masked_means() for s in retval] + # Because shifting may have caused us to go across the circle boundary, re-wrap + angular_idx = np.where(train_dset.feature_is_angular["angles"])[0] + for s in retval: + s[..., angular_idx] = utils.modulo_with_wrapped_range( + s[..., angular_idx], range_min=-np.pi, range_max=np.pi + ) + return retval