Implement code path to discard too-long sequences

This commit is contained in:
Kevin Wu
2022-09-12 13:40:08 -07:00
parent 5e55a04f94
commit 2f9c68bbba

View File

@@ -41,7 +41,7 @@ from angles_and_coords import (
from custom_metrics import kl_from_empirical, wrapped_mean
import utils
TRIM_STRATEGIES = Literal["leftalign", "randomcrop"]
TRIM_STRATEGIES = Literal["leftalign", "randomcrop", "discard"]
class CathConsecutiveAnglesDataset(Dataset):
@@ -260,6 +260,7 @@ class CathCanonicalAnglesDataset(Dataset):
use_cache: bool = True, # Use/build cached computations of dihedrals and angles
) -> None:
super().__init__()
assert pad > min_length
self.trim_strategy = trim_strategy
self.pad = pad
self.min_length = min_length
@@ -326,6 +327,15 @@ class CathCanonicalAnglesDataset(Dataset):
logging.info(
f"Removing structures shorter than {self.min_length} residues excludes {len_delta}/{orig_len} --> {len(self.structures)} sequences"
)
if self.trim_strategy == "discard":
orig_len = len(self.structures)
self.structures = [
s for s in self.structures if s["angles"].shape[0] <= self.pad
]
len_delta = orig_len - len(self.structures)
logging.info(
f"Removing structures longer than {self.pad} produces {orig_len} - {len_delta} = {len(self.structures)} sequences"
)
# Split the dataset if requested. This is implemented here to maintain
# functional parity with the original CATH dataset. Original CATH uses
@@ -694,7 +704,7 @@ class NoisedAnglesDataset(Dataset):
def pad(self):
"""Pas through the pad property of wrapped dset"""
return self.dset.pad
@property
def filenames(self):
"""Pass through the filenames property of the wrapped dset"""
@@ -1159,7 +1169,7 @@ def main():
# print(len(noised_dset))
# print(noised_dset[0])
dset = CathCanonicalAnglesDataset(pad=32, trim_strategy="randomcrop")
dset = CathCanonicalAnglesDataset(pad=128, trim_strategy="discard", use_cache=False)
noised_dset = NoisedAnglesDataset(dset, dset_key="angles")
print(len(noised_dset))
print(noised_dset[0])