mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Vectorize some code, fix some bugs
This commit is contained in:
@@ -60,25 +60,19 @@ def canonical_distances_and_dihedrals(
|
||||
for a in non_dihedral_angles:
|
||||
if a == "tau" or a == "N:CA:C":
|
||||
# tau = N - CA - C internal angles
|
||||
idx = np.array(
|
||||
[list(range(i, i + 3)) for i in range(3, len(backbone_atoms), 3)]
|
||||
+ [(0, 0, 0)]
|
||||
)
|
||||
r = np.arange(3, len(backbone_atoms), 3)
|
||||
idx = np.hstack([np.vstack([r, r + 1, r + 2]), np.zeros((3, 1))]).T
|
||||
elif a == "CA:C:1N": # Same as C-N angle in nerf
|
||||
# This measures an angle between two residues. Due to the way we build
|
||||
# proteins out later, we do not need to meas
|
||||
idx = np.array(
|
||||
[(i + 1, i + 2, i + 3) for i in range(0, len(backbone_atoms) - 3, 3)]
|
||||
+ [(0, 0, 0)]
|
||||
)
|
||||
r = np.arange(0, len(backbone_atoms) - 3, 3)
|
||||
idx = np.hstack([np.vstack([r + 1, r + 2, r + 3]), np.zeros((3, 1))]).T
|
||||
elif a == "C:1N:1CA":
|
||||
idx = np.array(
|
||||
[(i + 2, i + 3, i + 4) for i in range(0, len(backbone_atoms) - 3, 3)]
|
||||
+ [(0, 0, 0)]
|
||||
)
|
||||
r = np.arange(0, len(backbone_atoms) - 3, 3)
|
||||
idx = np.hstack([np.vstack([r + 2, r + 3, r + 4]), np.zeros((3, 1))]).T
|
||||
else:
|
||||
raise ValueError(f"Unrecognized angle: {a}")
|
||||
calc_angles[a] = struc.index_angle(backbone_atoms, indices=idx)
|
||||
calc_angles[a] = struc.index_angle(backbone_atoms, indices=idx.astype(int))
|
||||
|
||||
# At this point we've only looked at dihedral and angles; check value range
|
||||
for k, v in calc_angles.items():
|
||||
@@ -92,29 +86,25 @@ def canonical_distances_and_dihedrals(
|
||||
# Since this is measuring the distance between pairs of residues, there
|
||||
# is one fewer such measurement than the total number of residues like
|
||||
# for dihedrals. Therefore, we pad this with a null 0 value at the end.
|
||||
idx = np.array(
|
||||
[(i + 2, i + 3) for i in range(0, len(backbone_atoms) - 3, 3)]
|
||||
+ [(0, 0)]
|
||||
)
|
||||
r = np.arange(0, len(backbone_atoms) - 3, 3)
|
||||
idx = np.hstack([np.vstack([r + 2, r + 3]), np.zeros((2, 1))]).T
|
||||
elif d == "N:CA":
|
||||
# We start resconstructing with a fixed initial residue so we do not need
|
||||
# to predict or record the initial distance. Additionally we pad with a
|
||||
# null value at the end
|
||||
idx = np.array(
|
||||
[(i, i + 1) for i in range(3, len(backbone_atoms), 3)] + [(0, 0)]
|
||||
)
|
||||
r = np.arange(3, len(backbone_atoms), 3)
|
||||
idx = np.hstack([np.vstack([r, r + 1]), np.zeros((2, 1))]).T
|
||||
assert len(idx) == len(calc_angles["phi"])
|
||||
elif d == "CA:C":
|
||||
# We start reconstructing with a fixed initial residue so we do not need
|
||||
# to predict or record the initial distance. Additionally, we pad with a
|
||||
# null value at the end.
|
||||
idx = np.array(
|
||||
[(i + 1, i + 2) for i in range(3, len(backbone_atoms), 3)] + [(0, 0)]
|
||||
)
|
||||
r = np.arange(3, len(backbone_atoms), 3)
|
||||
idx = np.hstack([np.vstack([r + 1, r + 2]), np.zeros((2, 1))]).T
|
||||
assert len(idx) == len(calc_angles["phi"])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized distance: {d}")
|
||||
calc_angles[d] = struc.index_distance(backbone_atoms, indices=idx)
|
||||
calc_angles[d] = struc.index_distance(backbone_atoms, indices=idx.astype(int))
|
||||
|
||||
return pd.DataFrame({k: calc_angles[k].squeeze() for k in distances + angles})
|
||||
|
||||
@@ -368,18 +358,20 @@ def build_aa_sidechain_dict(
|
||||
that specify how to build out that sidechain's atoms from the backbone
|
||||
"""
|
||||
if not reference_pdbs:
|
||||
glob.glob(
|
||||
reference_pdbs = glob.glob(
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), "data/*.pdb")
|
||||
)
|
||||
|
||||
ref_file_counter = 0
|
||||
retval = {}
|
||||
for pdb in reference_pdbs:
|
||||
try:
|
||||
sidechain_angles = collect_aa_sidechain_angles(pdb)
|
||||
retval.update(sidechain_angles) # Overwrites any existing key/value pairs
|
||||
ref_file_counter += 1
|
||||
except ValueError:
|
||||
continue
|
||||
logging.info(f"Built sidechain dictionary with {len(retval)} amino acids")
|
||||
logging.info(f"Built sidechain dictionary with {len(retval)} amino acids from {ref_file_counter} files")
|
||||
return retval
|
||||
|
||||
|
||||
|
||||
@@ -164,4 +164,5 @@ class TestNoisedDataset(unittest.TestCase):
|
||||
"sqrt_alphas_cumprod_t"
|
||||
]
|
||||
recovered = utils.modulo_with_wrapped_range(recovered, -np.pi, np.pi)
|
||||
self.assertTrue(torch.allclose(recovered, orig_angles, atol=1e-5))
|
||||
delta = recovered - orig_angles
|
||||
self.assertTrue(torch.allclose(delta, torch.zeros_like(delta), atol=1e-4), f"Got non-zero delta on de-noise: {delta}")
|
||||
|
||||
@@ -11,7 +11,7 @@ class TestRadianSmoothL1Loss(unittest.TestCase):
|
||||
"""
|
||||
Easy test of basic wrapping functionality
|
||||
"""
|
||||
l = losses.radian_smooth_l1_loss(torch.tensor(0.1), 2 * torch.pi, beta=1.0)
|
||||
l = losses.radian_smooth_l1_loss(torch.tensor(0.1), torch.tensor(2 * torch.pi), beta=1.0)
|
||||
self.assertAlmostEqual(0.0050, l.item())
|
||||
|
||||
def test_rounding(self):
|
||||
|
||||
Reference in New Issue
Block a user