Use correct type information in tests.

PiperOrigin-RevId: 797102858
Change-Id: Ie49cf8102dee57a507d60464c746cf92fa70192f
This commit is contained in:
Ryan Pachauri
2025-08-19 18:37:25 -07:00
committed by Copybara-Service
parent 29f082183b
commit aba97651a6
2 changed files with 7 additions and 6 deletions

View File

@@ -18,6 +18,7 @@ from absl.testing import absltest
from absl.testing import parameterized
from alphafold.model import all_atom
from alphafold.model import r3
import jax
import numpy as np
L1_CLAMP_DISTANCE = 10
@@ -80,7 +81,7 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase):
global_rigid_transform = get_global_rigid_transform(
rot_angle, translation, 1)
target_positions = r3.vecs_from_tensor(target_positions)
target_positions = r3.vecs_from_tensor(jax.numpy.array(target_positions))
pred_positions = r3.rigids_mul_vecs(
global_rigid_transform, target_positions)
positions_mask = np.ones(target_positions.x.shape[0])
@@ -93,7 +94,7 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase):
pred_frames, target_frames, frames_mask, pred_positions,
target_positions, positions_mask, L1_CLAMP_DISTANCE,
L1_CLAMP_DISTANCE, epsilon=0)
self.assertAlmostEqual(fape, 0.)
self.assertAlmostEqual(fape, 0., places=6)
@parameterized.named_parameters(
('identity',
@@ -120,8 +121,8 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase):
pred_frames = target_frames
frames_mask = np.ones(2)
target_positions = r3.vecs_from_tensor(np.array(target_positions))
pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
target_positions = r3.vecs_from_tensor(jax.numpy.array(target_positions))
pred_positions = r3.vecs_from_tensor(jax.numpy.array(pred_positions))
positions_mask = np.ones(target_positions.x.shape[0])
alddt = all_atom.frame_aligned_point_error(

View File

@@ -159,10 +159,10 @@ class NotebookUtilsTest(parameterized.TestCase):
def test_show_msa_info(self, mocked_stdout):
single_chain_msas = [
parsers.Msa(sequences=['A', 'B', 'C', 'C'],
deletion_matrix=[None] * 4,
deletion_matrix=[[0]] * 4,
descriptions=[''] * 4),
parsers.Msa(sequences=['A', 'A', 'A', 'D'],
deletion_matrix=[None] * 4,
deletion_matrix=[[0]] * 4,
descriptions=[''] * 4)
]
notebook_utils.show_msa_info(