mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
# Copyright 2021 AlQuraishi Laboratory
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
import numpy as np
|
|
import unittest
|
|
|
|
from openfold.config import model_config
|
|
from openfold.model.model import AlphaFold
|
|
from openfold.utils.import_weights import import_jax_weights_
|
|
|
|
|
|
class TestImportWeights(unittest.TestCase):
|
|
def test_import_jax_weights_(self):
|
|
npz_path = "openfold/resources/params/params_model_1_ptm.npz"
|
|
|
|
c = model_config("model_1_ptm")
|
|
c.globals.blocks_per_ckpt = None
|
|
model = AlphaFold(c)
|
|
|
|
import_jax_weights_(
|
|
model,
|
|
npz_path,
|
|
)
|
|
|
|
data = np.load(npz_path)
|
|
prefix = "alphafold/alphafold_iteration/"
|
|
|
|
test_pairs = [
|
|
# Normal linear weight
|
|
(
|
|
torch.as_tensor(
|
|
data[
|
|
prefix + "structure_module/initial_projection//weights"
|
|
]
|
|
).transpose(-1, -2),
|
|
model.structure_module.linear_in.weight,
|
|
),
|
|
# Normal layer norm param
|
|
(
|
|
torch.as_tensor(
|
|
data[prefix + "evoformer/prev_pair_norm//offset"],
|
|
),
|
|
model.recycling_embedder.layer_norm_z.bias,
|
|
),
|
|
# From a stack
|
|
(
|
|
torch.as_tensor(
|
|
data[
|
|
prefix
|
|
+ (
|
|
"evoformer/evoformer_iteration/outer_product_mean/"
|
|
"left_projection//weights"
|
|
)
|
|
][1].transpose(-1, -2)
|
|
),
|
|
model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight,
|
|
),
|
|
]
|
|
|
|
for w_alpha, w_repro in test_pairs:
|
|
self.assertTrue(torch.all(w_alpha == w_repro))
|