Refactor between_residue_bond_loss to use a shared helper function.

PiperOrigin-RevId: 808166329
Change-Id: I3d60463cbc8f0ec3cbceab7a445d8f6f2d652ce0
This commit is contained in:
Ryan Pachauri
2025-09-17 08:31:42 -07:00
committed by Copybara-Service
parent cc9042484e
commit 2652cafb43

View File

@@ -32,8 +32,7 @@ computationally more efficient.
The internal atom14 representation is turned into the atom37 at the output of
the network to facilitate easier conversion to existing protein datastructures.
"""
from typing import Dict, Optional
from typing import Dict, Optional, Union
from alphafold.common import residue_constants
from alphafold.model import r3
@@ -393,59 +392,71 @@ def atom37_to_torsion_angles(
torsion_frames = r3.rigids_from_3_points(
point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]),
origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]),
point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]))
point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]),
)
# Compute the position of the forth atom in this frame (y and z coordinate
# define the chi angle)
# r3.Vecs (B, N, torsions=7)
forth_atom_rel_pos = r3.rigids_mul_vecs(
r3.invert_rigids(torsion_frames),
r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]))
r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]),
)
# Normalize to have the sin and cos of the torsion angle.
# jnp.ndarray (B, N, torsions=7, sincos=2)
torsion_angles_sin_cos = jnp.stack(
[forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1)
[forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1
)
torsion_angles_sin_cos /= jnp.sqrt(
jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True)
+ 1e-8)
jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8
)
# Mirror psi, because we computed it from the Oxygen-atom.
torsion_angles_sin_cos *= jnp.asarray(
[1., 1., -1., 1., 1., 1., 1.])[None, None, :, None]
torsion_angles_sin_cos *= jnp.asarray([1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0])[
None, None, :, None
]
# Create alternative angles for ambiguous atom names.
chi_is_ambiguous = utils.batched_gather(
jnp.asarray(residue_constants.chi_pi_periodic), aatype)
jnp.asarray(residue_constants.chi_pi_periodic), aatype
)
mirror_torsion_angles = jnp.concatenate(
[jnp.ones([num_batch, num_res, 3]),
1.0 - 2.0 * chi_is_ambiguous], axis=-1)
[jnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1
)
alt_torsion_angles_sin_cos = (
torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None])
torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]
)
if placeholder_for_undefined:
# Add placeholder torsions in place of undefined torsion angles
# (e.g. N-terminus pre-omega)
placeholder_torsions = jnp.stack([
jnp.ones(torsion_angles_sin_cos.shape[:-1]),
jnp.zeros(torsion_angles_sin_cos.shape[:-1])
], axis=-1)
placeholder_torsions = jnp.stack(
[
jnp.ones(torsion_angles_sin_cos.shape[:-1]),
jnp.zeros(torsion_angles_sin_cos.shape[:-1]),
],
axis=-1,
)
torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
..., None
] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
alt_torsion_angles_sin_cos = (
alt_torsion_angles_sin_cos * torsion_angles_mask[..., None]
+ placeholder_torsions * (1 - torsion_angles_mask[..., None])
)
return {
'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, N, 7, 2)
'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, N, 7, 2)
'torsion_angles_mask': torsion_angles_mask # (B, N, 7)
'torsion_angles_mask': torsion_angles_mask, # (B, N, 7)
}
def torsion_angles_to_frames(
aatype: jnp.ndarray, # (N)
backb_to_global: r3.Rigids, # (N)
torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2)
torsion_angles_sin_cos: jnp.ndarray, # (N, 7, 2)
) -> r3.Rigids: # (N, 8)
"""Compute rigid group frames from torsion angles.
@@ -666,21 +677,26 @@ def between_residue_bond_loss(
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = (
aatype[1:] == residue_constants.resname_to_idx['PRO']).astype(jnp.float32)
gt_length = (
(1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
+ next_is_proline * residue_constants.between_res_bond_length_c_n[1])
gt_stddev = (
(1. - next_is_proline) *
residue_constants.between_res_bond_length_stddev_c_n[0] +
next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1])
c_n_bond_length_error = jnp.sqrt(1e-6 +
jnp.square(c_n_bond_length - gt_length))
c_n_loss_per_residue = jax.nn.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev))
c_n_loss_per_residue, c_n_loss, c_n_violation_mask = (
_loss_and_violation_mask(
metric=c_n_bond_length,
gt_metric=(
(1.0 - next_is_proline)
* residue_constants.between_res_bond_length_c_n[0]
+ next_is_proline
* residue_constants.between_res_bond_length_c_n[1]
),
gt_stddev=(
(1.0 - next_is_proline)
* residue_constants.between_res_bond_length_stddev_c_n[0]
+ next_is_proline
* residue_constants.between_res_bond_length_stddev_c_n[1]
),
mask=this_c_mask * next_n_mask * has_no_gap_mask,
tolerance_factor_soft=tolerance_factor_soft,
tolerance_factor_hard=tolerance_factor_hard,
)
)
# Compute loss for the angles.
ca_c_bond_length = jnp.sqrt(1e-6 + jnp.sum(
@@ -692,53 +708,77 @@ def between_residue_bond_loss(
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[:, None]
n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[:, None]
ca_c_n_cos_angle = jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1)
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle))
ca_c_n_loss_per_residue = jax.nn.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
(tolerance_factor_hard * gt_stddev))
ca_c_n_loss_per_residue, ca_c_n_loss, ca_c_n_violation_mask = (
_loss_and_violation_mask(
metric=jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1),
gt_metric=residue_constants.between_res_cos_angles_ca_c_n[0],
gt_stddev=residue_constants.between_res_bond_length_stddev_c_n[0],
mask=this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask,
tolerance_factor_soft=tolerance_factor_soft,
tolerance_factor_hard=tolerance_factor_hard,
)
)
c_n_ca_cos_angle = jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1)
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = jnp.sqrt(
1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle))
c_n_ca_loss_per_residue = jax.nn.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev))
c_n_ca_loss_per_residue, c_n_ca_loss, c_n_ca_violation_mask = (
_loss_and_violation_mask(
metric=jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1),
gt_metric=residue_constants.between_res_cos_angles_c_n_ca[0],
gt_stddev=residue_constants.between_res_cos_angles_c_n_ca[1],
mask=this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask,
tolerance_factor_soft=tolerance_factor_soft,
tolerance_factor_hard=tolerance_factor_hard,
)
)
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum = (c_n_loss_per_residue +
ca_c_n_loss_per_residue +
c_n_ca_loss_per_residue)
per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) +
jnp.pad(per_residue_loss_sum, [[1, 0]]))
per_residue_loss_sum = (
c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
)
per_residue_loss_sum = 0.5 * (
jnp.pad(per_residue_loss_sum, [[0, 1]])
+ jnp.pad(per_residue_loss_sum, [[1, 0]])
)
# Compute hard violations.
violation_mask = jnp.max(
jnp.stack([c_n_violation_mask,
ca_c_n_violation_mask,
c_n_ca_violation_mask]), axis=0)
jnp.stack(
[c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask]
),
axis=0,
)
violation_mask = jnp.maximum(
jnp.pad(violation_mask, [[0, 1]]),
jnp.pad(violation_mask, [[1, 0]]))
jnp.pad(violation_mask, [[0, 1]]), jnp.pad(violation_mask, [[1, 0]])
)
return {'c_n_loss_mean': c_n_loss, # shape ()
'ca_c_n_loss_mean': ca_c_n_loss, # shape ()
'c_n_ca_loss_mean': c_n_ca_loss, # shape ()
'per_residue_loss_sum': per_residue_loss_sum, # shape (N)
'per_residue_violation_mask': violation_mask # shape (N)
}
return {
'c_n_loss_mean': c_n_loss, # shape ()
'ca_c_n_loss_mean': ca_c_n_loss, # shape ()
'c_n_ca_loss_mean': c_n_ca_loss, # shape ()
'per_residue_loss_sum': per_residue_loss_sum, # shape (N)
'per_residue_violation_mask': violation_mask, # shape (N)
}
def _loss_and_violation_mask(
*,
metric: jnp.ndarray, # (N - 1)
gt_metric: Union[float, jnp.ndarray],
gt_stddev: Union[float, jnp.ndarray],
mask: jnp.ndarray, # (N - 1)
tolerance_factor_soft: float = 12.0,
tolerance_factor_hard: float = 12.0,
):
"""Compute loss and violation mask for a given metric."""
error = jnp.sqrt(1e-6 + jnp.square(metric - gt_metric))
loss_per_residue = jax.nn.relu(
error - tolerance_factor_soft * gt_stddev
)
loss = jnp.sum(mask * loss_per_residue) / (jnp.sum(mask) + 1e-6)
violation_mask = mask * (
error > (tolerance_factor_hard * gt_stddev)
)
return loss_per_residue, loss, violation_mask
def between_residue_clash_loss(
@@ -1098,9 +1138,10 @@ def _make_renaming_matrices():
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix.astype(np.float32)
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices