mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Implement code path to discard too-long sequences
This commit is contained in:
@@ -41,7 +41,7 @@ from angles_and_coords import (
|
||||
from custom_metrics import kl_from_empirical, wrapped_mean
|
||||
import utils
|
||||
|
||||
TRIM_STRATEGIES = Literal["leftalign", "randomcrop"]
|
||||
TRIM_STRATEGIES = Literal["leftalign", "randomcrop", "discard"]
|
||||
|
||||
|
||||
class CathConsecutiveAnglesDataset(Dataset):
|
||||
@@ -260,6 +260,7 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
use_cache: bool = True, # Use/build cached computations of dihedrals and angles
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert pad > min_length
|
||||
self.trim_strategy = trim_strategy
|
||||
self.pad = pad
|
||||
self.min_length = min_length
|
||||
@@ -326,6 +327,15 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
logging.info(
|
||||
f"Removing structures shorter than {self.min_length} residues excludes {len_delta}/{orig_len} --> {len(self.structures)} sequences"
|
||||
)
|
||||
if self.trim_strategy == "discard":
|
||||
orig_len = len(self.structures)
|
||||
self.structures = [
|
||||
s for s in self.structures if s["angles"].shape[0] <= self.pad
|
||||
]
|
||||
len_delta = orig_len - len(self.structures)
|
||||
logging.info(
|
||||
f"Removing structures longer than {self.pad} produces {orig_len} - {len_delta} = {len(self.structures)} sequences"
|
||||
)
|
||||
|
||||
# Split the dataset if requested. This is implemented here to maintain
|
||||
# functional parity with the original CATH dataset. Original CATH uses
|
||||
@@ -694,7 +704,7 @@ class NoisedAnglesDataset(Dataset):
|
||||
def pad(self):
|
||||
"""Pas through the pad property of wrapped dset"""
|
||||
return self.dset.pad
|
||||
|
||||
|
||||
@property
|
||||
def filenames(self):
|
||||
"""Pass through the filenames property of the wrapped dset"""
|
||||
@@ -1159,7 +1169,7 @@ def main():
|
||||
# print(len(noised_dset))
|
||||
# print(noised_dset[0])
|
||||
|
||||
dset = CathCanonicalAnglesDataset(pad=32, trim_strategy="randomcrop")
|
||||
dset = CathCanonicalAnglesDataset(pad=128, trim_strategy="discard", use_cache=False)
|
||||
noised_dset = NoisedAnglesDataset(dset, dset_key="angles")
|
||||
print(len(noised_dset))
|
||||
print(noised_dset[0])
|
||||
|
||||
Reference in New Issue
Block a user