Add tests for all_atom.between_residue_bond_loss.

PiperOrigin-RevId: 807751146
Change-Id: I5e6ac821ae077a798255e437d95fe1c0a528295e
This commit is contained in:
Ryan Pachauri
2025-09-16 10:33:07 -07:00
committed by Copybara-Service
parent c9b9901dda
commit cc9042484e

View File

@@ -12,17 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for all_atom."""
from absl.testing import absltest
from absl.testing import parameterized
from alphafold.common import residue_constants
from alphafold.model import all_atom
from alphafold.model import r3
import jax
import jax.numpy as jnp
import numpy as np
L1_CLAMP_DISTANCE = 10
BL_C_N = residue_constants.between_res_bond_length_c_n
BL_STD_DEV_C_N = residue_constants.between_res_bond_length_stddev_c_n
COS_CA_C_N = residue_constants.between_res_cos_angles_ca_c_n
COS_C_N_CA = residue_constants.between_res_cos_angles_c_n_ca
def _relu(x):
"""Computes relu on a numpy array."""
return np.maximum(0, x)
def get_identity_rigid(shape):
"""Returns identity rigid transform."""
@@ -56,7 +67,7 @@ def get_global_rigid_transform(rot_angle, translation, bcast_dims):
return r3.Rigids(rot, trans)
class AllAtomTest(parameterized.TestCase, absltest.TestCase):
class AllAtomTest(parameterized.TestCase):
@parameterized.named_parameters(
('identity', 0, [0, 0, 0]),
@@ -126,11 +137,110 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase):
positions_mask = np.ones(target_positions.x.shape[0])
alddt = all_atom.frame_aligned_point_error(
pred_frames, target_frames, frames_mask, pred_positions,
target_positions, positions_mask, L1_CLAMP_DISTANCE,
L1_CLAMP_DISTANCE, epsilon=0)
pred_frames,
target_frames,
frames_mask,
pred_positions,
target_positions,
positions_mask,
L1_CLAMP_DISTANCE,
L1_CLAMP_DISTANCE,
epsilon=0,
)
self.assertAlmostEqual(alddt, expected_alddt)
@parameterized.named_parameters(
dict(
testcase_name='c_n_loss',
key='c_n_loss_mean',
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
residue_index=np.arange(2, dtype=np.int32),
aatype=np.zeros(2, dtype=np.int32),
expected_val=np.sum(
_relu(
np.sqrt(1e-6 + np.square(0.001 - BL_C_N[0]))
- 12.0 * BL_STD_DEV_C_N[0]
)
).astype(np.float32),
),
dict(
testcase_name='ca_c_n_loss',
key='ca_c_n_loss_mean',
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
residue_index=np.arange(2, dtype=np.int32),
aatype=np.zeros(2, dtype=np.int32),
expected_val=np.sum(
_relu(
np.sqrt(1e-6 + np.square(-COS_CA_C_N[0]))
- 12.0 * BL_STD_DEV_C_N[0]
)
).astype(np.float32),
),
dict(
testcase_name='c_n_ca_loss',
key='c_n_ca_loss_mean',
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
residue_index=np.arange(2, dtype=np.int32),
aatype=np.zeros(2, dtype=np.int32),
expected_val=np.sum(
_relu(
np.sqrt(1e-6 + np.square(0.0 - COS_C_N_CA[0]))
- 12.0 * COS_C_N_CA[1]
)
).astype(np.float32),
),
dict(
testcase_name='per_residue_loss_sum',
key='per_residue_loss_sum',
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
residue_index=np.arange(2, dtype=np.int32),
aatype=np.zeros(2, dtype=np.int32),
expected_val=np.array([0.768001, 0.768001], dtype=np.float32),
),
dict(
testcase_name='per_residue_violation_mask',
key='per_residue_violation_mask',
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
residue_index=np.arange(2, dtype=np.int32),
aatype=np.zeros(2, dtype=np.int32),
expected_val=np.array([1.0, 1.0], dtype=np.float32),
),
)
def test_between_residue_bond_loss(
self,
key,
pred_atom_positions,
pred_atom_mask,
residue_index,
aatype,
expected_val,
):
got = all_atom.between_residue_bond_loss(
pred_atom_positions=jnp.array(pred_atom_positions),
pred_atom_mask=jnp.array(pred_atom_mask),
residue_index=jnp.array(residue_index),
aatype=jnp.array(aatype),
)
self.assertIn(key, got)
self.assertEqual(
got[key].shape,
expected_val.shape,
f'Shape mismatch for key "{key}"',
)
self.assertEqual(
got[key].dtype,
expected_val.dtype,
f'Dtype mismatch for key "{key}"',
)
np.testing.assert_allclose(
got[key], expected_val, rtol=2e-6
)
if __name__ == '__main__':
absltest.main()