diff --git a/protdiff/angles_and_coords.py b/protdiff/angles_and_coords.py index 9b8eb8d..118a3f3 100644 --- a/protdiff/angles_and_coords.py +++ b/protdiff/angles_and_coords.py @@ -272,7 +272,6 @@ def extract_backbone_coords( chain = structure.get_structure()[0] backbone = chain[struc.filter_backbone(chain)] ca = [c for c in backbone if c.atom_name in atoms] - assert len(ca) == get_pdb_length(fname) coords = np.vstack([c.coord for c in ca]) return coords diff --git a/protdiff/datasets.py b/protdiff/datasets.py index 0ac29fb..b073b6c 100644 --- a/protdiff/datasets.py +++ b/protdiff/datasets.py @@ -33,6 +33,7 @@ from angles_and_coords import ( canonical_distances_and_dihedrals, EXHAUSTIVE_ANGLES, EXHAUSTIVE_DISTS, + extract_backbone_coords, ) from custom_metrics import wrapped_mean import utils @@ -94,8 +95,9 @@ class CathCanonicalAnglesDataset(Dataset): distances=EXHAUSTIVE_DISTS, angles=EXHAUSTIVE_ANGLES, ) + coords_pfunc = functools.partial(extract_backbone_coords, atoms=["CA"]) - # self.structures should be a list of dicts with keys (angles, fname) + # self.structures should be a list of dicts with keys (angles, coords, fname) # Always compute for toy; do not save if toy: if isinstance(toy, bool): @@ -104,11 +106,18 @@ class CathCanonicalAnglesDataset(Dataset): logging.info(f"Loading toy dataset of {toy} structures") struct_arrays = [pfunc(f) for f in fnames] + coord_arrays = [coords_pfunc for f in fnames] self.structures = [] - for fname, s in zip(fnames, struct_arrays): + for fname, s, c in zip(fnames, struct_arrays, coord_arrays): if s is None: continue - self.structures.append({"angles": s, "fname": fname}) + self.structures.append( + { + "angles": s, + "coords": c, + "fname": fname, + } + ) elif not use_cache or not os.path.exists(self.cache_fname): # No cache yet or not using cache logging.info( @@ -116,16 +125,23 @@ class CathCanonicalAnglesDataset(Dataset): ) # Generate dihedral angles pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) - struct_arrays = pool.map(pfunc, fnames, chunksize=250) + struct_arrays = list(pool.map(pfunc, fnames, chunksize=250)) + coord_arrays = list(pool.map(coords_pfunc, fnames, chunksize=250)) pool.close() pool.join() # Contains only non-null structures self.structures = [] - for fname, s in zip(fnames, struct_arrays): + for fname, s, c in zip(fnames, struct_arrays, coord_arrays): if s is None: continue - self.structures.append({"angles": s, "fname": fname}) + self.structures.append( + { + "angles": s, + "coords": c, + "fname": fname, + } + ) # Write the output to a file for faster loading next time if use_cache: logging.info(f"Saving full dataset to cache at {self.cache_fname}") @@ -238,6 +254,9 @@ class CathCanonicalAnglesDataset(Dataset): raise IndexError("Index out of range") angles = self.structures[index]["angles"] + # NOTE coords are NOT shifted or wrapped + coords = self.structures[index]["coords"] + print(angles.shape, coords.shape) # If given, offset the angles with mean if self.means is not None and not ignore_zero_center: @@ -982,14 +1001,15 @@ class ScoreMatchingNoisedAnglesDataset(Dataset): def main(): dset = CathCanonicalAnglesDataset(pad=128, trim_strategy="discard", use_cache=False) - noised_dset = NoisedAnglesDataset(dset, dset_key="angles") - print(len(noised_dset)) - print(noised_dset[0]) + # noised_dset = NoisedAnglesDataset(dset, dset_key="angles") + # print(len(noised_dset)) + # print(noised_dset[0]) + x = dset[0] # x = noised_dset[0] - # for k, v in x.items(): - # print(k) - # print(v) + for k, v in x.items(): + print(k) + print(v) if __name__ == "__main__":