Also pre-cache coordinates

This commit is contained in:
Kevin Wu
2022-09-22 14:24:53 -07:00
parent fa52ec1a2d
commit b0c5f28fee
2 changed files with 32 additions and 13 deletions

View File

@@ -272,7 +272,6 @@ def extract_backbone_coords(
chain = structure.get_structure()[0]
backbone = chain[struc.filter_backbone(chain)]
ca = [c for c in backbone if c.atom_name in atoms]
assert len(ca) == get_pdb_length(fname)
coords = np.vstack([c.coord for c in ca])
return coords

View File

@@ -33,6 +33,7 @@ from angles_and_coords import (
canonical_distances_and_dihedrals,
EXHAUSTIVE_ANGLES,
EXHAUSTIVE_DISTS,
extract_backbone_coords,
)
from custom_metrics import wrapped_mean
import utils
@@ -94,8 +95,9 @@ class CathCanonicalAnglesDataset(Dataset):
distances=EXHAUSTIVE_DISTS,
angles=EXHAUSTIVE_ANGLES,
)
coords_pfunc = functools.partial(extract_backbone_coords, atoms=["CA"])
# self.structures should be a list of dicts with keys (angles, fname)
# self.structures should be a list of dicts with keys (angles, coords, fname)
# Always compute for toy; do not save
if toy:
if isinstance(toy, bool):
@@ -104,11 +106,18 @@ class CathCanonicalAnglesDataset(Dataset):
logging.info(f"Loading toy dataset of {toy} structures")
struct_arrays = [pfunc(f) for f in fnames]
coord_arrays = [coords_pfunc for f in fnames]
self.structures = []
for fname, s in zip(fnames, struct_arrays):
for fname, s, c in zip(fnames, struct_arrays, coord_arrays):
if s is None:
continue
self.structures.append({"angles": s, "fname": fname})
self.structures.append(
{
"angles": s,
"coords": c,
"fname": fname,
}
)
elif not use_cache or not os.path.exists(self.cache_fname):
# No cache yet or not using cache
logging.info(
@@ -116,16 +125,23 @@ class CathCanonicalAnglesDataset(Dataset):
)
# Generate dihedral angles
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
struct_arrays = pool.map(pfunc, fnames, chunksize=250)
struct_arrays = list(pool.map(pfunc, fnames, chunksize=250))
coord_arrays = list(pool.map(coords_pfunc, fnames, chunksize=250))
pool.close()
pool.join()
# Contains only non-null structures
self.structures = []
for fname, s in zip(fnames, struct_arrays):
for fname, s, c in zip(fnames, struct_arrays, coord_arrays):
if s is None:
continue
self.structures.append({"angles": s, "fname": fname})
self.structures.append(
{
"angles": s,
"coords": c,
"fname": fname,
}
)
# Write the output to a file for faster loading next time
if use_cache:
logging.info(f"Saving full dataset to cache at {self.cache_fname}")
@@ -238,6 +254,9 @@ class CathCanonicalAnglesDataset(Dataset):
raise IndexError("Index out of range")
angles = self.structures[index]["angles"]
# NOTE coords are NOT shifted or wrapped
coords = self.structures[index]["coords"]
print(angles.shape, coords.shape)
# If given, offset the angles with mean
if self.means is not None and not ignore_zero_center:
@@ -982,14 +1001,15 @@ class ScoreMatchingNoisedAnglesDataset(Dataset):
def main():
dset = CathCanonicalAnglesDataset(pad=128, trim_strategy="discard", use_cache=False)
noised_dset = NoisedAnglesDataset(dset, dset_key="angles")
print(len(noised_dset))
print(noised_dset[0])
# noised_dset = NoisedAnglesDataset(dset, dset_key="angles")
# print(len(noised_dset))
# print(noised_dset[0])
x = dset[0]
# x = noised_dset[0]
# for k, v in x.items():
# print(k)
# print(v)
for k, v in x.items():
print(k)
print(v)
if __name__ == "__main__":