Code to manually set angle means

This commit is contained in:
Kevin Wu
2023-02-13 16:18:56 -08:00
parent 6a05cb4104
commit cb306a3e1b

View File

@@ -376,7 +376,7 @@ class CathCanonicalAnglesDataset(Dataset):
if self.means is not None and not ignore_zero_center:
assert (
self.means.shape[0] == angles.shape[1]
), f"Mismatched shapes: {self.means.shape} != {angles.shape}"
), f"Mismatched shapes for mean offset: {self.means.shape} != {angles.shape}"
angles = angles - self.means
# The distance features all contain a single ":"
@@ -527,6 +527,13 @@ class CathCanonicalAnglesOnlyDataset(CathCanonicalAnglesDataset):
return None
return np.copy(self.means)[self.feature_idx]
def set_masked_means(self, mean_values: np.ndarray) -> None:
"""Set the means to the subset of features used"""
if self.means is None:
raise NotImplementedError
logging.info(f"Setting means for features {self.feature_idx} <- {mean_values}")
self.means[self.feature_idx] = mean_values.copy()
def __getitem__(
self, index, ignore_zero_center: bool = False
) -> Dict[str, torch.Tensor]: