mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Add tests for all_atom.between_residue_bond_loss.
PiperOrigin-RevId: 807751146 Change-Id: I5e6ac821ae077a798255e437d95fe1c0a528295e
This commit is contained in:
committed by
Copybara-Service
parent
c9b9901dda
commit
cc9042484e
@@ -12,17 +12,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for all_atom."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from alphafold.common import residue_constants
|
||||
from alphafold.model import all_atom
|
||||
from alphafold.model import r3
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
L1_CLAMP_DISTANCE = 10
|
||||
|
||||
BL_C_N = residue_constants.between_res_bond_length_c_n
|
||||
BL_STD_DEV_C_N = residue_constants.between_res_bond_length_stddev_c_n
|
||||
COS_CA_C_N = residue_constants.between_res_cos_angles_ca_c_n
|
||||
COS_C_N_CA = residue_constants.between_res_cos_angles_c_n_ca
|
||||
|
||||
|
||||
def _relu(x):
|
||||
"""Computes relu on a numpy array."""
|
||||
return np.maximum(0, x)
|
||||
|
||||
|
||||
def get_identity_rigid(shape):
|
||||
"""Returns identity rigid transform."""
|
||||
@@ -56,7 +67,7 @@ def get_global_rigid_transform(rot_angle, translation, bcast_dims):
|
||||
return r3.Rigids(rot, trans)
|
||||
|
||||
|
||||
class AllAtomTest(parameterized.TestCase, absltest.TestCase):
|
||||
class AllAtomTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('identity', 0, [0, 0, 0]),
|
||||
@@ -126,11 +137,110 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase):
|
||||
positions_mask = np.ones(target_positions.x.shape[0])
|
||||
|
||||
alddt = all_atom.frame_aligned_point_error(
|
||||
pred_frames, target_frames, frames_mask, pred_positions,
|
||||
target_positions, positions_mask, L1_CLAMP_DISTANCE,
|
||||
L1_CLAMP_DISTANCE, epsilon=0)
|
||||
pred_frames,
|
||||
target_frames,
|
||||
frames_mask,
|
||||
pred_positions,
|
||||
target_positions,
|
||||
positions_mask,
|
||||
L1_CLAMP_DISTANCE,
|
||||
L1_CLAMP_DISTANCE,
|
||||
epsilon=0,
|
||||
)
|
||||
self.assertAlmostEqual(alddt, expected_alddt)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name='c_n_loss',
|
||||
key='c_n_loss_mean',
|
||||
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
|
||||
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
|
||||
residue_index=np.arange(2, dtype=np.int32),
|
||||
aatype=np.zeros(2, dtype=np.int32),
|
||||
expected_val=np.sum(
|
||||
_relu(
|
||||
np.sqrt(1e-6 + np.square(0.001 - BL_C_N[0]))
|
||||
- 12.0 * BL_STD_DEV_C_N[0]
|
||||
)
|
||||
).astype(np.float32),
|
||||
),
|
||||
dict(
|
||||
testcase_name='ca_c_n_loss',
|
||||
key='ca_c_n_loss_mean',
|
||||
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
|
||||
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
|
||||
residue_index=np.arange(2, dtype=np.int32),
|
||||
aatype=np.zeros(2, dtype=np.int32),
|
||||
expected_val=np.sum(
|
||||
_relu(
|
||||
np.sqrt(1e-6 + np.square(-COS_CA_C_N[0]))
|
||||
- 12.0 * BL_STD_DEV_C_N[0]
|
||||
)
|
||||
).astype(np.float32),
|
||||
),
|
||||
dict(
|
||||
testcase_name='c_n_ca_loss',
|
||||
key='c_n_ca_loss_mean',
|
||||
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
|
||||
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
|
||||
residue_index=np.arange(2, dtype=np.int32),
|
||||
aatype=np.zeros(2, dtype=np.int32),
|
||||
expected_val=np.sum(
|
||||
_relu(
|
||||
np.sqrt(1e-6 + np.square(0.0 - COS_C_N_CA[0]))
|
||||
- 12.0 * COS_C_N_CA[1]
|
||||
)
|
||||
).astype(np.float32),
|
||||
),
|
||||
dict(
|
||||
testcase_name='per_residue_loss_sum',
|
||||
key='per_residue_loss_sum',
|
||||
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
|
||||
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
|
||||
residue_index=np.arange(2, dtype=np.int32),
|
||||
aatype=np.zeros(2, dtype=np.int32),
|
||||
expected_val=np.array([0.768001, 0.768001], dtype=np.float32),
|
||||
),
|
||||
dict(
|
||||
testcase_name='per_residue_violation_mask',
|
||||
key='per_residue_violation_mask',
|
||||
pred_atom_positions=np.zeros((2, 37, 3), dtype=np.float32),
|
||||
pred_atom_mask=np.ones((2, 37), dtype=np.float32),
|
||||
residue_index=np.arange(2, dtype=np.int32),
|
||||
aatype=np.zeros(2, dtype=np.int32),
|
||||
expected_val=np.array([1.0, 1.0], dtype=np.float32),
|
||||
),
|
||||
)
|
||||
def test_between_residue_bond_loss(
|
||||
self,
|
||||
key,
|
||||
pred_atom_positions,
|
||||
pred_atom_mask,
|
||||
residue_index,
|
||||
aatype,
|
||||
expected_val,
|
||||
):
|
||||
got = all_atom.between_residue_bond_loss(
|
||||
pred_atom_positions=jnp.array(pred_atom_positions),
|
||||
pred_atom_mask=jnp.array(pred_atom_mask),
|
||||
residue_index=jnp.array(residue_index),
|
||||
aatype=jnp.array(aatype),
|
||||
)
|
||||
self.assertIn(key, got)
|
||||
self.assertEqual(
|
||||
got[key].shape,
|
||||
expected_val.shape,
|
||||
f'Shape mismatch for key "{key}"',
|
||||
)
|
||||
self.assertEqual(
|
||||
got[key].dtype,
|
||||
expected_val.dtype,
|
||||
f'Dtype mismatch for key "{key}"',
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
got[key], expected_val, rtol=2e-6
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
||||
Reference in New Issue
Block a user