Files
openfold/tests/test_embedders.py
2021-10-16 01:17:18 -04:00

111 lines
2.8 KiB
Python

# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from openfold.model.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
)
class TestInputEmbedder(unittest.TestCase):
def test_shape(self):
tf_dim = 2
msa_dim = 3
c_z = 5
c_m = 7
relpos_k = 11
b = 13
n_res = 17
n_clust = 19
tf = torch.rand((b, n_res, tf_dim))
ri = torch.rand((b, n_res))
msa = torch.rand((b, n_clust, n_res, msa_dim))
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf, ri, msa)
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
class TestRecyclingEmbedder(unittest.TestCase):
def test_shape(self):
batch_size = 2
n = 3
c_z = 5
c_m = 7
min_bin = 0
max_bin = 10
no_bins = 9
re = RecyclingEmbedder(c_m, c_z, min_bin, max_bin, no_bins)
m_1 = torch.rand((batch_size, n, c_m))
z = torch.rand((batch_size, n, n, c_z))
x = torch.rand((batch_size, n, 3))
m_1, z = re(m_1, z, x)
self.assertTrue(z.shape == (batch_size, n, n, c_z))
self.assertTrue(m_1.shape == (batch_size, n, c_m))
class TestTemplateAngleEmbedder(unittest.TestCase):
def test_shape(self):
template_angle_dim = 51
c_m = 256
batch_size = 4
n_templ = 4
n_res = 256
tae = TemplateAngleEmbedder(
template_angle_dim,
c_m,
)
x = torch.rand((batch_size, n_templ, n_res, template_angle_dim))
x = tae(x)
self.assertTrue(x.shape == (batch_size, n_templ, n_res, c_m))
class TestTemplatePairEmbedder(unittest.TestCase):
def test_shape(self):
batch_size = 2
n_templ = 3
n_res = 5
template_pair_dim = 7
c_t = 11
tpe = TemplatePairEmbedder(
template_pair_dim,
c_t,
)
x = torch.rand((batch_size, n_templ, n_res, n_res, template_pair_dim))
x = tpe(x)
self.assertTrue(x.shape == (batch_size, n_templ, n_res, n_res, c_t))
if __name__ == "__main__":
unittest.main()