mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Use correct type information in tests.
PiperOrigin-RevId: 797102858 Change-Id: Ie49cf8102dee57a507d60464c746cf92fa70192f
This commit is contained in:
committed by
Copybara-Service
parent
29f082183b
commit
aba97651a6
@@ -18,6 +18,7 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from alphafold.model import all_atom
|
||||
from alphafold.model import r3
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
L1_CLAMP_DISTANCE = 10
|
||||
@@ -80,7 +81,7 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase):
|
||||
global_rigid_transform = get_global_rigid_transform(
|
||||
rot_angle, translation, 1)
|
||||
|
||||
target_positions = r3.vecs_from_tensor(target_positions)
|
||||
target_positions = r3.vecs_from_tensor(jax.numpy.array(target_positions))
|
||||
pred_positions = r3.rigids_mul_vecs(
|
||||
global_rigid_transform, target_positions)
|
||||
positions_mask = np.ones(target_positions.x.shape[0])
|
||||
@@ -93,7 +94,7 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase):
|
||||
pred_frames, target_frames, frames_mask, pred_positions,
|
||||
target_positions, positions_mask, L1_CLAMP_DISTANCE,
|
||||
L1_CLAMP_DISTANCE, epsilon=0)
|
||||
self.assertAlmostEqual(fape, 0.)
|
||||
self.assertAlmostEqual(fape, 0., places=6)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('identity',
|
||||
@@ -120,8 +121,8 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase):
|
||||
pred_frames = target_frames
|
||||
frames_mask = np.ones(2)
|
||||
|
||||
target_positions = r3.vecs_from_tensor(np.array(target_positions))
|
||||
pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
|
||||
target_positions = r3.vecs_from_tensor(jax.numpy.array(target_positions))
|
||||
pred_positions = r3.vecs_from_tensor(jax.numpy.array(pred_positions))
|
||||
positions_mask = np.ones(target_positions.x.shape[0])
|
||||
|
||||
alddt = all_atom.frame_aligned_point_error(
|
||||
|
||||
@@ -159,10 +159,10 @@ class NotebookUtilsTest(parameterized.TestCase):
|
||||
def test_show_msa_info(self, mocked_stdout):
|
||||
single_chain_msas = [
|
||||
parsers.Msa(sequences=['A', 'B', 'C', 'C'],
|
||||
deletion_matrix=[None] * 4,
|
||||
deletion_matrix=[[0]] * 4,
|
||||
descriptions=[''] * 4),
|
||||
parsers.Msa(sequences=['A', 'A', 'A', 'D'],
|
||||
deletion_matrix=[None] * 4,
|
||||
deletion_matrix=[[0]] * 4,
|
||||
descriptions=[''] * 4)
|
||||
]
|
||||
notebook_utils.show_msa_info(
|
||||
|
||||
Reference in New Issue
Block a user