mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Also pre-cache coordinates
This commit is contained in:
@@ -272,7 +272,6 @@ def extract_backbone_coords(
|
||||
chain = structure.get_structure()[0]
|
||||
backbone = chain[struc.filter_backbone(chain)]
|
||||
ca = [c for c in backbone if c.atom_name in atoms]
|
||||
assert len(ca) == get_pdb_length(fname)
|
||||
coords = np.vstack([c.coord for c in ca])
|
||||
return coords
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from angles_and_coords import (
|
||||
canonical_distances_and_dihedrals,
|
||||
EXHAUSTIVE_ANGLES,
|
||||
EXHAUSTIVE_DISTS,
|
||||
extract_backbone_coords,
|
||||
)
|
||||
from custom_metrics import wrapped_mean
|
||||
import utils
|
||||
@@ -94,8 +95,9 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
distances=EXHAUSTIVE_DISTS,
|
||||
angles=EXHAUSTIVE_ANGLES,
|
||||
)
|
||||
coords_pfunc = functools.partial(extract_backbone_coords, atoms=["CA"])
|
||||
|
||||
# self.structures should be a list of dicts with keys (angles, fname)
|
||||
# self.structures should be a list of dicts with keys (angles, coords, fname)
|
||||
# Always compute for toy; do not save
|
||||
if toy:
|
||||
if isinstance(toy, bool):
|
||||
@@ -104,11 +106,18 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
|
||||
logging.info(f"Loading toy dataset of {toy} structures")
|
||||
struct_arrays = [pfunc(f) for f in fnames]
|
||||
coord_arrays = [coords_pfunc for f in fnames]
|
||||
self.structures = []
|
||||
for fname, s in zip(fnames, struct_arrays):
|
||||
for fname, s, c in zip(fnames, struct_arrays, coord_arrays):
|
||||
if s is None:
|
||||
continue
|
||||
self.structures.append({"angles": s, "fname": fname})
|
||||
self.structures.append(
|
||||
{
|
||||
"angles": s,
|
||||
"coords": c,
|
||||
"fname": fname,
|
||||
}
|
||||
)
|
||||
elif not use_cache or not os.path.exists(self.cache_fname):
|
||||
# No cache yet or not using cache
|
||||
logging.info(
|
||||
@@ -116,16 +125,23 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
)
|
||||
# Generate dihedral angles
|
||||
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
|
||||
struct_arrays = pool.map(pfunc, fnames, chunksize=250)
|
||||
struct_arrays = list(pool.map(pfunc, fnames, chunksize=250))
|
||||
coord_arrays = list(pool.map(coords_pfunc, fnames, chunksize=250))
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
# Contains only non-null structures
|
||||
self.structures = []
|
||||
for fname, s in zip(fnames, struct_arrays):
|
||||
for fname, s, c in zip(fnames, struct_arrays, coord_arrays):
|
||||
if s is None:
|
||||
continue
|
||||
self.structures.append({"angles": s, "fname": fname})
|
||||
self.structures.append(
|
||||
{
|
||||
"angles": s,
|
||||
"coords": c,
|
||||
"fname": fname,
|
||||
}
|
||||
)
|
||||
# Write the output to a file for faster loading next time
|
||||
if use_cache:
|
||||
logging.info(f"Saving full dataset to cache at {self.cache_fname}")
|
||||
@@ -238,6 +254,9 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
raise IndexError("Index out of range")
|
||||
|
||||
angles = self.structures[index]["angles"]
|
||||
# NOTE coords are NOT shifted or wrapped
|
||||
coords = self.structures[index]["coords"]
|
||||
print(angles.shape, coords.shape)
|
||||
|
||||
# If given, offset the angles with mean
|
||||
if self.means is not None and not ignore_zero_center:
|
||||
@@ -982,14 +1001,15 @@ class ScoreMatchingNoisedAnglesDataset(Dataset):
|
||||
|
||||
def main():
|
||||
dset = CathCanonicalAnglesDataset(pad=128, trim_strategy="discard", use_cache=False)
|
||||
noised_dset = NoisedAnglesDataset(dset, dset_key="angles")
|
||||
print(len(noised_dset))
|
||||
print(noised_dset[0])
|
||||
# noised_dset = NoisedAnglesDataset(dset, dset_key="angles")
|
||||
# print(len(noised_dset))
|
||||
# print(noised_dset[0])
|
||||
x = dset[0]
|
||||
|
||||
# x = noised_dset[0]
|
||||
# for k, v in x.items():
|
||||
# print(k)
|
||||
# print(v)
|
||||
for k, v in x.items():
|
||||
print(k)
|
||||
print(v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user