mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Support for coordinate based datasets
This commit is contained in:
@@ -59,10 +59,12 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
"tau",
|
||||
"CA:C:1N",
|
||||
"C:1N:1CA",
|
||||
]
|
||||
],
|
||||
"coords": ["x", "y", "z"],
|
||||
}
|
||||
feature_is_angular = {
|
||||
"angles": [False, False, False, True, True, True, True, True, True]
|
||||
"angles": [False, False, False, True, True, True, True, True, True],
|
||||
"coords": [False, False, False],
|
||||
}
|
||||
cache_fname = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
@@ -254,9 +256,9 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
raise IndexError("Index out of range")
|
||||
|
||||
angles = self.structures[index]["angles"]
|
||||
# NOTE coords are NOT shifted or wrapped
|
||||
# NOTE coords are NOT shifted or wrapped, has same length as angles
|
||||
coords = self.structures[index]["coords"]
|
||||
print(angles.shape, coords.shape)
|
||||
assert angles.shape[0] == coords.shape[0]
|
||||
|
||||
# If given, offset the angles with mean
|
||||
if self.means is not None and not ignore_zero_center:
|
||||
@@ -302,16 +304,24 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
coords = np.pad(
|
||||
coords,
|
||||
((0, self.pad - coords.shape[0]), (0, 0)),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
elif angles.shape[0] > self.pad:
|
||||
if self.trim_strategy == "leftalign":
|
||||
angles = angles[: self.pad]
|
||||
coords = coords[: self.pad]
|
||||
elif self.trim_strategy == "randomcrop":
|
||||
# Randomly crop the sequence to
|
||||
start_idx = self.rng.integers(0, angles.shape[0] - self.pad)
|
||||
end_idx = start_idx + self.pad
|
||||
assert end_idx < angles.shape[0]
|
||||
angles = angles[start_idx:end_idx]
|
||||
assert angles.shape[0] == self.pad
|
||||
coords = coords[start_idx:end_idx]
|
||||
assert angles.shape[0] == coords.shape[0] == self.pad
|
||||
else:
|
||||
raise ValueError(f"Unknown trim strategy: {self.trim_strategy}")
|
||||
|
||||
@@ -328,9 +338,11 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
angles[:, angular_idx], "<=", np.pi
|
||||
), f"Illegal value: {np.max(angles[:, angular_idx])}"
|
||||
angles = torch.from_numpy(angles).float()
|
||||
coords = torch.from_numpy(coords).float()
|
||||
|
||||
retval = {
|
||||
"angles": angles,
|
||||
"coords": coords,
|
||||
"attn_mask": attn_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
@@ -355,6 +367,25 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
return torch.var_mean(all_vals)[::-1] # Default is (var, mean)
|
||||
|
||||
|
||||
class CathCanonicalCoordsDataset(CathCanonicalAnglesDataset):
|
||||
"""
|
||||
Building on the CATH dataset, return the XYZ coordaintes of each alpha carbon
|
||||
"""
|
||||
|
||||
feature_names = {"coords": list("xyz")}
|
||||
feature_is_angular = {"coords": [False, False, False]}
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __getitem__(
|
||||
self, index, ignore_zero_center: bool = True
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return_dict = super().__getitem__(index, ignore_zero_center=ignore_zero_center)
|
||||
return_dict.pop("angles", None)
|
||||
return return_dict
|
||||
|
||||
|
||||
class CathCanonicalAnglesOnlyDataset(CathCanonicalAnglesDataset):
|
||||
"""
|
||||
Building on the CATH dataset, return the 3 canonical dihedrals and the 3
|
||||
@@ -398,6 +429,7 @@ class CathCanonicalAnglesOnlyDataset(CathCanonicalAnglesDataset):
|
||||
assert torch.all(
|
||||
return_dict["angles"] <= torch.pi
|
||||
), f"Maximum value {torch.max(return_dict['angles'])} higher than pi"
|
||||
return_dict.pop("coords", None)
|
||||
|
||||
return return_dict
|
||||
|
||||
@@ -1000,11 +1032,13 @@ class ScoreMatchingNoisedAnglesDataset(Dataset):
|
||||
|
||||
|
||||
def main():
|
||||
dset = CathCanonicalAnglesDataset(pad=128, trim_strategy="discard", use_cache=False)
|
||||
# noised_dset = NoisedAnglesDataset(dset, dset_key="angles")
|
||||
dset = CathCanonicalCoordsDataset(
|
||||
pad=128, trim_strategy="discard", use_cache=False, zero_center=False
|
||||
)
|
||||
noised_dset = NoisedAnglesDataset(dset, dset_key="coords")
|
||||
# print(len(noised_dset))
|
||||
# print(noised_dset[0])
|
||||
x = dset[0]
|
||||
x = noised_dset[0]
|
||||
|
||||
# x = noised_dset[0]
|
||||
for k, v in x.items():
|
||||
|
||||
Reference in New Issue
Block a user