diff --git a/run_pplm-contact.py b/run_pplm-contact.py index 727bfed..f53b626 100644 --- a/run_pplm-contact.py +++ b/run_pplm-contact.py @@ -310,13 +310,7 @@ def collect_all_features(): return feats def predict_contact(feats, mode, device="cpu"): - script_dir = os.path.dirname(os.path.abspath(__file__)) - # if mode == "homo": - # model_paths = [os.path.join(script_dir, "pplm_contact/models/pplm_contact.homo_" + str(i) + ".pkl") for i in range(1, 6)] - # else: - # model_paths = [os.path.join(script_dir, "pplm_contact/models/pplm_contact.hetero_" + str(i) + ".pkl") for i in range(1, 6)] - # models_weight_ = torch.load(os.path.join(script_dir, "weights/pplm_contact_models.pkl"), map_location=device) if mode == "homo": models_weight = models_weight_['homo'] @@ -336,7 +330,6 @@ def predict_contact(feats, mode, device="cpu"): ensemble_pred_inter_contact = [] with torch.no_grad(): for model_weight in models_weight: - # checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(model_weight) model.eval() @@ -346,7 +339,7 @@ def predict_contact(feats, mode, device="cpu"): contact_pred_ = model(intra2_1d, intra2_2d, intra1_1d, intra1_2d, inter_2d.transpose(-1, -2), intra2_Mdist, intra1_Mdist) contact_pred = (contact_pred + contact_pred_.transpose(-1, -2)) / 2 - pred_inter_contact = contact_pred # [inter_contact_mask_ur] + pred_inter_contact = contact_pred ensemble_pred_inter_contact.append(pred_inter_contact) pred_inter_contact = torch.stack(ensemble_pred_inter_contact) @@ -354,6 +347,10 @@ def predict_contact(feats, mode, device="cpu"): return pred_inter_contact + + + + if __name__ == "__main__": main()