From 836396b763afafde90ab4e09ade9d2db06548e14 Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Mon, 24 Nov 2025 16:37:43 +0800 Subject: [PATCH] Update run_pplm-ppi.py --- run_pplm-ppi.py | 93 +++++++++++++++++++------------------------------ 1 file changed, 35 insertions(+), 58 deletions(-) diff --git a/run_pplm-ppi.py b/run_pplm-ppi.py index 4a6b978..46c24da 100644 --- a/run_pplm-ppi.py +++ b/run_pplm-ppi.py @@ -4,20 +4,20 @@ import torch import argparse from pplm_ppi import PPLM_PPI - def main(): parser = argparse.ArgumentParser(description="Protein-Protein Interaction Prediction", epilog="v0.0.1") - parser.add_argument("seq_pairs_path", + parser.add_argument("seqA_path", action="store", - help="Path of paired sequence list") + help="Location of sequence A") - parser.add_argument("output_path", + parser.add_argument("seqB_path", action="store", - help="Path of output file") + help="Location of sequence B") parser.add_argument("--gpu_id", + "-gpu", type=int, default=0, help="gpu device specified", @@ -29,71 +29,55 @@ def main(): assigned_device = "cuda:" + str(args.gpu_id) device = assigned_device if torch.cuda.is_available() else "cpu" - script_dir = os.path.dirname(os.path.abspath(__file__)) - models_path = [os.path.join(script_dir, "pplm_ppi/models/model" + str(i) + ".pkl") for i in range(1, 6)] + model_weights = torch.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights/ppi_models.pkl"), map_location=device) model = PPLM_PPI() model.to(device) ### Read sequences ### + seqA = read_sequence(args.seqA_path) + seqB = read_sequence(args.seqB_path) - last_flag = '' - last_seq = '' - seq_list = [] - for line in open(args.seq_pairs_path).readlines(): - if line.startswith('>'): - if last_seq != '': - seq_list.append([last_flag, last_seq.split(':')[0], last_seq.split(':')[1]]) - last_seq = '' - last_flag = line.strip()[1:] - - elif len(line.strip()) != 0: - last_seq += line.strip() - - if last_seq != '': - seq_list.append([last_flag, last_seq.split(':')[0], last_seq.split(':')[1]]) - - print("Number of paired sequences:", len(seq_list)) + ### Get pplm features ### + mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B, max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B = get_pplm_features(seqA, seqB, device) ### Prediction ### - score_list = [] - for i in range(len(seq_list)): - flag = seq_list[i][0] - seqA = seq_list[i][1] - seqB = seq_list[i][2] + with torch.no_grad(): + predictions_list = [] + for model_weight in model_weights['mean']: + model.load_state_dict(model_weight) + predictions = model(mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B) + predictions_ = model(mean_inter_attn, mean_attn_BB, mean_attn_AA, mean_embed_B, mean_embed_A) + predictions = (predictions + predictions_) / 2 + predictions_list.append(predictions) - mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B, max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B = get_pplm_features(seqA, seqB, device) + for model_weight in model_weights['max']: + model.load_state_dict(model_weight) + predictions = model(max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B) + predictions_ = model(max_inter_attn, max_attn_BB, max_attn_AA, max_embed_B, max_embed_A) + predictions = (predictions + predictions_) / 2 + predictions_list.append(predictions) - with torch.no_grad(): - predictions_list = [] - for model_path in models_path: - checkpoint = torch.load(model_path, map_location=device) - model.load_state_dict(checkpoint["net"]) + predictions = torch.stack(predictions_list) + predictions = torch.mean(predictions, dim=0).squeeze().cpu().numpy() - predictions = model(mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B, max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B) - predictions_list.append(predictions) - - predictions = torch.stack(predictions_list) - predictions = torch.mean(predictions, dim=0).squeeze().cpu().numpy() - - score_list.append([flag, predictions]) - - ### Write results ### - with open(args.output_path, "w") as f: - for i in range(len(score_list)): - flag, prediction = score_list[i] - f.write(">" + flag + "\n") - f.write(f"{prediction:.6f}" + "\n") + print("Predicted interaction score:", predictions) +def read_sequence(seq_path): + seq = "" + for line in open(seq_path, "r").readlines(): + if not line.startswith(">"): + seq += line.strip() + return seq def get_pplm_features(seqA, seqB, device): mian_path = os.path.dirname(__file__) sys.path.append(os.path.abspath(mian_path)) from pplm import PPLM, Alphabet - model_location = os.path.join(mian_path, 'pplm/models/', 'pplm_t33_650M.pt') + model_location = os.path.join(mian_path, 'weights/', 'pplm_t33_650M.pt') ##### Loading PPLM Model ##### alphabet = Alphabet.from_architecture() @@ -146,13 +130,6 @@ def get_pplm_features(seqA, seqB, device): return mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B, max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B -def read_sequence(seq_path): - seq = "" - for line in open(seq_path, "r").readlines(): - if not line.startswith(">"): - seq += line.strip() - - return seq - if __name__ == "__main__": main() +