mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
revert accidentally removed files
This commit is contained in:
65
test/test_crosslink_inference.py
Normal file
65
test/test_crosslink_inference.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
from unifold.modules.alphafold import AlphaFold
|
||||
from unifold.alphalink_inference import prepare_model_runner
|
||||
from unifold.alphalink_inference import alphalink_prediction
|
||||
from unifold.dataset import process_ap
|
||||
from unifold.config import model_config
|
||||
from alphapulldown.utils import create
|
||||
from alphapulldown.run_multimer_jobs import predict_individual_jobs,create_custom_jobs
|
||||
|
||||
class _TestBase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.crosslink_file_path = os.path.join(os.path.dirname(__file__),"test_data/example_crosslink.pkl.gz")
|
||||
self.config_data_model_name = 'model_5_ptm_af2'
|
||||
self.config_alphafold_model_name = 'multimer_af2_crop'
|
||||
|
||||
class TestCrosslinkInference(_TestBase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.output_dir = tempfile.mkdtemp()
|
||||
self.monomer_object_path = os.path.join(os.path.dirname(__file__),"test_data/")
|
||||
self.protein_list = os.path.join(os.path.dirname(__file__),"test_data/example_crosslinked_pair.txt")
|
||||
self.alphalink_weight = '/g/alphafold/alphalink_weights/AlphaLink-Multimer_SDA_v3.pt'
|
||||
self.multimerobject = create_custom_jobs(self.protein_list,self.monomer_object_path,job_index=1,pair_msa=True)[0]
|
||||
|
||||
def test1_process_features(self):
|
||||
"""Test whether the PyTorch model of AlphaLink can be initiated successfully"""
|
||||
configs = model_config(self.config_data_model_name)
|
||||
processed_features,_ = process_ap(config=configs.data,
|
||||
features=self.multimerobject.feature_dict,
|
||||
mode="predict",labels=None,
|
||||
seed=42,batch_idx=None,
|
||||
data_idx=None,is_distillation=False,
|
||||
chain_id_map = self.multimerobject.chain_id_map,
|
||||
crosslinks = self.crosslink_file_path
|
||||
)
|
||||
|
||||
def test2_load_AlphaLink_weights(self):
|
||||
"""This is testing weither loading the PyTorch checkpoint is sucessfull"""
|
||||
if torch.cuda.is_available():
|
||||
model_device = 'cuda:0'
|
||||
else:
|
||||
model_device = 'cpu'
|
||||
|
||||
config = model_config(self.config_alphafold_model_name)
|
||||
model = AlphaFold(config)
|
||||
state_dict = torch.load(self.alphalink_weight)["ema"]["params"]
|
||||
state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()}
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(model_device)
|
||||
|
||||
def test3_test_inference(self):
|
||||
if torch.cuda.is_available():
|
||||
model_device = 'cuda:0'
|
||||
else:
|
||||
model_device = 'cpu'
|
||||
model = prepare_model_runner(self.alphalink_weight,model_device=model_device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
49
test/test_crosslink_input.py
Normal file
49
test/test_crosslink_input.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import unittest
|
||||
from unifold.dataset import calculate_offsets,create_xl_features,bin_xl
|
||||
from alphafold.data.pipeline_multimer import _FastaChain
|
||||
import numpy as np
|
||||
import gzip,pickle
|
||||
import torch
|
||||
class TestCreateObjects(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.crosslink_info ="./test/test_data/test_xl_input.pkl.gz"
|
||||
self.asym_id = torch.tensor([1]*10 + [2]*25 + [3]*40)
|
||||
self.chain_id_map = {
|
||||
"A":_FastaChain(sequence='',description='chain1'),
|
||||
"B":_FastaChain(sequence='',description='chain2'),
|
||||
"C":_FastaChain(sequence='',description='chain3')
|
||||
}
|
||||
self.bins = torch.arange(0,1.05,0.05)
|
||||
return super().setUp()
|
||||
|
||||
def test1_calculate_offsets(self):
|
||||
offsets = calculate_offsets(self.asym_id)
|
||||
offsets = offsets.tolist()
|
||||
expected_offsets = [0,10,35,75]
|
||||
self.assertEqual(offsets,expected_offsets)
|
||||
|
||||
def test2_create_xl_inputs(self):
|
||||
offsets = calculate_offsets(self.asym_id)
|
||||
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
|
||||
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
|
||||
expected_xl = torch.tensor([[10,35,0.01],
|
||||
[3,27,0.01],
|
||||
[5,56,0.01],
|
||||
[20,65,0.01]])
|
||||
self.assertTrue(torch.equal(xl,expected_xl))
|
||||
|
||||
def test3_bin_xl(self):
|
||||
offsets = calculate_offsets(self.asym_id)
|
||||
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
|
||||
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
|
||||
num_res = len(self.asym_id)
|
||||
xl = bin_xl(xl,num_res)
|
||||
expected_xl = np.zeros((num_res,num_res,1))
|
||||
expected_xl[3,27,0] = expected_xl[27,3,0] = torch.bucketize(0.99,self.bins)
|
||||
expected_xl[10,35,0] = expected_xl[35,10,0] = torch.bucketize(0.99,self.bins)
|
||||
expected_xl[5,56,0] = expected_xl[56,5,0] = torch.bucketize(0.99,self.bins)
|
||||
expected_xl[20,65,0] = expected_xl[65,20,0] = torch.bucketize(0.99,self.bins)
|
||||
self.assertTrue(np.array_equal(xl,expected_xl))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user