Create run_pplm-affinity.py

This commit is contained in:
Jun Liu
2025-04-28 20:35:12 +08:00
committed by GitHub
parent 8876926f82
commit 5339911fed

78
run_pplm-affinity.py Normal file
View File

@@ -0,0 +1,78 @@
import os
import sys
import torch
import argparse
# mian_path = os.path.dirname(__file__) + "/../"
# sys.path.append(os.path.abspath(mian_path))
# import pplm_ppi
from pplm_affinity import PPLM_Affinity
def main():
parser = argparse.ArgumentParser(description="Protein-Protein Interaction Prediction",
epilog="v0.0.1")
parser.add_argument("receptor_seqs_path",
action="store",
help="Location of receptor sequences")
parser.add_argument("ligand_seqs_path",
action="store",
help="Location of ligand sequences")
parser.add_argument("--gpu_id",
"-gpu",
type=int,
default=0,
help="gpu device specified",
)
args = parser.parse_args()
### Define model ###
assigned_device = "cuda:" + str(args.gpu_id)
device = assigned_device if torch.cuda.is_available() else "cpu"
script_dir = os.path.dirname(os.path.abspath(__file__))
models_path = [os.path.join(script_dir, "pplm_affinity/models/model_cv" + str(i) + ".pkl") for i in range(0, 5)]
### Read sequences ###
seqA = read_sequence(args.receptor_seqs_path)
seqB = read_sequence(args.ligand_seqs_path)
### Prediction ###
with torch.no_grad():
predictions_list = []
for model_path in models_path:
checkpoint = torch.load(model_path, map_location=device)
pplm_model_param = checkpoint["pplm_param"]
model_state = checkpoint["model_state_dict"]
model = PPLM_Affinity(pplm_model_param)
model.load_state_dict(model_state)
model.to(device)
predictions = model(seqA, seqB, device)
predictions2 = model(seqB, seqA, device)
predictions = (predictions + predictions2) / 2
predictions_list.append(predictions)
predictions = torch.stack(predictions_list)
predictions = torch.mean(predictions, dim=0).squeeze().cpu().numpy()
print("Predicted binding affinity:", predictions)
def read_sequence(seq_path):
seq = ""
for line in open(seq_path, "r").readlines():
if not line.startswith(">"):
seq += line.strip()
return seq
if __name__ == "__main__":
main()