revert accidentally removed files

This commit is contained in:
Dima Molodenskiy
2023-12-20 13:14:55 +01:00
parent 503f6e945c
commit 692d9baf64
2 changed files with 114 additions and 0 deletions

View File

@@ -0,0 +1,65 @@
import shutil
import tempfile
import unittest
import sys
import os
import torch
from unifold.modules.alphafold import AlphaFold
from unifold.alphalink_inference import prepare_model_runner
from unifold.alphalink_inference import alphalink_prediction
from unifold.dataset import process_ap
from unifold.config import model_config
from alphapulldown.utils import create
from alphapulldown.run_multimer_jobs import predict_individual_jobs,create_custom_jobs
class _TestBase(unittest.TestCase):
def setUp(self) -> None:
self.crosslink_file_path = os.path.join(os.path.dirname(__file__),"test_data/example_crosslink.pkl.gz")
self.config_data_model_name = 'model_5_ptm_af2'
self.config_alphafold_model_name = 'multimer_af2_crop'
class TestCrosslinkInference(_TestBase):
def setUp(self) -> None:
super().setUp()
self.output_dir = tempfile.mkdtemp()
self.monomer_object_path = os.path.join(os.path.dirname(__file__),"test_data/")
self.protein_list = os.path.join(os.path.dirname(__file__),"test_data/example_crosslinked_pair.txt")
self.alphalink_weight = '/g/alphafold/alphalink_weights/AlphaLink-Multimer_SDA_v3.pt'
self.multimerobject = create_custom_jobs(self.protein_list,self.monomer_object_path,job_index=1,pair_msa=True)[0]
def test1_process_features(self):
"""Test whether the PyTorch model of AlphaLink can be initiated successfully"""
configs = model_config(self.config_data_model_name)
processed_features,_ = process_ap(config=configs.data,
features=self.multimerobject.feature_dict,
mode="predict",labels=None,
seed=42,batch_idx=None,
data_idx=None,is_distillation=False,
chain_id_map = self.multimerobject.chain_id_map,
crosslinks = self.crosslink_file_path
)
def test2_load_AlphaLink_weights(self):
"""This is testing weither loading the PyTorch checkpoint is sucessfull"""
if torch.cuda.is_available():
model_device = 'cuda:0'
else:
model_device = 'cpu'
config = model_config(self.config_alphafold_model_name)
model = AlphaFold(config)
state_dict = torch.load(self.alphalink_weight)["ema"]["params"]
state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.to(model_device)
def test3_test_inference(self):
if torch.cuda.is_available():
model_device = 'cuda:0'
else:
model_device = 'cpu'
model = prepare_model_runner(self.alphalink_weight,model_device=model_device)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,49 @@
import unittest
from unifold.dataset import calculate_offsets,create_xl_features,bin_xl
from alphafold.data.pipeline_multimer import _FastaChain
import numpy as np
import gzip,pickle
import torch
class TestCreateObjects(unittest.TestCase):
def setUp(self) -> None:
self.crosslink_info ="./test/test_data/test_xl_input.pkl.gz"
self.asym_id = torch.tensor([1]*10 + [2]*25 + [3]*40)
self.chain_id_map = {
"A":_FastaChain(sequence='',description='chain1'),
"B":_FastaChain(sequence='',description='chain2'),
"C":_FastaChain(sequence='',description='chain3')
}
self.bins = torch.arange(0,1.05,0.05)
return super().setUp()
def test1_calculate_offsets(self):
offsets = calculate_offsets(self.asym_id)
offsets = offsets.tolist()
expected_offsets = [0,10,35,75]
self.assertEqual(offsets,expected_offsets)
def test2_create_xl_inputs(self):
offsets = calculate_offsets(self.asym_id)
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
expected_xl = torch.tensor([[10,35,0.01],
[3,27,0.01],
[5,56,0.01],
[20,65,0.01]])
self.assertTrue(torch.equal(xl,expected_xl))
def test3_bin_xl(self):
offsets = calculate_offsets(self.asym_id)
xl_pickle = pickle.load(gzip.open(self.crosslink_info,'rb'))
xl = create_xl_features(xl_pickle,offsets,chain_id_map = self.chain_id_map)
num_res = len(self.asym_id)
xl = bin_xl(xl,num_res)
expected_xl = np.zeros((num_res,num_res,1))
expected_xl[3,27,0] = expected_xl[27,3,0] = torch.bucketize(0.99,self.bins)
expected_xl[10,35,0] = expected_xl[35,10,0] = torch.bucketize(0.99,self.bins)
expected_xl[5,56,0] = expected_xl[56,5,0] = torch.bucketize(0.99,self.bins)
expected_xl[20,65,0] = expected_xl[65,20,0] = torch.bucketize(0.99,self.bins)
self.assertTrue(np.array_equal(xl,expected_xl))
if __name__ == "__main__":
unittest.main()