Vectorize some code, fix some bugs

This commit is contained in:
Kevin Wu
2023-04-13 10:54:44 -07:00
parent 902b0301d2
commit 12fa730caf
3 changed files with 21 additions and 28 deletions

View File

@@ -60,25 +60,19 @@ def canonical_distances_and_dihedrals(
for a in non_dihedral_angles:
if a == "tau" or a == "N:CA:C":
# tau = N - CA - C internal angles
idx = np.array(
[list(range(i, i + 3)) for i in range(3, len(backbone_atoms), 3)]
+ [(0, 0, 0)]
)
r = np.arange(3, len(backbone_atoms), 3)
idx = np.hstack([np.vstack([r, r + 1, r + 2]), np.zeros((3, 1))]).T
elif a == "CA:C:1N": # Same as C-N angle in nerf
# This measures an angle between two residues. Due to the way we build
# proteins out later, we do not need to meas
idx = np.array(
[(i + 1, i + 2, i + 3) for i in range(0, len(backbone_atoms) - 3, 3)]
+ [(0, 0, 0)]
)
r = np.arange(0, len(backbone_atoms) - 3, 3)
idx = np.hstack([np.vstack([r + 1, r + 2, r + 3]), np.zeros((3, 1))]).T
elif a == "C:1N:1CA":
idx = np.array(
[(i + 2, i + 3, i + 4) for i in range(0, len(backbone_atoms) - 3, 3)]
+ [(0, 0, 0)]
)
r = np.arange(0, len(backbone_atoms) - 3, 3)
idx = np.hstack([np.vstack([r + 2, r + 3, r + 4]), np.zeros((3, 1))]).T
else:
raise ValueError(f"Unrecognized angle: {a}")
calc_angles[a] = struc.index_angle(backbone_atoms, indices=idx)
calc_angles[a] = struc.index_angle(backbone_atoms, indices=idx.astype(int))
# At this point we've only looked at dihedral and angles; check value range
for k, v in calc_angles.items():
@@ -92,29 +86,25 @@ def canonical_distances_and_dihedrals(
# Since this is measuring the distance between pairs of residues, there
# is one fewer such measurement than the total number of residues like
# for dihedrals. Therefore, we pad this with a null 0 value at the end.
idx = np.array(
[(i + 2, i + 3) for i in range(0, len(backbone_atoms) - 3, 3)]
+ [(0, 0)]
)
r = np.arange(0, len(backbone_atoms) - 3, 3)
idx = np.hstack([np.vstack([r + 2, r + 3]), np.zeros((2, 1))]).T
elif d == "N:CA":
# We start resconstructing with a fixed initial residue so we do not need
# to predict or record the initial distance. Additionally we pad with a
# null value at the end
idx = np.array(
[(i, i + 1) for i in range(3, len(backbone_atoms), 3)] + [(0, 0)]
)
r = np.arange(3, len(backbone_atoms), 3)
idx = np.hstack([np.vstack([r, r + 1]), np.zeros((2, 1))]).T
assert len(idx) == len(calc_angles["phi"])
elif d == "CA:C":
# We start reconstructing with a fixed initial residue so we do not need
# to predict or record the initial distance. Additionally, we pad with a
# null value at the end.
idx = np.array(
[(i + 1, i + 2) for i in range(3, len(backbone_atoms), 3)] + [(0, 0)]
)
r = np.arange(3, len(backbone_atoms), 3)
idx = np.hstack([np.vstack([r + 1, r + 2]), np.zeros((2, 1))]).T
assert len(idx) == len(calc_angles["phi"])
else:
raise ValueError(f"Unrecognized distance: {d}")
calc_angles[d] = struc.index_distance(backbone_atoms, indices=idx)
calc_angles[d] = struc.index_distance(backbone_atoms, indices=idx.astype(int))
return pd.DataFrame({k: calc_angles[k].squeeze() for k in distances + angles})
@@ -368,18 +358,20 @@ def build_aa_sidechain_dict(
that specify how to build out that sidechain's atoms from the backbone
"""
if not reference_pdbs:
glob.glob(
reference_pdbs = glob.glob(
os.path.join(os.path.dirname(os.path.dirname(__file__)), "data/*.pdb")
)
ref_file_counter = 0
retval = {}
for pdb in reference_pdbs:
try:
sidechain_angles = collect_aa_sidechain_angles(pdb)
retval.update(sidechain_angles) # Overwrites any existing key/value pairs
ref_file_counter += 1
except ValueError:
continue
logging.info(f"Built sidechain dictionary with {len(retval)} amino acids")
logging.info(f"Built sidechain dictionary with {len(retval)} amino acids from {ref_file_counter} files")
return retval

View File

@@ -164,4 +164,5 @@ class TestNoisedDataset(unittest.TestCase):
"sqrt_alphas_cumprod_t"
]
recovered = utils.modulo_with_wrapped_range(recovered, -np.pi, np.pi)
self.assertTrue(torch.allclose(recovered, orig_angles, atol=1e-5))
delta = recovered - orig_angles
self.assertTrue(torch.allclose(delta, torch.zeros_like(delta), atol=1e-4), f"Got non-zero delta on de-noise: {delta}")

View File

@@ -11,7 +11,7 @@ class TestRadianSmoothL1Loss(unittest.TestCase):
"""
Easy test of basic wrapping functionality
"""
l = losses.radian_smooth_l1_loss(torch.tensor(0.1), 2 * torch.pi, beta=1.0)
l = losses.radian_smooth_l1_loss(torch.tensor(0.1), torch.tensor(2 * torch.pi), beta=1.0)
self.assertAlmostEqual(0.0050, l.item())
def test_rounding(self):