Files
openfold/tests/test_import_weights.py
2022-04-13 23:18:57 -04:00

74 lines
2.3 KiB
Python

# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import numpy as np
import unittest
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
class TestImportWeights(unittest.TestCase):
def test_import_jax_weights_(self):
npz_path = "openfold/resources/params/params_model_1_ptm.npz"
c = model_config("model_1_ptm")
c.globals.blocks_per_ckpt = None
model = AlphaFold(c)
import_jax_weights_(
model,
npz_path,
)
data = np.load(npz_path)
prefix = "alphafold/alphafold_iteration/"
test_pairs = [
# Normal linear weight
(
torch.as_tensor(
data[
prefix + "structure_module/initial_projection//weights"
]
).transpose(-1, -2),
model.structure_module.linear_in.weight,
),
# Normal layer norm param
(
torch.as_tensor(
data[prefix + "evoformer/prev_pair_norm//offset"],
),
model.recycling_embedder.layer_norm_z.bias,
),
# From a stack
(
torch.as_tensor(
data[
prefix
+ (
"evoformer/evoformer_iteration/outer_product_mean/"
"left_projection//weights"
)
][1].transpose(-1, -2)
),
model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight,
),
]
for w_alpha, w_repro in test_pairs:
self.assertTrue(torch.all(w_alpha == w_repro))