Refactor model loading and update argument parser

Updated argument parser description and removed unused imports. Refactored model loading to use a single weights file instead of multiple model paths.
This commit is contained in:
Jun Liu
2025-11-24 16:41:29 +08:00
committed by GitHub
parent 29ad4cdc0f
commit 3ee6cb493c

View File

@@ -1,16 +1,10 @@
import os
import sys
import torch
import argparse
# mian_path = os.path.dirname(__file__) + "/../"
# sys.path.append(os.path.abspath(mian_path))
# import pplm_ppi
from pplm_affinity import PPLM_Affinity
def main():
parser = argparse.ArgumentParser(description="Protein-Protein Interaction Prediction",
parser = argparse.ArgumentParser(description="Protein-Protein Biniding Affinity Prediction",
epilog="v0.0.1")
parser.add_argument("receptor_seqs_path",
@@ -34,9 +28,10 @@ def main():
assigned_device = "cuda:" + str(args.gpu_id)
device = assigned_device if torch.cuda.is_available() else "cpu"
script_dir = os.path.dirname(os.path.abspath(__file__))
models_path = [os.path.join(script_dir, "pplm_affinity/models/model_cv" + str(i) + ".pkl") for i in range(0, 5)]
cv_models_weight = torch.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights/affinity_models.pkl"), map_location=device)
model = PPLM_Affinity(cv_models_weight["pplm_param"])
model.to(device)
### Read sequences ###
seqA = read_sequence(args.receptor_seqs_path)
@@ -45,19 +40,11 @@ def main():
### Prediction ###
with torch.no_grad():
predictions_list = []
for model_path in models_path:
checkpoint = torch.load(model_path, map_location=device)
pplm_model_param = checkpoint["pplm_param"]
model_state = checkpoint["model_state_dict"]
model = PPLM_Affinity(pplm_model_param)
model.load_state_dict(model_state)
model.to(device)
for cv in range(0, 5):
model.load_state_dict(cv_models_weight['cv' + str(cv)])
predictions = model(seqA, seqB, device)
predictions2 = model(seqB, seqA, device)
predictions = (predictions + predictions2) / 2
predictions_list.append(predictions)
predictions = torch.stack(predictions_list)
@@ -75,4 +62,4 @@ def read_sequence(seq_path):
if __name__ == "__main__":
main()