Simplify dihedral calculation code and add tests

This commit is contained in:
Kevin Wu
2022-09-07 12:21:57 -07:00
parent b9c749eae2
commit 149fe8a72b
2 changed files with 21 additions and 24 deletions

View File

@@ -8,14 +8,15 @@ https://benjamin-computer.medium.com/protein-loops-in-tensorflow-a-i-bio-part-2-
import numpy as np
def place_dihedral(a, b, c, bond_angle, bond_length, torsion_angle):
def place_dihedral(a: np.ndarray, b:np.ndarray, c: np.ndarray, bond_angle:float, bond_length:float, torsion_angle:float) -> np.ndarray:
"""
Place the point d such that the bond angle, length, and torsion angle are satisfied
with the series a, b, c, d.
"""
assert a.ndim == b.ndim == c.ndim == 1
unit_vec = lambda x: x / np.linalg.norm(x)
ab = b - a
bc = c - b
bcn = bc / np.linalg.norm(bc)
# numpy is row major
bc = unit_vec(c - b)
d = np.array(
[
-bond_length * np.cos(bond_angle),
@@ -23,15 +24,11 @@ def place_dihedral(a, b, c, bond_angle, bond_length, torsion_angle):
bond_length * np.sin(torsion_angle) * np.sin(bond_angle),
]
)
n = np.cross(ab, bcn)
n /= np.linalg.norm(n)
nbc = np.cross(n, bcn)
m = np.array(
[[bcn[0], nbc[0], n[0]], [bcn[1], nbc[1], n[1]], [bcn[2], nbc[2], n[2]]]
)
n = unit_vec(np.cross(ab, bc))
nbc = np.cross(n, bc)
m = np.stack([bc, nbc, n]).T
d = m.dot(d)
d = d + c
return d
return d + c
if __name__ == "__main__":

View File

@@ -35,20 +35,20 @@ class TestDihedralPlacement(unittest.TestCase):
def test_randomized(self):
"""Simple test using randomized values"""
a, b, c, d = self.rng.uniform(low=-5, high=5, size=(4, 3))
print(a, b, c, d)
calc_d = nerf.place_dihedral(
a,
b,
c,
angle_between(d - c, b - c),
dist_between(c, d),
dihedral(a, b, c, d),
)
self.assertTrue(np.allclose(d, calc_d), f"Mismatched: {d} != {calc_d}")
for _ in range(5):
a, b, c, d = self.rng.uniform(low=-5, high=5, size=(4, 3))
calc_d = nerf.place_dihedral(
a,
b,
c,
angle_between(d - c, b - c),
dist_between(c, d),
dihedral(a, b, c, d),
)
self.assertTrue(np.allclose(d, calc_d), f"Mismatched: {d} != {calc_d}")
def angle_between(v1, v2):
def angle_between(v1, v2) -> float:
"""Gets the angle between u and v"""
# https://stackoverflow.com/questions/2827393/angles-between-two-n-dimensional-vectors-in-python
unit_vector = lambda vector: vector / np.linalg.norm(vector)