mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Code to manually set angle means
This commit is contained in:
@@ -376,7 +376,7 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
if self.means is not None and not ignore_zero_center:
|
||||
assert (
|
||||
self.means.shape[0] == angles.shape[1]
|
||||
), f"Mismatched shapes: {self.means.shape} != {angles.shape}"
|
||||
), f"Mismatched shapes for mean offset: {self.means.shape} != {angles.shape}"
|
||||
angles = angles - self.means
|
||||
|
||||
# The distance features all contain a single ":"
|
||||
@@ -527,6 +527,13 @@ class CathCanonicalAnglesOnlyDataset(CathCanonicalAnglesDataset):
|
||||
return None
|
||||
return np.copy(self.means)[self.feature_idx]
|
||||
|
||||
def set_masked_means(self, mean_values: np.ndarray) -> None:
|
||||
"""Set the means to the subset of features used"""
|
||||
if self.means is None:
|
||||
raise NotImplementedError
|
||||
logging.info(f"Setting means for features {self.feature_idx} <- {mean_values}")
|
||||
self.means[self.feature_idx] = mean_values.copy()
|
||||
|
||||
def __getitem__(
|
||||
self, index, ignore_zero_center: bool = False
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user