Update run_pplm-contact.py

This commit is contained in:
Jun Liu
2025-11-24 17:38:43 +08:00
committed by GitHub
parent 8d33c75376
commit 4fc3bb0814

View File

@@ -310,13 +310,7 @@ def collect_all_features():
return feats
def predict_contact(feats, mode, device="cpu"):
script_dir = os.path.dirname(os.path.abspath(__file__))
# if mode == "homo":
# model_paths = [os.path.join(script_dir, "pplm_contact/models/pplm_contact.homo_" + str(i) + ".pkl") for i in range(1, 6)]
# else:
# model_paths = [os.path.join(script_dir, "pplm_contact/models/pplm_contact.hetero_" + str(i) + ".pkl") for i in range(1, 6)]
#
models_weight_ = torch.load(os.path.join(script_dir, "weights/pplm_contact_models.pkl"), map_location=device)
if mode == "homo":
models_weight = models_weight_['homo']
@@ -336,7 +330,6 @@ def predict_contact(feats, mode, device="cpu"):
ensemble_pred_inter_contact = []
with torch.no_grad():
for model_weight in models_weight:
# checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(model_weight)
model.eval()
@@ -346,7 +339,7 @@ def predict_contact(feats, mode, device="cpu"):
contact_pred_ = model(intra2_1d, intra2_2d, intra1_1d, intra1_2d, inter_2d.transpose(-1, -2), intra2_Mdist, intra1_Mdist)
contact_pred = (contact_pred + contact_pred_.transpose(-1, -2)) / 2
pred_inter_contact = contact_pred # [inter_contact_mask_ur]
pred_inter_contact = contact_pred
ensemble_pred_inter_contact.append(pred_inter_contact)
pred_inter_contact = torch.stack(ensemble_pred_inter_contact)
@@ -354,6 +347,10 @@ def predict_contact(feats, mode, device="cpu"):
return pred_inter_contact
if __name__ == "__main__":
main()