diff --git a/foldingdiff/datasets.py b/foldingdiff/datasets.py index b3e30aa..67ed551 100644 --- a/foldingdiff/datasets.py +++ b/foldingdiff/datasets.py @@ -376,7 +376,7 @@ class CathCanonicalAnglesDataset(Dataset): if self.means is not None and not ignore_zero_center: assert ( self.means.shape[0] == angles.shape[1] - ), f"Mismatched shapes: {self.means.shape} != {angles.shape}" + ), f"Mismatched shapes for mean offset: {self.means.shape} != {angles.shape}" angles = angles - self.means # The distance features all contain a single ":" @@ -527,6 +527,13 @@ class CathCanonicalAnglesOnlyDataset(CathCanonicalAnglesDataset): return None return np.copy(self.means)[self.feature_idx] + def set_masked_means(self, mean_values: np.ndarray) -> None: + """Set the means to the subset of features used""" + if self.means is None: + raise NotImplementedError + logging.info(f"Setting means for features {self.feature_idx} <- {mean_values}") + self.means[self.feature_idx] = mean_values.copy() + def __getitem__( self, index, ignore_zero_center: bool = False ) -> Dict[str, torch.Tensor]: