Files
openfold/tests/test_import_weights.py
2023-11-03 14:26:18 -04:00

96 lines
3.0 KiB
Python

# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import numpy as np
import unittest
from pathlib import Path
from tests.config import consts
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_, import_openfold_weights_
class TestImportWeights(unittest.TestCase):
def test_import_jax_weights_(self):
npz_path = Path(__file__).parent.resolve() / f"../openfold/resources/params/params_{consts.model}.npz"
c = model_config(consts.model)
c.globals.blocks_per_ckpt = None
model = AlphaFold(c)
model.eval()
import_jax_weights_(
model,
npz_path,
version=consts.model
)
data = np.load(npz_path)
prefix = "alphafold/alphafold_iteration/"
test_pairs = [
# Normal linear weight
(
torch.as_tensor(
data[
prefix + "structure_module/initial_projection//weights"
]
).transpose(-1, -2),
model.structure_module.linear_in.weight,
),
# Normal layer norm param
(
torch.as_tensor(
data[prefix + "evoformer/prev_pair_norm//offset"],
),
model.recycling_embedder.layer_norm_z.bias,
),
# From a stack
(
torch.as_tensor(
data[
prefix
+ (
"evoformer/evoformer_iteration/outer_product_mean/"
"left_projection//weights"
)
][1].transpose(-1, -2)
),
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,
),
]
for w_alpha, w_repro in test_pairs:
self.assertTrue(torch.all(w_alpha == w_repro))
def test_import_openfold_weights_(self):
model_name = 'initial_training'
pt_path = Path(__file__).parent.resolve() / f"../openfold/resources/openfold_params/{model_name}.pt"
if os.path.exists(pt_path):
c = model_config(model_name)
c.globals.blocks_per_ckpt = None
model = AlphaFold(c)
model.eval()
d = torch.load(pt_path)
import_openfold_weights_(
model=model,
state_dict=d,
)