mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Begin to spruce up unit tests, fix config
This commit is contained in:
42
config.py
42
config.py
@@ -2,7 +2,15 @@ import copy
|
||||
import ml_collections as mlc
|
||||
|
||||
|
||||
def model_config(name, train=False):
|
||||
def set_inf(c, inf):
|
||||
for k, v in c.items():
|
||||
if(isinstance(v, mlc.ConfigDict)):
|
||||
set_inf(v, inf)
|
||||
elif(k == "inf"):
|
||||
c[k] = inf
|
||||
|
||||
|
||||
def model_config(name, train=False, low_prec=False):
|
||||
c = copy.deepcopy(config)
|
||||
if(name == "model_1"):
|
||||
pass
|
||||
@@ -16,28 +24,34 @@ def model_config(name, train=False):
|
||||
c.model.template.enabled = False
|
||||
elif(name == "model_1_ptm"):
|
||||
c.model.heads.tm.enabled = True
|
||||
c.model.loss.tm.weight = 0.1
|
||||
c.loss.tm.weight = 0.1
|
||||
elif(name == "model_2_ptm"):
|
||||
c.model.heads.tm.enabled = True
|
||||
c.model.loss.tm.weight = 0.1
|
||||
c.loss.tm.weight = 0.1
|
||||
elif(name == "model_3_ptm"):
|
||||
c.model.template.enabled = False
|
||||
c.model.heads.tm.enabled = True
|
||||
c.model.loss.tm.weight = 0.1
|
||||
c.loss.tm.weight = 0.1
|
||||
elif(name == "model_4_ptm"):
|
||||
c.model.template.enabled = False
|
||||
c.model.heads.tm.enabled = True
|
||||
c.model.loss.tm.weight = 0.1
|
||||
c.loss.tm.weight = 0.1
|
||||
elif(name == "model_5_ptm"):
|
||||
c.model.template.enabled = False
|
||||
c.model.heads.tm.enabled = True
|
||||
c.model.loss.tm.weight = 0.1
|
||||
c.loss.tm.weight = 0.1
|
||||
else:
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
if(train):
|
||||
c.globals.model.blocks_per_ckpt = 1
|
||||
c.globals.chunk_size = None
|
||||
|
||||
if(low_prec):
|
||||
c.globals.eps = 1e-4
|
||||
# If we want exact numerical parity with the original, inf can't be
|
||||
# a global constant
|
||||
set_inf(c, 1e4)
|
||||
|
||||
return c
|
||||
|
||||
@@ -51,7 +65,6 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
|
||||
chunk_size = mlc.FieldReference(4, field_type=int)
|
||||
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
|
||||
eps = mlc.FieldReference(1e-8, field_type=float)
|
||||
inf = mlc.FieldReference(1e8, field_type=float)
|
||||
|
||||
config = mlc.ConfigDict({
|
||||
# Recurring FieldReferences that can be changed globally here
|
||||
@@ -64,7 +77,6 @@ config = mlc.ConfigDict({
|
||||
"c_e": c_e,
|
||||
"c_s": c_s,
|
||||
"eps": eps,
|
||||
"inf": inf,
|
||||
},
|
||||
"model": {
|
||||
"no_cycles": 4,
|
||||
@@ -82,7 +94,7 @@ config = mlc.ConfigDict({
|
||||
"min_bin": 3.25,
|
||||
"max_bin": 20.75,
|
||||
"no_bins": 15,
|
||||
"inf": inf,#1e8,
|
||||
"inf": 1e8,
|
||||
},
|
||||
"template": {
|
||||
"distogram": {
|
||||
@@ -111,7 +123,7 @@ config = mlc.ConfigDict({
|
||||
"dropout_rate": 0.25,
|
||||
"blocks_per_ckpt": blocks_per_ckpt,
|
||||
"chunk_size": chunk_size,
|
||||
"inf": inf,
|
||||
"inf": 1e9,
|
||||
},
|
||||
"template_pointwise_attention": {
|
||||
"c_t": c_t,
|
||||
@@ -121,9 +133,9 @@ config = mlc.ConfigDict({
|
||||
"c_hidden": 16,
|
||||
"no_heads": 4,
|
||||
"chunk_size": chunk_size,
|
||||
"inf": inf,#1e-9,
|
||||
"inf": 1e9,
|
||||
},
|
||||
"inf": inf,
|
||||
"inf": 1e9,
|
||||
"eps": eps,#1e-6,
|
||||
"enabled": True,
|
||||
"embed_angles": True,
|
||||
@@ -148,7 +160,7 @@ config = mlc.ConfigDict({
|
||||
"pair_dropout": 0.25,
|
||||
"blocks_per_ckpt": blocks_per_ckpt,
|
||||
"chunk_size": chunk_size,
|
||||
"inf": inf,#1e9,
|
||||
"inf": 1e9,
|
||||
"eps": eps,#1e-10,
|
||||
},
|
||||
"enabled": True,
|
||||
@@ -169,7 +181,7 @@ config = mlc.ConfigDict({
|
||||
"pair_dropout": 0.25,
|
||||
"blocks_per_ckpt": blocks_per_ckpt,
|
||||
"chunk_size": chunk_size,
|
||||
"inf": inf,#1e9,
|
||||
"inf": 1e9,
|
||||
"eps": eps,#1e-10,
|
||||
},
|
||||
"structure_module": {
|
||||
@@ -187,7 +199,7 @@ config = mlc.ConfigDict({
|
||||
"no_angles": 7,
|
||||
"trans_scale_factor": 10,
|
||||
"epsilon": eps,#1e-12,
|
||||
"inf": inf,#1e5,
|
||||
"inf": 1e5,
|
||||
},
|
||||
"heads": {
|
||||
"lddt": {
|
||||
|
||||
@@ -316,7 +316,6 @@ class InvariantPointAttention(nn.Module):
|
||||
|
||||
# [*, N_res, N_res, H]
|
||||
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
|
||||
|
||||
# [*, N_res, N_res]
|
||||
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
|
||||
square_mask = self.inf * (square_mask - 1)
|
||||
@@ -721,7 +720,6 @@ class StructureModule(nn.Module):
|
||||
|
||||
# [*, N]
|
||||
t = T.identity(s.shape[:-1], s.dtype, s.device, self.training)
|
||||
|
||||
outputs = []
|
||||
for i in range(self.no_blocks):
|
||||
# [*, N, C_s]
|
||||
|
||||
@@ -23,6 +23,8 @@ from openfold.utils.affine_utils import T
|
||||
from openfold.utils.tensor_utils import (
|
||||
batched_gather,
|
||||
one_hot,
|
||||
tree_map,
|
||||
tensor_tree_map,
|
||||
)
|
||||
|
||||
|
||||
@@ -143,6 +145,13 @@ def compute_residx(batch):
|
||||
return out
|
||||
|
||||
|
||||
def compute_residx_np(batch):
|
||||
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
|
||||
out = compute_residx(batch)
|
||||
out = tensor_tree_map(lambda t: np.array(t), out)
|
||||
return out
|
||||
|
||||
|
||||
def atom14_to_atom37(atom14, batch):
|
||||
atom37_data = batched_gather(
|
||||
atom14,
|
||||
|
||||
2
setup.py
2
setup.py
@@ -19,7 +19,7 @@ setup(
|
||||
name='openfold',
|
||||
version='1.0.0',
|
||||
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
|
||||
author='Gustaf Ahdritz',
|
||||
author='Gustaf Ahdritz & DeepMind',
|
||||
author_email='gahdritz@gmail.com',
|
||||
license='Apache License, Version 2.0',
|
||||
url='https://github.com/aqlaboratory/openfold',
|
||||
|
||||
102
tests/compare_utils.py
Normal file
102
tests/compare_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
import importlib
|
||||
import pkgutil
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import model_config
|
||||
from openfold.model.model import AlphaFold
|
||||
from openfold.utils.import_weights import import_jax_weights_
|
||||
from tests.config import consts
|
||||
|
||||
# Give JAX some GPU memory discipline
|
||||
# (by default it hogs 90% of GPU memory. This disables that behavior and also
|
||||
# forces it to proactively free memory that it allocates)
|
||||
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
||||
os.environ["JAX_PLATFORM_NAME"] = "gpu"
|
||||
|
||||
|
||||
def alphafold_is_installed():
|
||||
return importlib.util.find_spec("alphafold") is not None
|
||||
|
||||
|
||||
def skip_unless_alphafold_installed():
|
||||
return unittest.skipUnless(alphafold_is_installed(), "Requires AlphaFold")
|
||||
|
||||
|
||||
def import_alphafold():
|
||||
"""
|
||||
If AlphaFold is installed using the provided setuptools script, this
|
||||
is necessary to expose all of AlphaFold's precious insides
|
||||
"""
|
||||
if("alphafold" in sys.modules):
|
||||
return sys.modules["alphafold"]
|
||||
module = importlib.import_module("alphafold")
|
||||
# Forcefully import alphafold's submodules
|
||||
submodules = pkgutil.walk_packages(
|
||||
module.__path__, prefix=("alphafold.")
|
||||
)
|
||||
for submodule_info in submodules:
|
||||
importlib.import_module(submodule_info.name)
|
||||
sys.modules["alphafold"] = module
|
||||
globals()["alphafold"] = module
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def get_alphafold_config():
|
||||
config = alphafold.model.config.model_config("model_1_ptm")
|
||||
config.model.global_config.deterministic = True
|
||||
return config
|
||||
|
||||
|
||||
_param_path = "openfold/resources/params/params_model_1_ptm.npz"
|
||||
_model = None
|
||||
def get_global_pretrained_openfold():
|
||||
global _model
|
||||
if(_model is None):
|
||||
_model = AlphaFold(model_config("model_1_ptm").model)
|
||||
_model = _model.eval()
|
||||
if(not os.path.exists(_param_path)):
|
||||
raise FileNotFoundError(
|
||||
"""Cannot load pretrained parameters. Make sure to run the
|
||||
installation script before running tests."""
|
||||
)
|
||||
import_jax_weights_(_model, _param_path)
|
||||
_model = _model.cuda()
|
||||
|
||||
return _model
|
||||
|
||||
|
||||
_orig_weights = None
|
||||
def _get_orig_weights():
|
||||
global _orig_weights
|
||||
if(_orig_weights is None):
|
||||
_orig_weights = np.load(_param_path)
|
||||
|
||||
return _orig_weights
|
||||
|
||||
|
||||
def _remove_key_prefix(d, prefix):
|
||||
for k, v in list(d.items()):
|
||||
if(k.startswith(prefix)):
|
||||
d.pop(k)
|
||||
d[k[len(prefix):]] = v
|
||||
|
||||
|
||||
def fetch_alphafold_module_weights(weight_path):
|
||||
orig_weights = _get_orig_weights()
|
||||
params = {
|
||||
k:v for k,v in orig_weights.items()
|
||||
if weight_path in k
|
||||
}
|
||||
if('/' in weight_path):
|
||||
spl = weight_path.split('/')
|
||||
spl = spl if len(spl[-1]) != 0 else spl[:-1]
|
||||
module_name = spl[-1]
|
||||
prefix = '/'.join(spl[:-1]) + '/'
|
||||
_remove_key_prefix(params, prefix)
|
||||
params = alphafold.model.utils.flat_params_to_haiku(params)
|
||||
return params
|
||||
17
tests/config.py
Normal file
17
tests/config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import ml_collections as mlc
|
||||
|
||||
consts = mlc.ConfigDict({
|
||||
"batch_size": 2,
|
||||
"n_res": 11,
|
||||
"n_seq": 13,
|
||||
"n_templ": 3,
|
||||
"n_extra": 17,
|
||||
"eps": 5e-4,
|
||||
# For compatibility with DeepMind's pretrained weights, it's easiest for
|
||||
# everyone if these take their real values.
|
||||
"c_m": 256,
|
||||
"c_z": 128,
|
||||
"c_s": 384,
|
||||
"c_t": 64,
|
||||
"c_e": 64,
|
||||
})
|
||||
@@ -54,7 +54,7 @@ def random_extra_msa_feats(n_extra, n, batch_size=None):
|
||||
return batch
|
||||
|
||||
|
||||
def random_affine_vectors(dim):
|
||||
def random_affines_vector(dim):
|
||||
prod_dim = 1
|
||||
for d in dim:
|
||||
prod_dim *= d
|
||||
@@ -68,7 +68,7 @@ def random_affine_vectors(dim):
|
||||
return affines.reshape(*dim, 7)
|
||||
|
||||
|
||||
def random_affine_4x4s(dim):
|
||||
def random_affines_4x4(dim):
|
||||
prod_dim = 1
|
||||
for d in dim:
|
||||
prod_dim *= d
|
||||
@@ -15,21 +15,33 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
from alphafold.model.evoformer import *
|
||||
from openfold.model.evoformer import (
|
||||
MSATransition,
|
||||
EvoformerStack,
|
||||
ExtraMSAStack,
|
||||
)
|
||||
from openfold.utils.tensor_utils import tree_map
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestEvoformerStack(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
batch_size = 5
|
||||
s_t = 27
|
||||
n_res = 29
|
||||
c_m = 7
|
||||
c_z = 11
|
||||
batch_size = consts.batch_size
|
||||
n_seq = consts.n_seq
|
||||
n_res = consts.n_res
|
||||
c_m = consts.c_m
|
||||
c_z = consts.c_z
|
||||
c_hidden_msa_att = 12
|
||||
c_hidden_opm = 17
|
||||
c_hidden_mul = 19
|
||||
c_hidden_pair_att = 14
|
||||
c_s = 23
|
||||
c_s = consts.c_s
|
||||
no_heads_msa = 3
|
||||
no_heads_pair = 7
|
||||
no_blocks = 2
|
||||
@@ -59,9 +71,9 @@ class TestEvoformerStack(unittest.TestCase):
|
||||
eps=eps,
|
||||
).eval()
|
||||
|
||||
m = torch.rand((batch_size, s_t, n_res, c_m))
|
||||
m = torch.rand((batch_size, n_seq, n_res, c_m))
|
||||
z = torch.rand((batch_size, n_res, n_res, c_z))
|
||||
msa_mask = torch.randint(0, 2, size=(batch_size, s_t, n_res))
|
||||
msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
|
||||
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
|
||||
|
||||
shape_m_before = m.shape
|
||||
@@ -73,6 +85,59 @@ class TestEvoformerStack(unittest.TestCase):
|
||||
self.assertTrue(z.shape == shape_z_before)
|
||||
self.assertTrue(s.shape == (batch_size, n_res, c_s))
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_ei(activations, masks):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_e = config.model.embeddings_and_evoformer.evoformer
|
||||
ei = alphafold.model.modules.EvoformerIteration(
|
||||
c_e, config.model.global_config, is_extra_msa=False)
|
||||
return ei(activations, masks, is_training=False)
|
||||
|
||||
f = hk.transform(run_ei)
|
||||
|
||||
n_res = consts.n_res
|
||||
n_seq = consts.n_seq
|
||||
|
||||
activations = {
|
||||
'msa': np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32),
|
||||
'pair': np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
|
||||
}
|
||||
|
||||
masks = {
|
||||
'msa': np.random.randint(0, 2, (n_seq, n_res)).astype(np.float32),
|
||||
'pair': np.random.randint(0, 2, (n_res, n_res)).astype(np.float32),
|
||||
}
|
||||
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
|
||||
)
|
||||
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
out_gt = f.apply(
|
||||
params, key, activations, masks
|
||||
)
|
||||
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
|
||||
out_gt_msa = torch.as_tensor(np.array(out_gt["msa"]))
|
||||
out_gt_pair = torch.as_tensor(np.array(out_gt["pair"]))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
|
||||
torch.as_tensor(activations["msa"]).cuda(),
|
||||
torch.as_tensor(activations["pair"]).cuda(),
|
||||
torch.as_tensor(masks["msa"]).cuda(),
|
||||
torch.as_tensor(masks["pair"]).cuda(),
|
||||
_mask_trans=False,
|
||||
)
|
||||
|
||||
out_repro_msa = out_repro_msa.cpu()
|
||||
out_repro_pair = out_repro_pair.cpu()
|
||||
|
||||
assert(torch.max(torch.abs(out_repro_msa - out_gt_msa) < consts.eps))
|
||||
assert(torch.max(torch.abs(out_repro_pair - out_gt_pair) < consts.eps))
|
||||
|
||||
|
||||
|
||||
class TestExtraMSAStack(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
@@ -143,6 +208,47 @@ class TestMSATransition(unittest.TestCase):
|
||||
|
||||
self.assertTrue(shape_before == shape_after)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_msa_transition(msa_act, msa_mask):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_e = config.model.embeddings_and_evoformer.evoformer
|
||||
msa_trans = alphafold.model.modules.Transition(
|
||||
c_e.msa_transition,
|
||||
config.model.global_config,
|
||||
name="msa_transition"
|
||||
)
|
||||
act = msa_trans(act=msa_act, mask=msa_mask)
|
||||
return act
|
||||
|
||||
f = hk.transform(run_msa_transition)
|
||||
|
||||
n_res = consts.n_res
|
||||
n_seq = consts.n_seq
|
||||
|
||||
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
|
||||
msa_mask = np.ones((n_seq, n_res)).astype(np.float32) # no mask here either
|
||||
|
||||
# Fetch pretrained parameters (but only from one block)]
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
|
||||
"msa_transition"
|
||||
)
|
||||
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
|
||||
|
||||
out_gt = f.apply(
|
||||
params, None, msa_act, msa_mask
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model.evoformer.blocks[0].msa_transition(
|
||||
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
|
||||
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
|
||||
).cpu()
|
||||
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -17,23 +17,37 @@ import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from alphafold.utils.loss import *
|
||||
from alphafold.utils.utils import T
|
||||
from openfold.utils.loss import (
|
||||
torsion_angle_loss,
|
||||
compute_fape,
|
||||
between_residue_bond_loss,
|
||||
between_residue_clash_loss,
|
||||
find_structural_violations,
|
||||
)
|
||||
from openfold.utils.affine_utils import T
|
||||
from openfold.utils.tensor_utils import tensor_tree_map
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestLoss(unittest.TestCase):
|
||||
def test_run_torsion_angle_loss(self):
|
||||
batch_size = 2
|
||||
n = 5
|
||||
batch_size = consts.batch_size
|
||||
n_res = consts.n_res
|
||||
|
||||
a = torch.rand((batch_size, n, 7, 2))
|
||||
a_gt = torch.rand((batch_size, n, 7, 2))
|
||||
a_alt_gt = torch.rand((batch_size, n, 7, 2))
|
||||
a = torch.rand((batch_size, n_res, 7, 2))
|
||||
a_gt = torch.rand((batch_size, n_res, 7, 2))
|
||||
a_alt_gt = torch.rand((batch_size, n_res, 7, 2))
|
||||
|
||||
loss = torsion_angle_loss(a, a_gt, a_alt_gt)
|
||||
|
||||
def test_run_fape(self):
|
||||
batch_size = 2
|
||||
batch_size = consts.batch_size
|
||||
n_frames = 7
|
||||
n_atoms = 5
|
||||
|
||||
@@ -45,12 +59,23 @@ class TestLoss(unittest.TestCase):
|
||||
trans_gt = torch.rand((batch_size, n_frames, 3))
|
||||
t = T(rots, trans)
|
||||
t_gt = T(rots_gt, trans_gt)
|
||||
frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
|
||||
positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
|
||||
length_scale = 10
|
||||
|
||||
loss = compute_fape(t, x, t_gt, x_gt)
|
||||
loss = compute_fape(
|
||||
pred_frames=t,
|
||||
target_frames=t_gt,
|
||||
frames_mask=frames_mask,
|
||||
pred_positions=x,
|
||||
target_positions=x_gt,
|
||||
positions_mask=positions_mask,
|
||||
length_scale=length_scale,
|
||||
)
|
||||
|
||||
def test_between_residue_bond_loss(self):
|
||||
bs = 2
|
||||
n = 10
|
||||
def test_run_between_residue_bond_loss(self):
|
||||
bs = consts.batch_size
|
||||
n = consts.n_res
|
||||
pred_pos = torch.rand(bs, n, 14, 3)
|
||||
pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
|
||||
residue_index = torch.arange(n).unsqueeze(0)
|
||||
@@ -63,9 +88,52 @@ class TestLoss(unittest.TestCase):
|
||||
aatype,
|
||||
)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_between_residue_bond_loss_compare(self):
|
||||
def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
|
||||
return alphafold.model.all_atom.between_residue_bond_loss(
|
||||
pred_pos,
|
||||
pred_atom_mask,
|
||||
residue_index,
|
||||
aatype,
|
||||
)
|
||||
|
||||
f = hk.transform(run_brbl)
|
||||
|
||||
n_res = consts.n_res
|
||||
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
|
||||
pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
|
||||
residue_index = np.arange(n_res)
|
||||
aatype = np.random.randint(0, 22, (n_res,))
|
||||
|
||||
out_gt = f.apply(
|
||||
{}, None,
|
||||
pred_pos,
|
||||
pred_atom_mask,
|
||||
residue_index,
|
||||
aatype,
|
||||
)
|
||||
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
|
||||
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
|
||||
|
||||
out_repro = between_residue_bond_loss(
|
||||
torch.tensor(pred_pos).cuda(),
|
||||
torch.tensor(pred_atom_mask).cuda(),
|
||||
torch.tensor(residue_index).cuda(),
|
||||
torch.tensor(aatype).cuda(),
|
||||
)
|
||||
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
|
||||
|
||||
for k in out_gt.keys():
|
||||
self.assertTrue(
|
||||
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
|
||||
)
|
||||
|
||||
|
||||
def test_between_residue_clash_loss(self):
|
||||
bs = 2
|
||||
n = 10
|
||||
bs = consts.batch_size
|
||||
n = consts.n_res
|
||||
|
||||
pred_pos = torch.rand(bs, n, 14, 3)
|
||||
pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
|
||||
atom14_atom_radius = torch.rand(bs, n, 14)
|
||||
@@ -79,7 +147,7 @@ class TestLoss(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_find_structural_violations(self):
|
||||
n = 10
|
||||
n = consts.n_res
|
||||
|
||||
batch = {
|
||||
"atom14_atom_exists": torch.randint(0, 2, (n, 14)),
|
||||
@@ -90,12 +158,12 @@ class TestLoss(unittest.TestCase):
|
||||
|
||||
pred_pos = torch.rand(n, 14, 3)
|
||||
|
||||
config = ml_collections.ConfigDict({
|
||||
config = {
|
||||
"clash_overlap_tolerance": 1.5,
|
||||
"violation_tolerance_factor": 12.0,
|
||||
})
|
||||
}
|
||||
|
||||
find_structural_violations(batch, pred_pos, config)
|
||||
find_structural_violations(batch, pred_pos, **config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -12,25 +12,35 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import unittest
|
||||
from config import *
|
||||
from alphafold.model.model import *
|
||||
from alphafold.utils.utils import my_tree_map
|
||||
from tests.alphafold.utils.utils import (
|
||||
from openfold.model.model import AlphaFold
|
||||
import openfold.utils.feats as feats
|
||||
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
from tests.data_utils import (
|
||||
random_template_feats,
|
||||
random_extra_msa_feats,
|
||||
)
|
||||
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestModel(unittest.TestCase):
|
||||
def test_dry_run(self):
|
||||
batch_size = 2
|
||||
n_seq = 5
|
||||
n_templ = 7
|
||||
n_res = 11
|
||||
n_extra_seq = 13
|
||||
batch_size = consts.batch_size
|
||||
n_seq = consts.n_seq
|
||||
n_templ = consts.n_templ
|
||||
n_res = consts.n_res
|
||||
n_extra_seq = consts.n_extra
|
||||
|
||||
c = model_config("model_1").model
|
||||
c.no_cycles = 2
|
||||
@@ -59,20 +69,65 @@ class TestModel(unittest.TestCase):
|
||||
batch.update({k:torch.tensor(v) for k, v in extra_feats.items()})
|
||||
batch["msa_mask"] = torch.randint(
|
||||
low=0, high=2, size=(batch_size, n_seq, n_res)
|
||||
)
|
||||
).float()
|
||||
batch["seq_mask"] = torch.randint(
|
||||
low=0, high=2, size=(batch_size, n_res)
|
||||
)
|
||||
).float()
|
||||
batch.update(feats.compute_residx(batch))
|
||||
|
||||
add_recycling_dims = lambda t: (
|
||||
t.unsqueeze(-1).expand(*t.shape, c.no_cycles)
|
||||
)
|
||||
batch = my_tree_map(add_recycling_dims, batch, torch.Tensor)
|
||||
batch = tensor_tree_map(add_recycling_dims, batch)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(batch)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_alphafold(batch):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
model = alphafold.model.modules.AlphaFold(config.model)
|
||||
return model(
|
||||
batch=batch, is_training=False, return_representations=True,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
f = hk.transform(run_alphafold)
|
||||
|
||||
params = compare_utils.fetch_alphafold_module_weights('')
|
||||
|
||||
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
|
||||
batch = pickle.load(fp)
|
||||
|
||||
out_gt = jax.jit(f.apply)(params, jax.random.PRNGKey(42), batch)
|
||||
|
||||
out_gt = out_gt["structure_module"]["final_atom_positions"]
|
||||
# atom37_to_atom14 doesn't like batches
|
||||
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
|
||||
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
|
||||
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
|
||||
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
|
||||
|
||||
batch = {
|
||||
k:torch.as_tensor(v).cuda() for k,v in batch.items()
|
||||
}
|
||||
batch["aatype"] = batch["aatype"].long()
|
||||
batch["template_aatype"] = batch["template_aatype"].long()
|
||||
batch["extra_msa"] = batch["extra_msa"].long()
|
||||
batch["residx_atom37_to_atom14"] = batch["residx_atom37_to_atom14"].long()
|
||||
|
||||
# Move the recycling dimension to the end
|
||||
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
|
||||
batch = tensor_tree_map(move_dim, batch)
|
||||
|
||||
with torch.no_grad():
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model(batch)
|
||||
|
||||
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
|
||||
|
||||
out_repro = out_repro["sm"]["positions"][-1]
|
||||
out_repro = out_repro.squeeze(0)
|
||||
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 1e-3))
|
||||
|
||||
|
||||
@@ -15,23 +15,36 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
from alphafold.model.msa import *
|
||||
from openfold.model.msa import (
|
||||
MSARowAttentionWithPairBias,
|
||||
MSAColumnAttention,
|
||||
MSAColumnGlobalAttention,
|
||||
)
|
||||
from openfold.utils.tensor_utils import tree_map
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestMSARowAttentionWithPairBias(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
batch_size = 2
|
||||
s_t = 3
|
||||
n = 5
|
||||
c_m = 7
|
||||
c_z = 11
|
||||
batch_size = consts.batch_size
|
||||
n_seq = consts.n_seq
|
||||
n_res = consts.n_res
|
||||
c_m = consts.c_m
|
||||
c_z = consts.c_z
|
||||
c = 52
|
||||
no_heads = 4
|
||||
chunk_size=None
|
||||
|
||||
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads)
|
||||
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads, chunk_size)
|
||||
|
||||
m = torch.rand((batch_size, s_t, n, c_m))
|
||||
z = torch.rand((batch_size, n, n, c_z))
|
||||
m = torch.rand((batch_size, n_seq, n_res, c_m))
|
||||
z = torch.rand((batch_size, n_res, n_res, c_z))
|
||||
|
||||
shape_before = m.shape
|
||||
m = mrapb(m, z)
|
||||
@@ -39,19 +52,65 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
|
||||
|
||||
self.assertTrue(shape_before == shape_after)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_msa_row_att(msa_act, msa_mask, pair_act):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_e = config.model.embeddings_and_evoformer.evoformer
|
||||
msa_row = alphafold.model.modules.MSARowAttentionWithPairBias(
|
||||
c_e.msa_row_attention_with_pair_bias,
|
||||
config.model.global_config
|
||||
)
|
||||
act = msa_row(
|
||||
msa_act=msa_act, msa_mask=msa_mask, pair_act=pair_act
|
||||
)
|
||||
return act
|
||||
|
||||
f = hk.transform(run_msa_row_att)
|
||||
|
||||
n_res = consts.n_res
|
||||
n_seq = consts.n_seq
|
||||
|
||||
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
|
||||
msa_mask = np.random.randint(
|
||||
low=0, high=2, size=(n_seq, n_res)
|
||||
).astype(np.float32)
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
|
||||
|
||||
# Fetch pretrained parameters (but only from one block)]
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
|
||||
"msa_row_attention"
|
||||
)
|
||||
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
|
||||
|
||||
out_gt = f.apply(
|
||||
params, None, msa_act, msa_mask, pair_act
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model.evoformer.blocks[0].msa_att_row(
|
||||
torch.as_tensor(msa_act).cuda(),
|
||||
torch.as_tensor(pair_act).cuda(),
|
||||
torch.as_tensor(msa_mask).cuda(),
|
||||
).cpu()
|
||||
|
||||
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
|
||||
|
||||
|
||||
class TestMSAColumnAttention(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
batch_size = 2
|
||||
s_t = 3
|
||||
n = 5
|
||||
c_m = 7
|
||||
batch_size = consts.batch_size
|
||||
n_seq = consts.n_seq
|
||||
n_res = consts.n_res
|
||||
c_m = consts.c_m
|
||||
c = 44
|
||||
no_heads = 4
|
||||
|
||||
msaca = MSAColumnAttention(c_m, c, no_heads)
|
||||
|
||||
x = torch.rand((batch_size, s_t, n, c_m))
|
||||
x = torch.rand((batch_size, n_seq, n_res, c_m))
|
||||
|
||||
shape_before = x.shape
|
||||
x = msaca(x)
|
||||
@@ -59,19 +118,63 @@ class TestMSAColumnAttention(unittest.TestCase):
|
||||
|
||||
self.assertTrue(shape_before == shape_after)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_msa_col_att(msa_act, msa_mask):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_e = config.model.embeddings_and_evoformer.evoformer
|
||||
msa_col = alphafold.model.modules.MSAColumnAttention(
|
||||
c_e.msa_column_attention,
|
||||
config.model.global_config
|
||||
)
|
||||
act = msa_col(
|
||||
msa_act=msa_act, msa_mask=msa_mask
|
||||
)
|
||||
return act
|
||||
|
||||
f = hk.transform(run_msa_col_att)
|
||||
|
||||
n_res = consts.n_res
|
||||
n_seq = consts.n_seq
|
||||
|
||||
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
|
||||
msa_mask = np.random.randint(
|
||||
low=0, high=2, size=(n_seq, n_res)
|
||||
).astype(np.float32)
|
||||
|
||||
# Fetch pretrained parameters (but only from one block)]
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
|
||||
"msa_column_attention"
|
||||
)
|
||||
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
|
||||
|
||||
out_gt = f.apply(
|
||||
params, None, msa_act, msa_mask
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model.evoformer.blocks[0].msa_att_col(
|
||||
torch.as_tensor(msa_act).cuda(),
|
||||
torch.as_tensor(msa_mask).cuda(),
|
||||
).cpu()
|
||||
|
||||
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
|
||||
|
||||
|
||||
class TestMSAColumnGlobalAttention(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
batch_size = 2
|
||||
s_t = 3
|
||||
n = 5
|
||||
c_m = 7
|
||||
batch_size = consts.batch_size
|
||||
n_seq = consts.n_seq
|
||||
n_res = consts.n_res
|
||||
c_m = consts.c_m
|
||||
c = 44
|
||||
no_heads = 4
|
||||
|
||||
msagca = MSAColumnGlobalAttention(c_m, c, no_heads)
|
||||
|
||||
x = torch.rand((batch_size, s_t, n, c_m))
|
||||
x = torch.rand((batch_size, n_seq, n_res, c_m))
|
||||
|
||||
shape_before = x.shape
|
||||
x = msagca(x)
|
||||
@@ -79,6 +182,48 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
|
||||
|
||||
self.assertTrue(shape_before == shape_after)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_msa_col_global_att(msa_act, msa_mask):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_e = config.model.embeddings_and_evoformer.evoformer
|
||||
msa_col = alphafold.model.modules.MSAColumnGlobalAttention(
|
||||
c_e.msa_column_attention,
|
||||
config.model.global_config,
|
||||
name="msa_column_global_attention"
|
||||
)
|
||||
act = msa_col(msa_act=msa_act, msa_mask=msa_mask)
|
||||
return act
|
||||
|
||||
f = hk.transform(run_msa_col_global_att)
|
||||
|
||||
n_res = consts.n_res
|
||||
n_seq = consts.n_seq
|
||||
c_e = consts.c_e
|
||||
|
||||
msa_act = np.random.rand(n_seq, n_res, c_e)
|
||||
msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res))
|
||||
|
||||
# Fetch pretrained parameters (but only from one block)]
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/" +
|
||||
"msa_column_global_attention"
|
||||
)
|
||||
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
|
||||
|
||||
out_gt = f.apply(
|
||||
params, None, msa_act, msa_mask
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model.extra_msa_stack.stack.blocks[0].msa_att_col(
|
||||
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
|
||||
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
|
||||
).cpu()
|
||||
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -15,25 +15,79 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
from alphafold.model.outer_product_mean import *
|
||||
from openfold.model.outer_product_mean import OuterProductMean
|
||||
from openfold.utils.tensor_utils import tree_map
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestOuterProductMean(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
batch_size = 2
|
||||
s = 5
|
||||
n_res = 7
|
||||
c_m = 11
|
||||
c = 13
|
||||
c_z = 17
|
||||
c = 31
|
||||
|
||||
opm = OuterProductMean(c_m, c_z, c)
|
||||
opm = OuterProductMean(consts.c_m, consts.c_z, c)
|
||||
|
||||
m = torch.rand((batch_size, s, n_res, c_m))
|
||||
mask = torch.randint(0, 2, size=(batch_size, s, n_res))
|
||||
m = torch.rand(
|
||||
(consts.batch_size, consts.n_seq, consts.n_res, consts.c_m)
|
||||
)
|
||||
mask = torch.randint(
|
||||
0, 2, size=(consts.batch_size, consts.n_seq, consts.n_res)
|
||||
)
|
||||
m = opm(m, mask)
|
||||
|
||||
self.assertTrue(m.shape == (batch_size, n_res, n_res, c_z))
|
||||
self.assertTrue(
|
||||
m.shape == (consts.batch_size, consts.n_res, consts.n_res, consts.c_z)
|
||||
)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_opm_compare(self):
|
||||
def run_opm(msa_act, msa_mask):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_evo = config.model.embeddings_and_evoformer.evoformer
|
||||
opm = alphafold.model.modules.OuterProductMean(
|
||||
c_evo.outer_product_mean,
|
||||
config.model.global_config,
|
||||
consts.c_z,
|
||||
)
|
||||
act = opm(act=msa_act, mask=msa_mask)
|
||||
return act
|
||||
|
||||
f = hk.transform(run_opm)
|
||||
|
||||
n_res = consts.n_res
|
||||
n_seq = consts.n_seq
|
||||
c_m = consts.c_m
|
||||
|
||||
msa_act = np.random.rand(n_seq, n_res, c_m).astype(np.float32) * 100
|
||||
msa_mask = np.random.randint(
|
||||
low=0, high=2, size=(n_seq, n_res)
|
||||
).astype(np.float32)
|
||||
|
||||
# Fetch pretrained parameters (but only from one block)]
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/" +
|
||||
"evoformer_iteration/outer_product_mean"
|
||||
)
|
||||
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
|
||||
|
||||
out_gt = f.apply(
|
||||
params, None, msa_act, msa_mask
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model.evoformer.blocks[0].outer_product_mean(
|
||||
torch.as_tensor(msa_act).cuda(),
|
||||
mask=torch.as_tensor(msa_mask).cuda(),
|
||||
).cpu()
|
||||
|
||||
# Even when correct, OPM has large, precision-related errors. It gets
|
||||
# a special pass from consts.eps.
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 5e-4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,18 +15,26 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
from alphafold.model.pair_transition import *
|
||||
from openfold.model.pair_transition import PairTransition
|
||||
from openfold.utils.tensor_utils import tree_map
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestPairTransition(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
c_z = 5
|
||||
c_z = consts.c_z
|
||||
n = 4
|
||||
|
||||
pt = PairTransition(c_z, n)
|
||||
|
||||
batch_size = 4
|
||||
n_res = 256
|
||||
batch_size = consts.batch_size
|
||||
n_res = consts.n_res
|
||||
|
||||
z = torch.rand((batch_size, n_res, n_res, c_z))
|
||||
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
|
||||
@@ -36,6 +44,47 @@ class TestPairTransition(unittest.TestCase):
|
||||
|
||||
self.assertTrue(shape_before == shape_after)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_pair_transition(pair_act, pair_mask):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_e = config.model.embeddings_and_evoformer.evoformer
|
||||
pt = alphafold.model.modules.Transition(
|
||||
c_e.pair_transition,
|
||||
config.model.global_config,
|
||||
name="pair_transition"
|
||||
)
|
||||
act = pt(act=pair_act, mask=pair_mask)
|
||||
return act
|
||||
|
||||
f = hk.transform(run_pair_transition)
|
||||
|
||||
n_res = consts.n_res
|
||||
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
|
||||
pair_mask = np.ones((n_res, n_res)).astype(np.float32) # no mask
|
||||
|
||||
# Fetch pretrained parameters (but only from one block)]
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
|
||||
"pair_transition"
|
||||
)
|
||||
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
|
||||
|
||||
out_gt = f.apply(
|
||||
params, None, pair_act, pair_mask
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model.evoformer.blocks[0].pair_transition(
|
||||
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
|
||||
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
|
||||
).cpu()
|
||||
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -16,18 +16,28 @@ import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from alphafold.np.residue_constants import (
|
||||
from openfold.np.residue_constants import (
|
||||
restype_rigid_group_default_frame,
|
||||
restype_atom14_to_rigid_group,
|
||||
restype_atom14_mask,
|
||||
restype_atom14_rigid_group_positions,
|
||||
)
|
||||
from alphafold.model.structure_module import *
|
||||
from alphafold.model.structure_module import (
|
||||
from openfold.model.structure_module import *
|
||||
from openfold.model.structure_module import (
|
||||
_torsion_angles_to_frames,
|
||||
_frames_and_literature_positions_to_atom14_pos,
|
||||
)
|
||||
from alphafold.utils.utils import T
|
||||
from openfold.utils.affine_utils import T
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
from tests.data_utils import (
|
||||
random_affines_4x4,
|
||||
)
|
||||
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestStructureModule(unittest.TestCase):
|
||||
@@ -75,7 +85,7 @@ class TestStructureModule(unittest.TestCase):
|
||||
out = sm(s, z, f)
|
||||
|
||||
self.assertTrue(
|
||||
out["transformations"].shape == (no_layers, batch_size, n, 4, 4)
|
||||
out["frames"].shape == (no_layers, batch_size, n, 4, 4)
|
||||
)
|
||||
self.assertTrue(
|
||||
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
|
||||
@@ -190,6 +200,62 @@ class TestInvariantPointAttention(unittest.TestCase):
|
||||
|
||||
self.assertTrue(s.shape == shape_before)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_ipa_compare(self):
|
||||
def run_ipa(act, static_feat_2d, mask, affine):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
ipa = alphafold.model.folding.InvariantPointAttention(
|
||||
config.model.heads.structure_module,
|
||||
config.model.global_config,
|
||||
)
|
||||
attn = ipa(
|
||||
inputs_1d=act,
|
||||
inputs_2d=static_feat_2d,
|
||||
mask=mask,
|
||||
affine=affine
|
||||
)
|
||||
return attn
|
||||
|
||||
f = hk.transform(run_ipa)
|
||||
|
||||
n_res = consts.n_res
|
||||
c_s = consts.c_s
|
||||
c_z = consts.c_z
|
||||
|
||||
sample_act = np.random.rand(n_res, c_s)
|
||||
sample_2d = np.random.rand(n_res, n_res, c_z)
|
||||
sample_mask = np.ones((n_res, 1))
|
||||
|
||||
affines = random_affines_4x4((n_res,))
|
||||
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
|
||||
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
|
||||
transformations = T.from_4x4(
|
||||
torch.as_tensor(affines).float().cuda()
|
||||
)
|
||||
|
||||
sample_affine = quats
|
||||
|
||||
ipa_params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/structure_module/" +
|
||||
"fold_iteration/invariant_point_attention"
|
||||
)
|
||||
|
||||
out_gt = f.apply(
|
||||
ipa_params, None, sample_act, sample_2d, sample_mask, sample_affine
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt))
|
||||
|
||||
with torch.no_grad():
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model.structure_module.ipa(
|
||||
torch.as_tensor(sample_act).float().cuda(),
|
||||
torch.as_tensor(sample_2d).float().cuda(),
|
||||
transformations,
|
||||
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
|
||||
).cpu()
|
||||
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
|
||||
|
||||
|
||||
class TestAngleResnet(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
|
||||
@@ -15,23 +15,38 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
from alphafold.model.template import *
|
||||
from openfold.model.template import (
|
||||
TemplatePointwiseAttention,
|
||||
TemplatePairStack,
|
||||
)
|
||||
from openfold.utils.tensor_utils import tree_map
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
from tests.data_utils import random_template_feats
|
||||
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestTemplatePointwiseAttention(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
batch_size = 2
|
||||
s_t = 3
|
||||
c_t = 5
|
||||
c_z = 7
|
||||
batch_size = consts.batch_size
|
||||
n_seq = consts.n_seq
|
||||
c_t = consts.c_t
|
||||
c_z = consts.c_z
|
||||
c = 26
|
||||
no_heads = 13
|
||||
n = 17
|
||||
n_res = consts.n_res
|
||||
inf = 1e7
|
||||
|
||||
tpa = TemplatePointwiseAttention(c_t, c_z, c, no_heads, chunk_size=4)
|
||||
tpa = TemplatePointwiseAttention(
|
||||
c_t, c_z, c, no_heads, chunk_size=4, inf=inf
|
||||
)
|
||||
|
||||
t = torch.rand((batch_size, s_t, n, n, c_t))
|
||||
z = torch.rand((batch_size, n, n, c_z))
|
||||
t = torch.rand((batch_size, n_seq, n_res, n_res, c_t))
|
||||
z = torch.rand((batch_size, n_res, n_res, c_z))
|
||||
|
||||
z_update = tpa(t, z)
|
||||
|
||||
@@ -40,17 +55,20 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
|
||||
|
||||
class TestTemplatePairStack(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
batch_size = 2
|
||||
c_t = 5
|
||||
batch_size = consts.batch_size
|
||||
c_t = consts.c_t
|
||||
c_hidden_tri_att = 7
|
||||
c_hidden_tri_mul = 7
|
||||
no_blocks = 2
|
||||
no_heads = 4
|
||||
pt_inner_dim = 15
|
||||
dropout = 0.25
|
||||
n_templ = 3
|
||||
n_res = 5
|
||||
n_templ = consts.n_templ
|
||||
n_res = consts.n_res
|
||||
blocks_per_ckpt = None
|
||||
chunk_size = 4
|
||||
inf=1e7
|
||||
eps=1e-7
|
||||
|
||||
tpe = TemplatePairStack(
|
||||
c_t,
|
||||
@@ -60,7 +78,10 @@ class TestTemplatePairStack(unittest.TestCase):
|
||||
no_heads=no_heads,
|
||||
pair_transition_n=pt_inner_dim,
|
||||
dropout_rate=dropout,
|
||||
blocks_per_ckpt=None,
|
||||
chunk_size=chunk_size,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
t = torch.rand((batch_size, n_templ, n_res, n_res, c_t))
|
||||
@@ -71,7 +92,98 @@ class TestTemplatePairStack(unittest.TestCase):
|
||||
|
||||
self.assertTrue(shape_before == shape_after)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_template_pair_stack(pair_act, pair_mask):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_ee = config.model.embeddings_and_evoformer
|
||||
tps = alphafold.model.modules.TemplatePairStack(
|
||||
c_ee.template.template_pair_stack,
|
||||
config.model.global_config,
|
||||
name="template_pair_stack"
|
||||
)
|
||||
act = tps(pair_act, pair_mask, is_training=False)
|
||||
ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
|
||||
act = ln(act)
|
||||
return act
|
||||
|
||||
f = hk.transform(run_template_pair_stack)
|
||||
|
||||
n_res = consts.n_res
|
||||
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_t).astype(np.float32)
|
||||
pair_mask = np.random.randint(
|
||||
low=0, high=2, size=(n_res, n_res)
|
||||
).astype(np.float32)
|
||||
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/template_embedding/" +
|
||||
"single_template_embedding/template_pair_stack"
|
||||
)
|
||||
params.update(compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/template_embedding/" +
|
||||
"single_template_embedding/output_layer_norm"
|
||||
))
|
||||
|
||||
out_gt = f.apply(
|
||||
params, jax.random.PRNGKey(42), pair_act, pair_mask
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model.template_pair_stack(
|
||||
torch.as_tensor(pair_act).cuda(),
|
||||
torch.as_tensor(pair_mask).cuda(),
|
||||
_mask_trans=False,
|
||||
).cpu()
|
||||
|
||||
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
|
||||
|
||||
class Template(unittest.TestCase):
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def test_template_embedding(pair, batch, mask_2d):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
te = alphafold.model.modules.TemplateEmbedding(
|
||||
config.model.embeddings_and_evoformer.template,
|
||||
config.model.global_config
|
||||
)
|
||||
act = te(pair, batch, mask_2d, is_training=False)
|
||||
return act
|
||||
|
||||
f = hk.transform(test_template_embedding)
|
||||
|
||||
n_res = consts.n_res
|
||||
n_templ = consts.n_templ
|
||||
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
|
||||
batch = random_template_feats(n_templ, n_res)
|
||||
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
|
||||
|
||||
# Fetch pretrained parameters (but only from one block)]
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/template_embedding"
|
||||
)
|
||||
|
||||
out_gt = f.apply(
|
||||
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt))
|
||||
|
||||
inds = np.random.randint(0, 21, (n_res,))
|
||||
batch["target_feat"] = np.eye(22)[inds]
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
out_repro = model.embed_templates(
|
||||
{k:torch.as_tensor(v).cuda() for k,v in batch.items()},
|
||||
torch.as_tensor(pair_act).cuda(),
|
||||
torch.as_tensor(pair_mask).cuda(),
|
||||
templ_dim=0,
|
||||
)
|
||||
out_repro = out_repro["template_pair_embedding"]
|
||||
out_repro = out_repro.cpu()
|
||||
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,12 +15,21 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
from alphafold.model.triangular_attention import *
|
||||
from openfold.model.triangular_attention import TriangleAttention
|
||||
from openfold.utils.tensor_utils import tree_map
|
||||
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestTriangularAttention(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
c_z = 2
|
||||
c_z = consts.c_z
|
||||
c = 12
|
||||
no_heads = 4
|
||||
starting = True
|
||||
@@ -32,8 +41,8 @@ class TestTriangularAttention(unittest.TestCase):
|
||||
starting
|
||||
)
|
||||
|
||||
batch_size = 4
|
||||
n_res = 7
|
||||
batch_size = consts.batch_size
|
||||
n_res = consts.n_res
|
||||
|
||||
x = torch.rand((batch_size, n_res, n_res, c_z))
|
||||
shape_before = x.shape
|
||||
@@ -42,9 +51,61 @@ class TestTriangularAttention(unittest.TestCase):
|
||||
|
||||
self.assertTrue(shape_before == shape_after)
|
||||
|
||||
def _tri_att_compare(self, starting=False):
|
||||
name = (
|
||||
"triangle_attention_" +
|
||||
("starting" if starting else "ending") +
|
||||
"_node"
|
||||
)
|
||||
def run_tri_att(pair_act, pair_mask):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_e = config.model.embeddings_and_evoformer.evoformer
|
||||
tri_att = alphafold.model.modules.TriangleAttention(
|
||||
c_e.triangle_attention_starting_node if starting else
|
||||
c_e.triangle_attention_ending_node,
|
||||
config.model.global_config,
|
||||
name=name,
|
||||
)
|
||||
act = tri_att(pair_act=pair_act, pair_mask=pair_mask)
|
||||
return act
|
||||
|
||||
f = hk.transform(run_tri_att)
|
||||
|
||||
n_res = consts.n_res
|
||||
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_z)
|
||||
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
|
||||
|
||||
# Fetch pretrained parameters (but only from one block)]
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
|
||||
name
|
||||
)
|
||||
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
|
||||
|
||||
out_gt = f.apply(
|
||||
params, None, pair_act, pair_mask
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
module = (
|
||||
model.evoformer.blocks[0].tri_att_start if starting else
|
||||
model.evoformer.blocks[0].tri_att_end
|
||||
)
|
||||
out_repro = module(
|
||||
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
|
||||
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
|
||||
).cpu()
|
||||
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
|
||||
|
||||
def test_tri_att_end_compare(self):
|
||||
self._tri_att_compare()
|
||||
|
||||
def test_tri_att_start_compare(self):
|
||||
self._tri_att_compare(starting=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
|
||||
|
||||
unittest.main()
|
||||
|
||||
@@ -15,12 +15,20 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
from alphafold.model.triangular_multiplicative_update import *
|
||||
from openfold.model.triangular_multiplicative_update import *
|
||||
from openfold.utils.tensor_utils import tree_map
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.config import consts
|
||||
|
||||
if(compare_utils.alphafold_is_installed()):
|
||||
alphafold = compare_utils.import_alphafold()
|
||||
import jax
|
||||
import haiku as hk
|
||||
|
||||
|
||||
class TestTriangularMultiplicativeUpdate(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
c_z = 7
|
||||
c_z = consts.c_z
|
||||
c = 11
|
||||
outgoing = True
|
||||
|
||||
@@ -30,8 +38,8 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
|
||||
outgoing,
|
||||
)
|
||||
|
||||
n_res = 5
|
||||
batch_size = 2
|
||||
n_res = consts.c_z
|
||||
batch_size = consts.batch_size
|
||||
|
||||
x = torch.rand((batch_size, n_res, n_res, c_z))
|
||||
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
|
||||
@@ -41,6 +49,63 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
|
||||
|
||||
self.assertTrue(shape_before == shape_after)
|
||||
|
||||
def _tri_mul_compare(self, incoming=False):
|
||||
name = (
|
||||
"triangle_multiplication_" +
|
||||
("incoming" if incoming else "outgoing")
|
||||
)
|
||||
def run_tri_mul(pair_act, pair_mask):
|
||||
config = compare_utils.get_alphafold_config()
|
||||
c_e = config.model.embeddings_and_evoformer.evoformer
|
||||
tri_mul = alphafold.model.modules.TriangleMultiplication(
|
||||
c_e.triangle_multiplication_incoming if incoming else
|
||||
c_e.triangle_multiplication_outgoing,
|
||||
config.model.global_config,
|
||||
name=name,
|
||||
)
|
||||
act = tri_mul(act=pair_act, mask=pair_mask)
|
||||
return act
|
||||
|
||||
f = hk.transform(run_tri_mul)
|
||||
|
||||
n_res = consts.n_res
|
||||
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
|
||||
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
|
||||
pair_mask = pair_mask.astype(np.float32)
|
||||
|
||||
# Fetch pretrained parameters (but only from one block)]
|
||||
params = compare_utils.fetch_alphafold_module_weights(
|
||||
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
|
||||
name
|
||||
)
|
||||
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
|
||||
|
||||
out_gt = f.apply(
|
||||
params, None, pair_act, pair_mask
|
||||
).block_until_ready()
|
||||
out_gt = torch.as_tensor(np.array(out_gt))
|
||||
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
module = (
|
||||
model.evoformer.blocks[0].tri_mul_in if incoming else
|
||||
model.evoformer.blocks[0].tri_mul_out
|
||||
)
|
||||
out_repro = module(
|
||||
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
|
||||
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
|
||||
).cpu()
|
||||
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_tri_mul_out_compare(self):
|
||||
self._tri_mul_compare()
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_tri_mul_in_compare(self):
|
||||
self._tri_mul_compare(incoming=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -16,7 +16,8 @@ import math
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
from alphafold.utils.utils import *
|
||||
from openfold.utils.affine_utils import *
|
||||
from openfold.utils.tensor_utils import *
|
||||
|
||||
|
||||
X_90_ROT = torch.tensor([
|
||||
|
||||
Reference in New Issue
Block a user