Support for coordinate based datasets

This commit is contained in:
Kevin Wu
2022-09-22 14:52:30 -07:00
parent b0c5f28fee
commit d5a70c2d75

View File

@@ -59,10 +59,12 @@ class CathCanonicalAnglesDataset(Dataset):
"tau",
"CA:C:1N",
"C:1N:1CA",
]
],
"coords": ["x", "y", "z"],
}
feature_is_angular = {
"angles": [False, False, False, True, True, True, True, True, True]
"angles": [False, False, False, True, True, True, True, True, True],
"coords": [False, False, False],
}
cache_fname = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
@@ -254,9 +256,9 @@ class CathCanonicalAnglesDataset(Dataset):
raise IndexError("Index out of range")
angles = self.structures[index]["angles"]
# NOTE coords are NOT shifted or wrapped
# NOTE coords are NOT shifted or wrapped, has same length as angles
coords = self.structures[index]["coords"]
print(angles.shape, coords.shape)
assert angles.shape[0] == coords.shape[0]
# If given, offset the angles with mean
if self.means is not None and not ignore_zero_center:
@@ -302,16 +304,24 @@ class CathCanonicalAnglesDataset(Dataset):
mode="constant",
constant_values=0,
)
coords = np.pad(
coords,
((0, self.pad - coords.shape[0]), (0, 0)),
mode="constant",
constant_values=0,
)
elif angles.shape[0] > self.pad:
if self.trim_strategy == "leftalign":
angles = angles[: self.pad]
coords = coords[: self.pad]
elif self.trim_strategy == "randomcrop":
# Randomly crop the sequence to
start_idx = self.rng.integers(0, angles.shape[0] - self.pad)
end_idx = start_idx + self.pad
assert end_idx < angles.shape[0]
angles = angles[start_idx:end_idx]
assert angles.shape[0] == self.pad
coords = coords[start_idx:end_idx]
assert angles.shape[0] == coords.shape[0] == self.pad
else:
raise ValueError(f"Unknown trim strategy: {self.trim_strategy}")
@@ -328,9 +338,11 @@ class CathCanonicalAnglesDataset(Dataset):
angles[:, angular_idx], "<=", np.pi
), f"Illegal value: {np.max(angles[:, angular_idx])}"
angles = torch.from_numpy(angles).float()
coords = torch.from_numpy(coords).float()
retval = {
"angles": angles,
"coords": coords,
"attn_mask": attn_mask,
"position_ids": position_ids,
}
@@ -355,6 +367,25 @@ class CathCanonicalAnglesDataset(Dataset):
return torch.var_mean(all_vals)[::-1] # Default is (var, mean)
class CathCanonicalCoordsDataset(CathCanonicalAnglesDataset):
"""
Building on the CATH dataset, return the XYZ coordaintes of each alpha carbon
"""
feature_names = {"coords": list("xyz")}
feature_is_angular = {"coords": [False, False, False]}
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __getitem__(
self, index, ignore_zero_center: bool = True
) -> Dict[str, torch.Tensor]:
return_dict = super().__getitem__(index, ignore_zero_center=ignore_zero_center)
return_dict.pop("angles", None)
return return_dict
class CathCanonicalAnglesOnlyDataset(CathCanonicalAnglesDataset):
"""
Building on the CATH dataset, return the 3 canonical dihedrals and the 3
@@ -398,6 +429,7 @@ class CathCanonicalAnglesOnlyDataset(CathCanonicalAnglesDataset):
assert torch.all(
return_dict["angles"] <= torch.pi
), f"Maximum value {torch.max(return_dict['angles'])} higher than pi"
return_dict.pop("coords", None)
return return_dict
@@ -1000,11 +1032,13 @@ class ScoreMatchingNoisedAnglesDataset(Dataset):
def main():
dset = CathCanonicalAnglesDataset(pad=128, trim_strategy="discard", use_cache=False)
# noised_dset = NoisedAnglesDataset(dset, dset_key="angles")
dset = CathCanonicalCoordsDataset(
pad=128, trim_strategy="discard", use_cache=False, zero_center=False
)
noised_dset = NoisedAnglesDataset(dset, dset_key="coords")
# print(len(noised_dset))
# print(noised_dset[0])
x = dset[0]
x = noised_dset[0]
# x = noised_dset[0]
for k, v in x.items():