mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
154 lines
4.5 KiB
Python
154 lines
4.5 KiB
Python
# Copyright 2021 AlQuraishi Laboratory
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import random
|
|
import torch
|
|
import unittest
|
|
from tests.config import consts
|
|
from tests.data_utils import random_asym_ids
|
|
from openfold.model.embedders import (
|
|
InputEmbedder,
|
|
InputEmbedderMultimer,
|
|
PreembeddingEmbedder,
|
|
RecyclingEmbedder,
|
|
TemplateSingleEmbedder,
|
|
TemplatePairEmbedder
|
|
)
|
|
|
|
|
|
class TestInputEmbedder(unittest.TestCase):
|
|
def test_shape(self):
|
|
tf_dim = 2
|
|
msa_dim = 3
|
|
c_z = 5
|
|
c_m = 7
|
|
relpos_k = 11
|
|
|
|
b = 13
|
|
n_res = 17
|
|
n_clust = 19
|
|
|
|
max_relative_chain = 2
|
|
max_relative_idx = 32
|
|
use_chain_relative = True
|
|
|
|
tf = torch.rand((b, n_res, tf_dim))
|
|
ri = torch.rand((b, n_res))
|
|
msa = torch.rand((b, n_clust, n_res, msa_dim))
|
|
asym_ids_flat = torch.Tensor(random_asym_ids(n_res))
|
|
asym_id = torch.tile(asym_ids_flat.unsqueeze(0), (b, 1))
|
|
entity_id = asym_id
|
|
sym_id = torch.zeros_like(entity_id)
|
|
|
|
if consts.is_multimer:
|
|
ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m,
|
|
max_relative_idx=max_relative_idx,
|
|
use_chain_relative=use_chain_relative,
|
|
max_relative_chain=max_relative_chain)
|
|
batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa,
|
|
"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id}
|
|
msa_emb, pair_emb = ie(batch)
|
|
else:
|
|
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
|
|
msa_emb, pair_emb = ie(tf=tf, ri=ri, msa=msa, inplace_safe=False)
|
|
|
|
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
|
|
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
|
|
|
|
|
|
class TestPreembeddingEmbedder(unittest.TestCase):
|
|
def test_shape(self):
|
|
tf_dim = 22
|
|
preembedding_dim = 1280
|
|
c_z = 4
|
|
c_m = 6
|
|
relpos_k = 10
|
|
|
|
batch_size = 4
|
|
num_res = 20
|
|
|
|
tf = torch.rand((batch_size, num_res, tf_dim))
|
|
ri = torch.rand((batch_size, num_res))
|
|
preemb = torch.rand((batch_size, num_res, preembedding_dim))
|
|
|
|
pe = PreembeddingEmbedder(tf_dim, preembedding_dim, c_z, c_m, relpos_k)
|
|
|
|
seq_emb, pair_emb = pe(tf, ri, preemb)
|
|
self.assertTrue(seq_emb.shape == (batch_size, 1, num_res, c_m))
|
|
self.assertTrue(pair_emb.shape == (batch_size, num_res, num_res, c_z))
|
|
|
|
|
|
class TestRecyclingEmbedder(unittest.TestCase):
|
|
def test_shape(self):
|
|
batch_size = 2
|
|
n = 3
|
|
c_z = 5
|
|
c_m = 7
|
|
min_bin = 0
|
|
max_bin = 10
|
|
no_bins = 9
|
|
|
|
re = RecyclingEmbedder(c_m, c_z, min_bin, max_bin, no_bins)
|
|
|
|
m_1 = torch.rand((batch_size, n, c_m))
|
|
z = torch.rand((batch_size, n, n, c_z))
|
|
x = torch.rand((batch_size, n, 3))
|
|
|
|
m_1, z = re(m_1, z, x)
|
|
|
|
self.assertTrue(z.shape == (batch_size, n, n, c_z))
|
|
self.assertTrue(m_1.shape == (batch_size, n, c_m))
|
|
|
|
|
|
class TestTemplateAngleEmbedder(unittest.TestCase):
|
|
def test_shape(self):
|
|
template_angle_dim = 51
|
|
c_m = 256
|
|
batch_size = 4
|
|
n_templ = 4
|
|
n_res = 256
|
|
|
|
tae = TemplateSingleEmbedder(
|
|
template_angle_dim,
|
|
c_m,
|
|
)
|
|
|
|
x = torch.rand((batch_size, n_templ, n_res, template_angle_dim))
|
|
x = tae(x)
|
|
|
|
self.assertTrue(x.shape == (batch_size, n_templ, n_res, c_m))
|
|
|
|
|
|
class TestTemplatePairEmbedder(unittest.TestCase):
|
|
def test_shape(self):
|
|
batch_size = 2
|
|
n_templ = 3
|
|
n_res = 5
|
|
template_pair_dim = 7
|
|
c_t = 11
|
|
|
|
tpe = TemplatePairEmbedder(
|
|
template_pair_dim,
|
|
c_t,
|
|
)
|
|
|
|
x = torch.rand((batch_size, n_templ, n_res, n_res, template_pair_dim))
|
|
x = tpe(x)
|
|
|
|
self.assertTrue(x.shape == (batch_size, n_templ, n_res, n_res, c_t))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|