From 3ee6cb493c75b35cedde7dbae5ea808869f6de10 Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Mon, 24 Nov 2025 16:41:29 +0800 Subject: [PATCH] Refactor model loading and update argument parser Updated argument parser description and removed unused imports. Refactored model loading to use a single weights file instead of multiple model paths. --- run_pplm-affinity.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/run_pplm-affinity.py b/run_pplm-affinity.py index a90245a..9d911bd 100644 --- a/run_pplm-affinity.py +++ b/run_pplm-affinity.py @@ -1,16 +1,10 @@ import os -import sys import torch import argparse -# mian_path = os.path.dirname(__file__) + "/../" -# sys.path.append(os.path.abspath(mian_path)) - -# import pplm_ppi from pplm_affinity import PPLM_Affinity - def main(): - parser = argparse.ArgumentParser(description="Protein-Protein Interaction Prediction", + parser = argparse.ArgumentParser(description="Protein-Protein Biniding Affinity Prediction", epilog="v0.0.1") parser.add_argument("receptor_seqs_path", @@ -34,9 +28,10 @@ def main(): assigned_device = "cuda:" + str(args.gpu_id) device = assigned_device if torch.cuda.is_available() else "cpu" - script_dir = os.path.dirname(os.path.abspath(__file__)) - models_path = [os.path.join(script_dir, "pplm_affinity/models/model_cv" + str(i) + ".pkl") for i in range(0, 5)] + cv_models_weight = torch.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights/affinity_models.pkl"), map_location=device) + model = PPLM_Affinity(cv_models_weight["pplm_param"]) + model.to(device) ### Read sequences ### seqA = read_sequence(args.receptor_seqs_path) @@ -45,19 +40,11 @@ def main(): ### Prediction ### with torch.no_grad(): predictions_list = [] - for model_path in models_path: - checkpoint = torch.load(model_path, map_location=device) - pplm_model_param = checkpoint["pplm_param"] - model_state = checkpoint["model_state_dict"] - - model = PPLM_Affinity(pplm_model_param) - model.load_state_dict(model_state) - model.to(device) - + for cv in range(0, 5): + model.load_state_dict(cv_models_weight['cv' + str(cv)]) predictions = model(seqA, seqB, device) predictions2 = model(seqB, seqA, device) predictions = (predictions + predictions2) / 2 - predictions_list.append(predictions) predictions = torch.stack(predictions_list) @@ -75,4 +62,4 @@ def read_sequence(seq_path): if __name__ == "__main__": main() - +