mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 20:54:24 +08:00
348 lines
11 KiB
Python
348 lines
11 KiB
Python
# Copyright 2021 AlQuraishi Laboratory
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
import numpy as np
|
|
import unittest
|
|
|
|
from openfold.data.data_transforms import make_atom14_masks_np
|
|
from openfold.np.residue_constants import (
|
|
restype_atom14_mask,
|
|
restype_atom37_mask,
|
|
)
|
|
from openfold.model.structure_module import (
|
|
StructureModule,
|
|
StructureModuleTransition,
|
|
AngleResnet,
|
|
InvariantPointAttention,
|
|
)
|
|
from openfold.utils.rigid_utils import Rotation, Rigid
|
|
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
|
|
from openfold.utils.geometry.rotation_matrix import Rot3Array
|
|
from openfold.utils.geometry.vector import Vec3Array
|
|
import tests.compare_utils as compare_utils
|
|
from tests.config import consts
|
|
from tests.data_utils import (
|
|
random_affines_4x4,
|
|
)
|
|
|
|
if compare_utils.alphafold_is_installed():
|
|
alphafold = compare_utils.import_alphafold()
|
|
import jax
|
|
import haiku as hk
|
|
|
|
|
|
class TestStructureModule(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if compare_utils.alphafold_is_installed():
|
|
if consts.is_multimer:
|
|
cls.am_atom = alphafold.model.all_atom_multimer
|
|
cls.am_fold = alphafold.model.folding_multimer
|
|
cls.am_modules = alphafold.model.modules_multimer
|
|
cls.am_rigid = alphafold.model.geometry
|
|
else:
|
|
cls.am_atom = alphafold.model.all_atom
|
|
cls.am_fold = alphafold.model.folding
|
|
cls.am_modules = alphafold.model.modules
|
|
cls.am_rigid = alphafold.model.r3
|
|
|
|
def test_structure_module_shape(self):
|
|
batch_size = consts.batch_size
|
|
n = consts.n_res
|
|
c_s = consts.c_s
|
|
c_z = consts.c_z
|
|
c_ipa = 13
|
|
c_resnet = 17
|
|
no_heads_ipa = 6
|
|
no_query_points = 4
|
|
no_value_points = 4
|
|
dropout_rate = 0.1
|
|
no_layers = 3
|
|
no_transition_layers = 3
|
|
no_resnet_layers = 3
|
|
ar_epsilon = 1e-6
|
|
no_angles = 7
|
|
trans_scale_factor = 10
|
|
inf = 1e5
|
|
|
|
sm = StructureModule(
|
|
c_s,
|
|
c_z,
|
|
c_ipa,
|
|
c_resnet,
|
|
no_heads_ipa,
|
|
no_query_points,
|
|
no_value_points,
|
|
dropout_rate,
|
|
no_layers,
|
|
no_transition_layers,
|
|
no_resnet_layers,
|
|
no_angles,
|
|
trans_scale_factor,
|
|
ar_epsilon,
|
|
inf,
|
|
is_multimer=consts.is_multimer
|
|
)
|
|
|
|
s = torch.rand((batch_size, n, c_s))
|
|
z = torch.rand((batch_size, n, n, c_z))
|
|
f = torch.randint(low=0, high=21, size=(batch_size, n)).long()
|
|
|
|
out = sm({"single": s, "pair": z}, f)
|
|
|
|
if consts.is_multimer:
|
|
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
|
|
else:
|
|
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
|
|
|
|
self.assertTrue(
|
|
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
|
|
)
|
|
self.assertTrue(
|
|
out["positions"].shape == (no_layers, batch_size, n, 14, 3)
|
|
)
|
|
|
|
def test_structure_module_transition_shape(self):
|
|
batch_size = 2
|
|
n = 5
|
|
c = 7
|
|
num_layers = 3
|
|
dropout = 0.1
|
|
|
|
smt = StructureModuleTransition(c, num_layers, dropout)
|
|
|
|
s = torch.rand((batch_size, n, c))
|
|
|
|
shape_before = s.shape
|
|
s = smt(s)
|
|
shape_after = s.shape
|
|
|
|
self.assertTrue(shape_before == shape_after)
|
|
|
|
@compare_utils.skip_unless_alphafold_installed()
|
|
def test_structure_module_compare(self):
|
|
config = compare_utils.get_alphafold_config()
|
|
c_sm = config.model.heads.structure_module
|
|
c_global = config.model.global_config
|
|
|
|
def run_sm(representations, batch):
|
|
sm = self.am_fold.StructureModule(c_sm, c_global)
|
|
representations = {
|
|
k: jax.lax.stop_gradient(v) for k, v in representations.items()
|
|
}
|
|
batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()}
|
|
|
|
if consts.is_multimer:
|
|
return sm(representations, batch, is_training=False, compute_loss=True)
|
|
return sm(representations, batch, is_training=False)
|
|
|
|
f = hk.transform(run_sm)
|
|
|
|
n_res = 200
|
|
|
|
representations = {
|
|
"single": np.random.rand(n_res, consts.c_s).astype(np.float32),
|
|
"pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
|
|
}
|
|
|
|
batch = {
|
|
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
|
|
"aatype": np.random.randint(0, 21, (n_res,)),
|
|
}
|
|
|
|
batch["atom14_atom_exists"] = np.take(
|
|
restype_atom14_mask, batch["aatype"], axis=0
|
|
)
|
|
|
|
batch["atom37_atom_exists"] = np.take(
|
|
restype_atom37_mask, batch["aatype"], axis=0
|
|
)
|
|
|
|
batch.update(make_atom14_masks_np(batch))
|
|
|
|
params = compare_utils.fetch_alphafold_module_weights(
|
|
"alphafold/alphafold_iteration/structure_module"
|
|
)
|
|
|
|
key = jax.random.PRNGKey(42)
|
|
out_gt = f.apply(params, key, representations, batch)
|
|
out_gt = torch.as_tensor(
|
|
np.array(out_gt["final_atom14_positions"].block_until_ready())
|
|
)
|
|
|
|
model = compare_utils.get_global_pretrained_openfold()
|
|
out_repro = model.structure_module(
|
|
{
|
|
"single": torch.as_tensor(representations["single"]).cuda(),
|
|
"pair": torch.as_tensor(representations["pair"]).cuda(),
|
|
},
|
|
torch.as_tensor(batch["aatype"]).cuda(),
|
|
mask=torch.as_tensor(batch["seq_mask"]).cuda(),
|
|
inplace_safe=False,
|
|
)
|
|
out_repro = out_repro["positions"][-1].cpu()
|
|
|
|
# The structure module, thanks to angle normalization, is very volatile
|
|
# We only assess the mean here. Heuristically speaking, it seems to
|
|
# have lower error in general on real rather than synthetic data.
|
|
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05)
|
|
|
|
|
|
class TestInvariantPointAttention(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if compare_utils.alphafold_is_installed():
|
|
if consts.is_multimer:
|
|
cls.am_atom = alphafold.model.all_atom_multimer
|
|
cls.am_fold = alphafold.model.folding_multimer
|
|
cls.am_modules = alphafold.model.modules_multimer
|
|
cls.am_rigid = alphafold.model.geometry
|
|
else:
|
|
cls.am_atom = alphafold.model.all_atom
|
|
cls.am_fold = alphafold.model.folding
|
|
cls.am_modules = alphafold.model.modules
|
|
cls.am_rigid = alphafold.model.r3
|
|
|
|
def test_shape(self):
|
|
c_m = 13
|
|
c_z = 17
|
|
c_hidden = 19
|
|
no_heads = 5
|
|
no_qp = 7
|
|
no_vp = 11
|
|
|
|
batch_size = 2
|
|
n_res = 23
|
|
|
|
s = torch.rand((batch_size, n_res, c_m))
|
|
z = torch.rand((batch_size, n_res, n_res, c_z))
|
|
mask = torch.ones((batch_size, n_res))
|
|
|
|
rot_mats = torch.rand((batch_size, n_res, 3, 3))
|
|
trans = torch.rand((batch_size, n_res, 3))
|
|
|
|
if consts.is_multimer:
|
|
rotation = Rot3Array.from_array(rot_mats)
|
|
translation = Vec3Array.from_array(trans)
|
|
r = Rigid3Array(rotation, translation)
|
|
else:
|
|
rots = Rotation(rot_mats=rot_mats, quats=None)
|
|
r = Rigid(rots, trans)
|
|
|
|
ipa = InvariantPointAttention(
|
|
c_m, c_z, c_hidden, no_heads, no_qp, no_vp, is_multimer=consts.is_multimer
|
|
)
|
|
|
|
shape_before = s.shape
|
|
s = ipa(s, z, r, mask)
|
|
|
|
self.assertTrue(s.shape == shape_before)
|
|
|
|
@compare_utils.skip_unless_alphafold_installed()
|
|
def test_ipa_compare(self):
|
|
def run_ipa(act, static_feat_2d, mask, affine):
|
|
config = compare_utils.get_alphafold_config()
|
|
ipa = self.am_fold.InvariantPointAttention(
|
|
config.model.heads.structure_module,
|
|
config.model.global_config,
|
|
)
|
|
|
|
if consts.is_multimer:
|
|
attn = ipa(
|
|
inputs_1d=act,
|
|
inputs_2d=static_feat_2d,
|
|
mask=mask,
|
|
rigid=affine
|
|
)
|
|
else:
|
|
attn = ipa(
|
|
inputs_1d=act,
|
|
inputs_2d=static_feat_2d,
|
|
mask=mask,
|
|
affine=affine
|
|
)
|
|
|
|
return attn
|
|
|
|
f = hk.transform(run_ipa)
|
|
|
|
n_res = consts.n_res
|
|
c_s = consts.c_s
|
|
c_z = consts.c_z
|
|
|
|
sample_act = np.random.rand(n_res, c_s)
|
|
sample_2d = np.random.rand(n_res, n_res, c_z)
|
|
sample_mask = np.ones((n_res, 1))
|
|
|
|
affines = random_affines_4x4((n_res,))
|
|
|
|
if consts.is_multimer:
|
|
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
|
|
transformations = Rigid3Array.from_tensor_4x4(
|
|
torch.as_tensor(affines).float().cuda()
|
|
)
|
|
sample_affine = rigids
|
|
else:
|
|
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
|
|
quats = self.am_rigid.rigids_to_quataffine(rigids)
|
|
transformations = Rigid.from_tensor_4x4(
|
|
torch.as_tensor(affines).float().cuda()
|
|
)
|
|
sample_affine = quats
|
|
|
|
ipa_params = compare_utils.fetch_alphafold_module_weights(
|
|
"alphafold/alphafold_iteration/structure_module/"
|
|
+ "fold_iteration/invariant_point_attention"
|
|
)
|
|
|
|
out_gt = f.apply(
|
|
ipa_params, None, sample_act, sample_2d, sample_mask, sample_affine
|
|
).block_until_ready()
|
|
out_gt = torch.as_tensor(np.array(out_gt))
|
|
|
|
with torch.no_grad():
|
|
model = compare_utils.get_global_pretrained_openfold()
|
|
out_repro = model.structure_module.ipa(
|
|
torch.as_tensor(sample_act).float().cuda(),
|
|
torch.as_tensor(sample_2d).float().cuda(),
|
|
transformations,
|
|
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
|
|
).cpu()
|
|
|
|
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
|
|
|
|
|
|
class TestAngleResnet(unittest.TestCase):
|
|
def test_shape(self):
|
|
batch_size = 2
|
|
n = 3
|
|
c_s = 13
|
|
c_hidden = 11
|
|
no_layers = 5
|
|
no_angles = 7
|
|
epsilon = 1e-12
|
|
|
|
ar = AngleResnet(c_s, c_hidden, no_layers, no_angles, epsilon)
|
|
a = torch.rand((batch_size, n, c_s))
|
|
a_initial = torch.rand((batch_size, n, c_s))
|
|
|
|
_, a = ar(a, a_initial)
|
|
|
|
self.assertTrue(a.shape == (batch_size, n, no_angles, 2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|