From 2652cafb439667b9c12bc98060e3a907af5afa92 Mon Sep 17 00:00:00 2001 From: Ryan Pachauri Date: Wed, 17 Sep 2025 08:31:42 -0700 Subject: [PATCH] Refactor between_residue_bond_loss to use a shared helper function. PiperOrigin-RevId: 808166329 Change-Id: I3d60463cbc8f0ec3cbceab7a445d8f6f2d652ce0 --- alphafold/model/all_atom.py | 197 ++++++++++++++++++++++-------------- 1 file changed, 119 insertions(+), 78 deletions(-) diff --git a/alphafold/model/all_atom.py b/alphafold/model/all_atom.py index 6278217..356d1cb 100644 --- a/alphafold/model/all_atom.py +++ b/alphafold/model/all_atom.py @@ -32,8 +32,7 @@ computationally more efficient. The internal atom14 representation is turned into the atom37 at the output of the network to facilitate easier conversion to existing protein datastructures. """ - -from typing import Dict, Optional +from typing import Dict, Optional, Union from alphafold.common import residue_constants from alphafold.model import r3 @@ -393,59 +392,71 @@ def atom37_to_torsion_angles( torsion_frames = r3.rigids_from_3_points( point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]), origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]), - point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :])) + point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]), + ) # Compute the position of the forth atom in this frame (y and z coordinate # define the chi angle) # r3.Vecs (B, N, torsions=7) forth_atom_rel_pos = r3.rigids_mul_vecs( r3.invert_rigids(torsion_frames), - r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :])) + r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]), + ) # Normalize to have the sin and cos of the torsion angle. # jnp.ndarray (B, N, torsions=7, sincos=2) torsion_angles_sin_cos = jnp.stack( - [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1) + [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1 + ) torsion_angles_sin_cos /= jnp.sqrt( - jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) - + 1e-8) + jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8 + ) # Mirror psi, because we computed it from the Oxygen-atom. - torsion_angles_sin_cos *= jnp.asarray( - [1., 1., -1., 1., 1., 1., 1.])[None, None, :, None] + torsion_angles_sin_cos *= jnp.asarray([1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0])[ + None, None, :, None + ] # Create alternative angles for ambiguous atom names. chi_is_ambiguous = utils.batched_gather( - jnp.asarray(residue_constants.chi_pi_periodic), aatype) + jnp.asarray(residue_constants.chi_pi_periodic), aatype + ) mirror_torsion_angles = jnp.concatenate( - [jnp.ones([num_batch, num_res, 3]), - 1.0 - 2.0 * chi_is_ambiguous], axis=-1) + [jnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1 + ) alt_torsion_angles_sin_cos = ( - torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]) + torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None] + ) if placeholder_for_undefined: # Add placeholder torsions in place of undefined torsion angles # (e.g. N-terminus pre-omega) - placeholder_torsions = jnp.stack([ - jnp.ones(torsion_angles_sin_cos.shape[:-1]), - jnp.zeros(torsion_angles_sin_cos.shape[:-1]) - ], axis=-1) + placeholder_torsions = jnp.stack( + [ + jnp.ones(torsion_angles_sin_cos.shape[:-1]), + jnp.zeros(torsion_angles_sin_cos.shape[:-1]), + ], + axis=-1, + ) torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ - ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) - alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ - ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + ..., None + ] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + alt_torsion_angles_sin_cos = ( + alt_torsion_angles_sin_cos * torsion_angles_mask[..., None] + + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + ) return { 'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, N, 7, 2) 'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, N, 7, 2) - 'torsion_angles_mask': torsion_angles_mask # (B, N, 7) + 'torsion_angles_mask': torsion_angles_mask, # (B, N, 7) } def torsion_angles_to_frames( aatype: jnp.ndarray, # (N) backb_to_global: r3.Rigids, # (N) - torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2) + torsion_angles_sin_cos: jnp.ndarray, # (N, 7, 2) ) -> r3.Rigids: # (N, 8) """Compute rigid group frames from torsion angles. @@ -666,21 +677,26 @@ def between_residue_bond_loss( # The C-N bond to proline has slightly different length because of the ring. next_is_proline = ( aatype[1:] == residue_constants.resname_to_idx['PRO']).astype(jnp.float32) - gt_length = ( - (1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0] - + next_is_proline * residue_constants.between_res_bond_length_c_n[1]) - gt_stddev = ( - (1. - next_is_proline) * - residue_constants.between_res_bond_length_stddev_c_n[0] + - next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1]) - c_n_bond_length_error = jnp.sqrt(1e-6 + - jnp.square(c_n_bond_length - gt_length)) - c_n_loss_per_residue = jax.nn.relu( - c_n_bond_length_error - tolerance_factor_soft * gt_stddev) - mask = this_c_mask * next_n_mask * has_no_gap_mask - c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) - c_n_violation_mask = mask * ( - c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + c_n_loss_per_residue, c_n_loss, c_n_violation_mask = ( + _loss_and_violation_mask( + metric=c_n_bond_length, + gt_metric=( + (1.0 - next_is_proline) + * residue_constants.between_res_bond_length_c_n[0] + + next_is_proline + * residue_constants.between_res_bond_length_c_n[1] + ), + gt_stddev=( + (1.0 - next_is_proline) + * residue_constants.between_res_bond_length_stddev_c_n[0] + + next_is_proline + * residue_constants.between_res_bond_length_stddev_c_n[1] + ), + mask=this_c_mask * next_n_mask * has_no_gap_mask, + tolerance_factor_soft=tolerance_factor_soft, + tolerance_factor_hard=tolerance_factor_hard, + ) + ) # Compute loss for the angles. ca_c_bond_length = jnp.sqrt(1e-6 + jnp.sum( @@ -692,53 +708,77 @@ def between_residue_bond_loss( c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[:, None] n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[:, None] - ca_c_n_cos_angle = jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1) - gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] - gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] - ca_c_n_cos_angle_error = jnp.sqrt( - 1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle)) - ca_c_n_loss_per_residue = jax.nn.relu( - ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) - mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask - ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) - ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > - (tolerance_factor_hard * gt_stddev)) + ca_c_n_loss_per_residue, ca_c_n_loss, ca_c_n_violation_mask = ( + _loss_and_violation_mask( + metric=jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1), + gt_metric=residue_constants.between_res_cos_angles_ca_c_n[0], + gt_stddev=residue_constants.between_res_bond_length_stddev_c_n[0], + mask=this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask, + tolerance_factor_soft=tolerance_factor_soft, + tolerance_factor_hard=tolerance_factor_hard, + ) + ) - c_n_ca_cos_angle = jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1) - gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] - gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] - c_n_ca_cos_angle_error = jnp.sqrt( - 1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle)) - c_n_ca_loss_per_residue = jax.nn.relu( - c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) - mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask - c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6) - c_n_ca_violation_mask = mask * ( - c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + c_n_ca_loss_per_residue, c_n_ca_loss, c_n_ca_violation_mask = ( + _loss_and_violation_mask( + metric=jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1), + gt_metric=residue_constants.between_res_cos_angles_c_n_ca[0], + gt_stddev=residue_constants.between_res_cos_angles_c_n_ca[1], + mask=this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask, + tolerance_factor_soft=tolerance_factor_soft, + tolerance_factor_hard=tolerance_factor_hard, + ) + ) # Compute a per residue loss (equally distribute the loss to both # neighbouring residues). - per_residue_loss_sum = (c_n_loss_per_residue + - ca_c_n_loss_per_residue + - c_n_ca_loss_per_residue) - per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) + - jnp.pad(per_residue_loss_sum, [[1, 0]])) + per_residue_loss_sum = ( + c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue + ) + per_residue_loss_sum = 0.5 * ( + jnp.pad(per_residue_loss_sum, [[0, 1]]) + + jnp.pad(per_residue_loss_sum, [[1, 0]]) + ) # Compute hard violations. violation_mask = jnp.max( - jnp.stack([c_n_violation_mask, - ca_c_n_violation_mask, - c_n_ca_violation_mask]), axis=0) + jnp.stack( + [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask] + ), + axis=0, + ) violation_mask = jnp.maximum( - jnp.pad(violation_mask, [[0, 1]]), - jnp.pad(violation_mask, [[1, 0]])) + jnp.pad(violation_mask, [[0, 1]]), jnp.pad(violation_mask, [[1, 0]]) + ) - return {'c_n_loss_mean': c_n_loss, # shape () - 'ca_c_n_loss_mean': ca_c_n_loss, # shape () - 'c_n_ca_loss_mean': c_n_ca_loss, # shape () - 'per_residue_loss_sum': per_residue_loss_sum, # shape (N) - 'per_residue_violation_mask': violation_mask # shape (N) - } + return { + 'c_n_loss_mean': c_n_loss, # shape () + 'ca_c_n_loss_mean': ca_c_n_loss, # shape () + 'c_n_ca_loss_mean': c_n_ca_loss, # shape () + 'per_residue_loss_sum': per_residue_loss_sum, # shape (N) + 'per_residue_violation_mask': violation_mask, # shape (N) + } + + +def _loss_and_violation_mask( + *, + metric: jnp.ndarray, # (N - 1) + gt_metric: Union[float, jnp.ndarray], + gt_stddev: Union[float, jnp.ndarray], + mask: jnp.ndarray, # (N - 1) + tolerance_factor_soft: float = 12.0, + tolerance_factor_hard: float = 12.0, +): + """Compute loss and violation mask for a given metric.""" + error = jnp.sqrt(1e-6 + jnp.square(metric - gt_metric)) + loss_per_residue = jax.nn.relu( + error - tolerance_factor_soft * gt_stddev + ) + loss = jnp.sum(mask * loss_per_residue) / (jnp.sum(mask) + 1e-6) + violation_mask = mask * ( + error > (tolerance_factor_hard * gt_stddev) + ) + return loss_per_residue, loss, violation_mask def between_residue_clash_loss( @@ -1098,9 +1138,10 @@ def _make_renaming_matrices(): resname].index(target_atom_swap) correspondences[source_index] = target_index correspondences[target_index] = source_index - renaming_matrix = np.zeros((14, 14), dtype=np.float32) - for index, correspondence in enumerate(correspondences): - renaming_matrix[index, correspondence] = 1. + + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. all_matrices[resname] = renaming_matrix.astype(np.float32) renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) return renaming_matrices