Update run_pplm-contact2.py

This commit is contained in:
Jun Liu
2025-11-24 17:39:12 +08:00
committed by GitHub
parent 4fc3bb0814
commit 37caa67e02

View File

@@ -14,7 +14,7 @@ from pplm_contact.config import *
def main():
parser = argparse.ArgumentParser(description="Protein-Protein Contact Prediction",
epilog="v0.0.1")
epilog="v2.0.1")
parser.add_argument("dimer_pdb_paths",
nargs='+',
@@ -50,26 +50,23 @@ def main():
##### Step 0: Process the input pdb (clean pdb & extract sequence & distance map) #####
target_list = []
for dimer_pdb_path in dimer_pdb_paths:
target = str(args.output_folder).split('/')[-1].replace('.pdb', '')
target = str(dimer_pdb_path).split('/')[-1].replace('.pdb', '')
target_pdb = os.path.join(workspace, target + ".clean.pdb")
target1_seq = os.path.join(workspace, target + "_A.fasta")
target2_seq = os.path.join(workspace, target + "_B.fasta")
target1_monomer_dist = os.path.join(workspace, target + "_A.monomer_dist.pkl")
target2_monomer_dist = os.path.join(workspace, target + "_B.monomer_dist.pkl")
inter_chain_dist = os.path.join(workspace, target + ".inter_chain_dist.pkl")
subprocess.run("grep \"^ATOM\" " + str(dimer_pdb_path) + " | sed 's/MEX/CYS/g; s/HID/HIS/g; s/HIE/HIS/g; s/HIP/HIS/g; s/MSE/MET/g; s/ASX/ASN/g; s/GLX/GLN/g; s/TYS/TRP/g' > " + target_pdb, shell=True, check=True)
seqA, target1_res_idx_type, seqB, target2_res_idx_type = extract_seq_and_dist_map_dimer(target_pdb, target1_seq, target2_seq, target1_monomer_dist, target2_monomer_dist, inter_chain_dist)
seqA, target1_res_idx_type, target1_monomer_dist, seqB, target2_res_idx_type, target2_monomer_dist, inter_chain_dist = extract_seq_and_dist_map_dimer(target_pdb)
target_data = {'name': target, 'seqA': seqA, 'seqB': seqB, 'target1_res_idx_type': target1_res_idx_type, 'target2_res_idx_type': target2_res_idx_type,
'seqA_path': target1_seq, 'seqB_path': target2_seq, 'monomer_A_dist': target1_monomer_dist, 'monomer_B_dist': target2_monomer_dist, 'inter_chain_dist': inter_chain_dist}
'monomer_A_dist': target1_monomer_dist, 'monomer_B_dist': target2_monomer_dist, 'inter_chain_dist': inter_chain_dist}
target_list.append(target_data)
target = target_list[0]['name']
seqA = target_list[0]['seqA']
seqB = target_list[0]['seqB']
target1_seq = target_list[0]['seqA_path']
target2_seq = target_list[0]['seqB_path']
target1_res_idx_type = target_list[0]['target1_res_idx_type']
target2_res_idx_type = target_list[0]['target2_res_idx_type']
@@ -78,6 +75,14 @@ def main():
print("Error: all complex structure most have the same sequence!!!")
exit()
target1_seq = os.path.join(workspace, target + "_A.fasta")
target2_seq = os.path.join(workspace, target + "_B.fasta")
with open(target1_seq, 'w') as fw:
fw.write(">seqA\n" + seqA)
with open(target2_seq, 'w') as fw:
fw.write(">seqB\n" + seqB)
print("sequence of first chain:", seqA)
print("sequence of second chain:", seqB)
@@ -161,8 +166,6 @@ def main():
data = "{:<10}".format(k+1) + "{:<10}".format(str(res1_idx) + ":A") + "{:<10}".format(res1_type) + "{:<10}".format(str(res2_idx) + ":B") + "{:<10}".format(res2_type) + "{:<10}".format(f"{prob:.6g}") + "\n"
fw.write(data)
def define_param(args, target_name):
global target, device
global target1_msa, target2_msa, target1_hhm, target1_aln
@@ -222,7 +225,7 @@ def get_pplm_features(seqA_path, seqB_path, out_pkl_path, device='cpu'):
sys.path.append(os.path.abspath(mian_path))
from pplm import PPLM, Alphabet
model_location = os.path.join(mian_path, 'pplm/models/', 'pplm_t33_650M.pt')
model_location = os.path.join(mian_path, 'weights/pplm_t33_650M.pt')
##### Loading PPLM Model #####
alphabet = Alphabet.from_architecture()
@@ -270,10 +273,7 @@ def get_pplm_features(seqA_path, seqB_path, out_pkl_path, device='cpu'):
with open(out_pkl_path, mode='wb') as fw:
pickle.dump(inter_attn, fw)
def collect_all_features(target1_monomer_dist, target2_monomer_dist, inter_chain_dist):
with open(target1_monomer_dist, "rb") as fr:
target1_M_dist = pickle.load(fr)
def collect_all_features(target1_M_dist, target2_M_dist, inter_dist):
target1_DCA_DI = np.expand_dims(np.loadtxt(target1_dca_di), 0)
target1_DCA_APC = np.expand_dims(np.loadtxt(target1_dca_apc), 0)
target1_PSSM = load_hmm(target1_hhm)['PSSM']
@@ -283,7 +283,6 @@ def collect_all_features(target1_monomer_dist, target2_monomer_dist, inter_chain
target1_esm_msa_1d = esm_msa_data['esm_msa_1d']
target1_esm_msa_2d = esm_msa_data['row_attentions']
# print(target1_DCA_DI.shape, target1_DCA_APC.shape, target1_esm_msa_2d.shape, RBF(target1_M_dist).shape, target1_M_dist.shape)
intra1_1d = np.concatenate([target1_PSSM, target1_esm_msa_1d], axis=-1).transpose(1,0)
intra1_2d = np.concatenate([target1_DCA_DI, target1_DCA_APC, target1_esm_msa_2d, RBF(target1_M_dist)], axis=0)
intra1_Mdist = target1_M_dist
@@ -298,9 +297,6 @@ def collect_all_features(target1_monomer_dist, target2_monomer_dist, inter_chain
inter_2d = np.concatenate([target1_DCA_DI, target1_DCA_APC, target1_esm_msa_2d, inter_pplm_attn], axis=0)
else:
with open(target2_monomer_dist, "rb") as fr:
target2_M_dist = pickle.load(fr)
target2_DCA_DI = np.expand_dims(np.loadtxt(target2_dca_di), 0)
target2_DCA_APC = np.expand_dims(np.loadtxt(target2_dca_apc), 0)
target2_PSSM = load_hmm(target2_hhm)['PSSM']
@@ -310,7 +306,6 @@ def collect_all_features(target1_monomer_dist, target2_monomer_dist, inter_chain
target2_esm_msa_1d = esm_msa_data['esm_msa_1d']
target2_esm_msa_2d = esm_msa_data['row_attentions']
# print(target2_DCA_DI.shape, target2_DCA_APC.shape, target2_esm_msa_2d.shape, RBF(target2_M_dist).shape, target2_M_dist.shape)
intra2_1d = np.concatenate([target2_PSSM, target2_esm_msa_1d], axis=-1).transpose(1, 0)
intra2_2d = np.concatenate([target2_DCA_DI, target2_DCA_APC, target2_esm_msa_2d, RBF(target2_M_dist)], axis=0)
intra2_Mdist = target2_M_dist
@@ -324,13 +319,8 @@ def collect_all_features(target1_monomer_dist, target2_monomer_dist, inter_chain
len1 = inter_pplm_attn.shape[-2]
len2 = inter_esm_msa_2d.shape[-1]
# print(inter_DCA_DI.shape, inter_DCA_APC.shape, inter_esm_msa_2d.shape, inter_pplm_attn.shape)
inter_2d = np.concatenate([inter_DCA_DI[:, :len1, len1:len1+len2], inter_DCA_APC[:, :len1, len1:len1+len2], inter_esm_msa_2d[:, :len1, len1:len1+len2], inter_pplm_attn], axis=0)
with open(inter_chain_dist, "rb") as fr:
inter_dist = pickle.load(fr)
# print("inter_2d:", inter_2d.shape, inter_dist.shape, RBF(inter_dist).shape)
inter_2d = np.concatenate([inter_2d, RBF(inter_dist)], axis=0)
feats = {"intra1_1d": intra1_1d, "intra1_2d": intra1_2d, "intra1_Mdist": intra1_Mdist, "intra2_1d": intra2_1d, "intra2_2d": intra2_2d, "intra2_Mdist": intra2_Mdist, "inter_2d": inter_2d}
@@ -339,10 +329,11 @@ def collect_all_features(target1_monomer_dist, target2_monomer_dist, inter_chain
def predict_contact(feats, device="cpu"):
script_dir = os.path.dirname(os.path.abspath(__file__))
models_weight_ = torch.load(os.path.join(script_dir, "weights/pplm_contact2_models.pkl"), map_location=device)
if mode == "homo":
model_paths = [os.path.join(script_dir, "pplm_contact/models/pplm_contact2.homo_" + str(i) + ".pkl") for i in range(1, 6)]
models_weight = models_weight_['homo']
else:
model_paths = [os.path.join(script_dir, "pplm_contact/models/pplm_contact2.hetero_" + str(i) + ".pkl") for i in range(1, 6)]
models_weight = models_weight_['hetero']
model = PPLM_Contact(inter_2d_dim=144+2+660+64)
model.to(device)
@@ -357,20 +348,17 @@ def predict_contact(feats, device="cpu"):
ensemble_pred_inter_contact = []
with torch.no_grad():
for model_path in model_paths:
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
for model_weight in models_weight:
model.load_state_dict(model_weight)
model.eval()
#################################### Network predict #######################################
contact_pred = model(intra1_1d, intra1_2d, intra2_1d, intra2_2d, inter_2d, intra1_Mdist, intra2_Mdist)
if mode == "hetero":
contact_pred_ = model(intra2_1d, intra2_2d, intra1_1d, intra1_2d, inter_2d.transpose(-1, -2), intra2_Mdist, intra1_Mdist)
contact_pred = (contact_pred + contact_pred_.transpose(-1, -2)) / 2
pred_inter_contact = contact_pred # [inter_contact_mask_ur]
pred_inter_contact = contact_pred
ensemble_pred_inter_contact.append(pred_inter_contact)
pred_inter_contact = torch.stack(ensemble_pred_inter_contact)
@@ -378,7 +366,6 @@ def predict_contact(feats, device="cpu"):
return pred_inter_contact
if __name__ == "__main__":
main()