diff --git a/protdiff/datasets.py b/protdiff/datasets.py index b073b6c..dbf376c 100644 --- a/protdiff/datasets.py +++ b/protdiff/datasets.py @@ -59,10 +59,12 @@ class CathCanonicalAnglesDataset(Dataset): "tau", "CA:C:1N", "C:1N:1CA", - ] + ], + "coords": ["x", "y", "z"], } feature_is_angular = { - "angles": [False, False, False, True, True, True, True, True, True] + "angles": [False, False, False, True, True, True, True, True, True], + "coords": [False, False, False], } cache_fname = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -254,9 +256,9 @@ class CathCanonicalAnglesDataset(Dataset): raise IndexError("Index out of range") angles = self.structures[index]["angles"] - # NOTE coords are NOT shifted or wrapped + # NOTE coords are NOT shifted or wrapped, has same length as angles coords = self.structures[index]["coords"] - print(angles.shape, coords.shape) + assert angles.shape[0] == coords.shape[0] # If given, offset the angles with mean if self.means is not None and not ignore_zero_center: @@ -302,16 +304,24 @@ class CathCanonicalAnglesDataset(Dataset): mode="constant", constant_values=0, ) + coords = np.pad( + coords, + ((0, self.pad - coords.shape[0]), (0, 0)), + mode="constant", + constant_values=0, + ) elif angles.shape[0] > self.pad: if self.trim_strategy == "leftalign": angles = angles[: self.pad] + coords = coords[: self.pad] elif self.trim_strategy == "randomcrop": # Randomly crop the sequence to start_idx = self.rng.integers(0, angles.shape[0] - self.pad) end_idx = start_idx + self.pad assert end_idx < angles.shape[0] angles = angles[start_idx:end_idx] - assert angles.shape[0] == self.pad + coords = coords[start_idx:end_idx] + assert angles.shape[0] == coords.shape[0] == self.pad else: raise ValueError(f"Unknown trim strategy: {self.trim_strategy}") @@ -328,9 +338,11 @@ class CathCanonicalAnglesDataset(Dataset): angles[:, angular_idx], "<=", np.pi ), f"Illegal value: {np.max(angles[:, angular_idx])}" angles = torch.from_numpy(angles).float() + coords = torch.from_numpy(coords).float() retval = { "angles": angles, + "coords": coords, "attn_mask": attn_mask, "position_ids": position_ids, } @@ -355,6 +367,25 @@ class CathCanonicalAnglesDataset(Dataset): return torch.var_mean(all_vals)[::-1] # Default is (var, mean) +class CathCanonicalCoordsDataset(CathCanonicalAnglesDataset): + """ + Building on the CATH dataset, return the XYZ coordaintes of each alpha carbon + """ + + feature_names = {"coords": list("xyz")} + feature_is_angular = {"coords": [False, False, False]} + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def __getitem__( + self, index, ignore_zero_center: bool = True + ) -> Dict[str, torch.Tensor]: + return_dict = super().__getitem__(index, ignore_zero_center=ignore_zero_center) + return_dict.pop("angles", None) + return return_dict + + class CathCanonicalAnglesOnlyDataset(CathCanonicalAnglesDataset): """ Building on the CATH dataset, return the 3 canonical dihedrals and the 3 @@ -398,6 +429,7 @@ class CathCanonicalAnglesOnlyDataset(CathCanonicalAnglesDataset): assert torch.all( return_dict["angles"] <= torch.pi ), f"Maximum value {torch.max(return_dict['angles'])} higher than pi" + return_dict.pop("coords", None) return return_dict @@ -1000,11 +1032,13 @@ class ScoreMatchingNoisedAnglesDataset(Dataset): def main(): - dset = CathCanonicalAnglesDataset(pad=128, trim_strategy="discard", use_cache=False) - # noised_dset = NoisedAnglesDataset(dset, dset_key="angles") + dset = CathCanonicalCoordsDataset( + pad=128, trim_strategy="discard", use_cache=False, zero_center=False + ) + noised_dset = NoisedAnglesDataset(dset, dset_key="coords") # print(len(noised_dset)) # print(noised_dset[0]) - x = dset[0] + x = noised_dset[0] # x = noised_dset[0] for k, v in x.items():