mirror of
https://github.com/junliu621/PPLM.git
synced 2026-06-04 06:14:23 +08:00
Update run_pplm-ppi.py
This commit is contained in:
@@ -4,20 +4,20 @@ import torch
|
||||
import argparse
|
||||
from pplm_ppi import PPLM_PPI
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Protein-Protein Interaction Prediction",
|
||||
epilog="v0.0.1")
|
||||
|
||||
parser.add_argument("seq_pairs_path",
|
||||
parser.add_argument("seqA_path",
|
||||
action="store",
|
||||
help="Path of paired sequence list")
|
||||
help="Location of sequence A")
|
||||
|
||||
parser.add_argument("output_path",
|
||||
parser.add_argument("seqB_path",
|
||||
action="store",
|
||||
help="Path of output file")
|
||||
help="Location of sequence B")
|
||||
|
||||
parser.add_argument("--gpu_id",
|
||||
"-gpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="gpu device specified",
|
||||
@@ -29,71 +29,55 @@ def main():
|
||||
assigned_device = "cuda:" + str(args.gpu_id)
|
||||
device = assigned_device if torch.cuda.is_available() else "cpu"
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
models_path = [os.path.join(script_dir, "pplm_ppi/models/model" + str(i) + ".pkl") for i in range(1, 6)]
|
||||
model_weights = torch.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights/ppi_models.pkl"), map_location=device)
|
||||
|
||||
model = PPLM_PPI()
|
||||
model.to(device)
|
||||
|
||||
### Read sequences ###
|
||||
seqA = read_sequence(args.seqA_path)
|
||||
seqB = read_sequence(args.seqB_path)
|
||||
|
||||
last_flag = ''
|
||||
last_seq = ''
|
||||
seq_list = []
|
||||
for line in open(args.seq_pairs_path).readlines():
|
||||
if line.startswith('>'):
|
||||
if last_seq != '':
|
||||
seq_list.append([last_flag, last_seq.split(':')[0], last_seq.split(':')[1]])
|
||||
last_seq = ''
|
||||
last_flag = line.strip()[1:]
|
||||
|
||||
elif len(line.strip()) != 0:
|
||||
last_seq += line.strip()
|
||||
|
||||
if last_seq != '':
|
||||
seq_list.append([last_flag, last_seq.split(':')[0], last_seq.split(':')[1]])
|
||||
|
||||
print("Number of paired sequences:", len(seq_list))
|
||||
### Get pplm features ###
|
||||
mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B, max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B = get_pplm_features(seqA, seqB, device)
|
||||
|
||||
### Prediction ###
|
||||
score_list = []
|
||||
for i in range(len(seq_list)):
|
||||
flag = seq_list[i][0]
|
||||
seqA = seq_list[i][1]
|
||||
seqB = seq_list[i][2]
|
||||
with torch.no_grad():
|
||||
predictions_list = []
|
||||
for model_weight in model_weights['mean']:
|
||||
model.load_state_dict(model_weight)
|
||||
predictions = model(mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B)
|
||||
predictions_ = model(mean_inter_attn, mean_attn_BB, mean_attn_AA, mean_embed_B, mean_embed_A)
|
||||
predictions = (predictions + predictions_) / 2
|
||||
predictions_list.append(predictions)
|
||||
|
||||
mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B, max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B = get_pplm_features(seqA, seqB, device)
|
||||
for model_weight in model_weights['max']:
|
||||
model.load_state_dict(model_weight)
|
||||
predictions = model(max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B)
|
||||
predictions_ = model(max_inter_attn, max_attn_BB, max_attn_AA, max_embed_B, max_embed_A)
|
||||
predictions = (predictions + predictions_) / 2
|
||||
predictions_list.append(predictions)
|
||||
|
||||
with torch.no_grad():
|
||||
predictions_list = []
|
||||
for model_path in models_path:
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(checkpoint["net"])
|
||||
predictions = torch.stack(predictions_list)
|
||||
predictions = torch.mean(predictions, dim=0).squeeze().cpu().numpy()
|
||||
|
||||
predictions = model(mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B, max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B)
|
||||
predictions_list.append(predictions)
|
||||
|
||||
predictions = torch.stack(predictions_list)
|
||||
predictions = torch.mean(predictions, dim=0).squeeze().cpu().numpy()
|
||||
|
||||
score_list.append([flag, predictions])
|
||||
|
||||
### Write results ###
|
||||
with open(args.output_path, "w") as f:
|
||||
for i in range(len(score_list)):
|
||||
flag, prediction = score_list[i]
|
||||
f.write(">" + flag + "\n")
|
||||
f.write(f"{prediction:.6f}" + "\n")
|
||||
print("Predicted interaction score:", predictions)
|
||||
|
||||
|
||||
def read_sequence(seq_path):
|
||||
seq = ""
|
||||
for line in open(seq_path, "r").readlines():
|
||||
if not line.startswith(">"):
|
||||
seq += line.strip()
|
||||
|
||||
return seq
|
||||
|
||||
def get_pplm_features(seqA, seqB, device):
|
||||
mian_path = os.path.dirname(__file__)
|
||||
sys.path.append(os.path.abspath(mian_path))
|
||||
|
||||
from pplm import PPLM, Alphabet
|
||||
model_location = os.path.join(mian_path, 'pplm/models/', 'pplm_t33_650M.pt')
|
||||
model_location = os.path.join(mian_path, 'weights/', 'pplm_t33_650M.pt')
|
||||
|
||||
##### Loading PPLM Model #####
|
||||
alphabet = Alphabet.from_architecture()
|
||||
@@ -146,13 +130,6 @@ def get_pplm_features(seqA, seqB, device):
|
||||
|
||||
return mean_inter_attn, mean_attn_AA, mean_attn_BB, mean_embed_A, mean_embed_B, max_inter_attn, max_attn_AA, max_attn_BB, max_embed_A, max_embed_B
|
||||
|
||||
def read_sequence(seq_path):
|
||||
seq = ""
|
||||
for line in open(seq_path, "r").readlines():
|
||||
if not line.startswith(">"):
|
||||
seq += line.strip()
|
||||
|
||||
return seq
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user