mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Refactor between_residue_bond_loss to use a shared helper function.
PiperOrigin-RevId: 808166329 Change-Id: I3d60463cbc8f0ec3cbceab7a445d8f6f2d652ce0
This commit is contained in:
committed by
Copybara-Service
parent
cc9042484e
commit
2652cafb43
@@ -32,8 +32,7 @@ computationally more efficient.
|
||||
The internal atom14 representation is turned into the atom37 at the output of
|
||||
the network to facilitate easier conversion to existing protein datastructures.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
from alphafold.common import residue_constants
|
||||
|
||||
from alphafold.model import r3
|
||||
@@ -393,59 +392,71 @@ def atom37_to_torsion_angles(
|
||||
torsion_frames = r3.rigids_from_3_points(
|
||||
point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]),
|
||||
origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]),
|
||||
point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]))
|
||||
point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]),
|
||||
)
|
||||
|
||||
# Compute the position of the forth atom in this frame (y and z coordinate
|
||||
# define the chi angle)
|
||||
# r3.Vecs (B, N, torsions=7)
|
||||
forth_atom_rel_pos = r3.rigids_mul_vecs(
|
||||
r3.invert_rigids(torsion_frames),
|
||||
r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]))
|
||||
r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]),
|
||||
)
|
||||
|
||||
# Normalize to have the sin and cos of the torsion angle.
|
||||
# jnp.ndarray (B, N, torsions=7, sincos=2)
|
||||
torsion_angles_sin_cos = jnp.stack(
|
||||
[forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1)
|
||||
[forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1
|
||||
)
|
||||
torsion_angles_sin_cos /= jnp.sqrt(
|
||||
jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True)
|
||||
+ 1e-8)
|
||||
jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8
|
||||
)
|
||||
|
||||
# Mirror psi, because we computed it from the Oxygen-atom.
|
||||
torsion_angles_sin_cos *= jnp.asarray(
|
||||
[1., 1., -1., 1., 1., 1., 1.])[None, None, :, None]
|
||||
torsion_angles_sin_cos *= jnp.asarray([1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0])[
|
||||
None, None, :, None
|
||||
]
|
||||
|
||||
# Create alternative angles for ambiguous atom names.
|
||||
chi_is_ambiguous = utils.batched_gather(
|
||||
jnp.asarray(residue_constants.chi_pi_periodic), aatype)
|
||||
jnp.asarray(residue_constants.chi_pi_periodic), aatype
|
||||
)
|
||||
mirror_torsion_angles = jnp.concatenate(
|
||||
[jnp.ones([num_batch, num_res, 3]),
|
||||
1.0 - 2.0 * chi_is_ambiguous], axis=-1)
|
||||
[jnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1
|
||||
)
|
||||
alt_torsion_angles_sin_cos = (
|
||||
torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None])
|
||||
torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]
|
||||
)
|
||||
|
||||
if placeholder_for_undefined:
|
||||
# Add placeholder torsions in place of undefined torsion angles
|
||||
# (e.g. N-terminus pre-omega)
|
||||
placeholder_torsions = jnp.stack([
|
||||
jnp.ones(torsion_angles_sin_cos.shape[:-1]),
|
||||
jnp.zeros(torsion_angles_sin_cos.shape[:-1])
|
||||
], axis=-1)
|
||||
placeholder_torsions = jnp.stack(
|
||||
[
|
||||
jnp.ones(torsion_angles_sin_cos.shape[:-1]),
|
||||
jnp.zeros(torsion_angles_sin_cos.shape[:-1]),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[
|
||||
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
|
||||
alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[
|
||||
..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
|
||||
..., None
|
||||
] + placeholder_torsions * (1 - torsion_angles_mask[..., None])
|
||||
alt_torsion_angles_sin_cos = (
|
||||
alt_torsion_angles_sin_cos * torsion_angles_mask[..., None]
|
||||
+ placeholder_torsions * (1 - torsion_angles_mask[..., None])
|
||||
)
|
||||
|
||||
return {
|
||||
'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, N, 7, 2)
|
||||
'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, N, 7, 2)
|
||||
'torsion_angles_mask': torsion_angles_mask # (B, N, 7)
|
||||
'torsion_angles_mask': torsion_angles_mask, # (B, N, 7)
|
||||
}
|
||||
|
||||
|
||||
def torsion_angles_to_frames(
|
||||
aatype: jnp.ndarray, # (N)
|
||||
backb_to_global: r3.Rigids, # (N)
|
||||
torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2)
|
||||
torsion_angles_sin_cos: jnp.ndarray, # (N, 7, 2)
|
||||
) -> r3.Rigids: # (N, 8)
|
||||
"""Compute rigid group frames from torsion angles.
|
||||
|
||||
@@ -666,21 +677,26 @@ def between_residue_bond_loss(
|
||||
# The C-N bond to proline has slightly different length because of the ring.
|
||||
next_is_proline = (
|
||||
aatype[1:] == residue_constants.resname_to_idx['PRO']).astype(jnp.float32)
|
||||
gt_length = (
|
||||
(1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0]
|
||||
+ next_is_proline * residue_constants.between_res_bond_length_c_n[1])
|
||||
gt_stddev = (
|
||||
(1. - next_is_proline) *
|
||||
residue_constants.between_res_bond_length_stddev_c_n[0] +
|
||||
next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1])
|
||||
c_n_bond_length_error = jnp.sqrt(1e-6 +
|
||||
jnp.square(c_n_bond_length - gt_length))
|
||||
c_n_loss_per_residue = jax.nn.relu(
|
||||
c_n_bond_length_error - tolerance_factor_soft * gt_stddev)
|
||||
mask = this_c_mask * next_n_mask * has_no_gap_mask
|
||||
c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
|
||||
c_n_violation_mask = mask * (
|
||||
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev))
|
||||
c_n_loss_per_residue, c_n_loss, c_n_violation_mask = (
|
||||
_loss_and_violation_mask(
|
||||
metric=c_n_bond_length,
|
||||
gt_metric=(
|
||||
(1.0 - next_is_proline)
|
||||
* residue_constants.between_res_bond_length_c_n[0]
|
||||
+ next_is_proline
|
||||
* residue_constants.between_res_bond_length_c_n[1]
|
||||
),
|
||||
gt_stddev=(
|
||||
(1.0 - next_is_proline)
|
||||
* residue_constants.between_res_bond_length_stddev_c_n[0]
|
||||
+ next_is_proline
|
||||
* residue_constants.between_res_bond_length_stddev_c_n[1]
|
||||
),
|
||||
mask=this_c_mask * next_n_mask * has_no_gap_mask,
|
||||
tolerance_factor_soft=tolerance_factor_soft,
|
||||
tolerance_factor_hard=tolerance_factor_hard,
|
||||
)
|
||||
)
|
||||
|
||||
# Compute loss for the angles.
|
||||
ca_c_bond_length = jnp.sqrt(1e-6 + jnp.sum(
|
||||
@@ -692,53 +708,77 @@ def between_residue_bond_loss(
|
||||
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[:, None]
|
||||
n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[:, None]
|
||||
|
||||
ca_c_n_cos_angle = jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1)
|
||||
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
|
||||
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
|
||||
ca_c_n_cos_angle_error = jnp.sqrt(
|
||||
1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle))
|
||||
ca_c_n_loss_per_residue = jax.nn.relu(
|
||||
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev)
|
||||
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
|
||||
ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6)
|
||||
ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error >
|
||||
(tolerance_factor_hard * gt_stddev))
|
||||
ca_c_n_loss_per_residue, ca_c_n_loss, ca_c_n_violation_mask = (
|
||||
_loss_and_violation_mask(
|
||||
metric=jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1),
|
||||
gt_metric=residue_constants.between_res_cos_angles_ca_c_n[0],
|
||||
gt_stddev=residue_constants.between_res_bond_length_stddev_c_n[0],
|
||||
mask=this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask,
|
||||
tolerance_factor_soft=tolerance_factor_soft,
|
||||
tolerance_factor_hard=tolerance_factor_hard,
|
||||
)
|
||||
)
|
||||
|
||||
c_n_ca_cos_angle = jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1)
|
||||
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
|
||||
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
|
||||
c_n_ca_cos_angle_error = jnp.sqrt(
|
||||
1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle))
|
||||
c_n_ca_loss_per_residue = jax.nn.relu(
|
||||
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev)
|
||||
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
|
||||
c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6)
|
||||
c_n_ca_violation_mask = mask * (
|
||||
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev))
|
||||
c_n_ca_loss_per_residue, c_n_ca_loss, c_n_ca_violation_mask = (
|
||||
_loss_and_violation_mask(
|
||||
metric=jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1),
|
||||
gt_metric=residue_constants.between_res_cos_angles_c_n_ca[0],
|
||||
gt_stddev=residue_constants.between_res_cos_angles_c_n_ca[1],
|
||||
mask=this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask,
|
||||
tolerance_factor_soft=tolerance_factor_soft,
|
||||
tolerance_factor_hard=tolerance_factor_hard,
|
||||
)
|
||||
)
|
||||
|
||||
# Compute a per residue loss (equally distribute the loss to both
|
||||
# neighbouring residues).
|
||||
per_residue_loss_sum = (c_n_loss_per_residue +
|
||||
ca_c_n_loss_per_residue +
|
||||
c_n_ca_loss_per_residue)
|
||||
per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) +
|
||||
jnp.pad(per_residue_loss_sum, [[1, 0]]))
|
||||
per_residue_loss_sum = (
|
||||
c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
|
||||
)
|
||||
per_residue_loss_sum = 0.5 * (
|
||||
jnp.pad(per_residue_loss_sum, [[0, 1]])
|
||||
+ jnp.pad(per_residue_loss_sum, [[1, 0]])
|
||||
)
|
||||
|
||||
# Compute hard violations.
|
||||
violation_mask = jnp.max(
|
||||
jnp.stack([c_n_violation_mask,
|
||||
ca_c_n_violation_mask,
|
||||
c_n_ca_violation_mask]), axis=0)
|
||||
jnp.stack(
|
||||
[c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask]
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
violation_mask = jnp.maximum(
|
||||
jnp.pad(violation_mask, [[0, 1]]),
|
||||
jnp.pad(violation_mask, [[1, 0]]))
|
||||
jnp.pad(violation_mask, [[0, 1]]), jnp.pad(violation_mask, [[1, 0]])
|
||||
)
|
||||
|
||||
return {'c_n_loss_mean': c_n_loss, # shape ()
|
||||
'ca_c_n_loss_mean': ca_c_n_loss, # shape ()
|
||||
'c_n_ca_loss_mean': c_n_ca_loss, # shape ()
|
||||
'per_residue_loss_sum': per_residue_loss_sum, # shape (N)
|
||||
'per_residue_violation_mask': violation_mask # shape (N)
|
||||
}
|
||||
return {
|
||||
'c_n_loss_mean': c_n_loss, # shape ()
|
||||
'ca_c_n_loss_mean': ca_c_n_loss, # shape ()
|
||||
'c_n_ca_loss_mean': c_n_ca_loss, # shape ()
|
||||
'per_residue_loss_sum': per_residue_loss_sum, # shape (N)
|
||||
'per_residue_violation_mask': violation_mask, # shape (N)
|
||||
}
|
||||
|
||||
|
||||
def _loss_and_violation_mask(
|
||||
*,
|
||||
metric: jnp.ndarray, # (N - 1)
|
||||
gt_metric: Union[float, jnp.ndarray],
|
||||
gt_stddev: Union[float, jnp.ndarray],
|
||||
mask: jnp.ndarray, # (N - 1)
|
||||
tolerance_factor_soft: float = 12.0,
|
||||
tolerance_factor_hard: float = 12.0,
|
||||
):
|
||||
"""Compute loss and violation mask for a given metric."""
|
||||
error = jnp.sqrt(1e-6 + jnp.square(metric - gt_metric))
|
||||
loss_per_residue = jax.nn.relu(
|
||||
error - tolerance_factor_soft * gt_stddev
|
||||
)
|
||||
loss = jnp.sum(mask * loss_per_residue) / (jnp.sum(mask) + 1e-6)
|
||||
violation_mask = mask * (
|
||||
error > (tolerance_factor_hard * gt_stddev)
|
||||
)
|
||||
return loss_per_residue, loss, violation_mask
|
||||
|
||||
|
||||
def between_residue_clash_loss(
|
||||
@@ -1098,9 +1138,10 @@ def _make_renaming_matrices():
|
||||
resname].index(target_atom_swap)
|
||||
correspondences[source_index] = target_index
|
||||
correspondences[target_index] = source_index
|
||||
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
|
||||
for index, correspondence in enumerate(correspondences):
|
||||
renaming_matrix[index, correspondence] = 1.
|
||||
|
||||
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
|
||||
for index, correspondence in enumerate(correspondences):
|
||||
renaming_matrix[index, correspondence] = 1.
|
||||
all_matrices[resname] = renaming_matrix.astype(np.float32)
|
||||
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
|
||||
return renaming_matrices
|
||||
|
||||
Reference in New Issue
Block a user