diff --git a/alphafold/model/all_atom_test.py b/alphafold/model/all_atom_test.py index 36ba45f..0bab235 100644 --- a/alphafold/model/all_atom_test.py +++ b/alphafold/model/all_atom_test.py @@ -18,6 +18,7 @@ from absl.testing import absltest from absl.testing import parameterized from alphafold.model import all_atom from alphafold.model import r3 +import jax import numpy as np L1_CLAMP_DISTANCE = 10 @@ -80,7 +81,7 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase): global_rigid_transform = get_global_rigid_transform( rot_angle, translation, 1) - target_positions = r3.vecs_from_tensor(target_positions) + target_positions = r3.vecs_from_tensor(jax.numpy.array(target_positions)) pred_positions = r3.rigids_mul_vecs( global_rigid_transform, target_positions) positions_mask = np.ones(target_positions.x.shape[0]) @@ -93,7 +94,7 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase): pred_frames, target_frames, frames_mask, pred_positions, target_positions, positions_mask, L1_CLAMP_DISTANCE, L1_CLAMP_DISTANCE, epsilon=0) - self.assertAlmostEqual(fape, 0.) + self.assertAlmostEqual(fape, 0., places=6) @parameterized.named_parameters( ('identity', @@ -120,8 +121,8 @@ class AllAtomTest(parameterized.TestCase, absltest.TestCase): pred_frames = target_frames frames_mask = np.ones(2) - target_positions = r3.vecs_from_tensor(np.array(target_positions)) - pred_positions = r3.vecs_from_tensor(np.array(pred_positions)) + target_positions = r3.vecs_from_tensor(jax.numpy.array(target_positions)) + pred_positions = r3.vecs_from_tensor(jax.numpy.array(pred_positions)) positions_mask = np.ones(target_positions.x.shape[0]) alddt = all_atom.frame_aligned_point_error( diff --git a/alphafold/notebooks/notebook_utils_test.py b/alphafold/notebooks/notebook_utils_test.py index 6df7689..9d1764f 100644 --- a/alphafold/notebooks/notebook_utils_test.py +++ b/alphafold/notebooks/notebook_utils_test.py @@ -159,10 +159,10 @@ class NotebookUtilsTest(parameterized.TestCase): def test_show_msa_info(self, mocked_stdout): single_chain_msas = [ parsers.Msa(sequences=['A', 'B', 'C', 'C'], - deletion_matrix=[None] * 4, + deletion_matrix=[[0]] * 4, descriptions=[''] * 4), parsers.Msa(sequences=['A', 'A', 'A', 'D'], - deletion_matrix=[None] * 4, + deletion_matrix=[[0]] * 4, descriptions=[''] * 4) ] notebook_utils.show_msa_info(