mirror of
https://github.com/junliu621/PPLM.git
synced 2026-06-04 14:24:22 +08:00
Update run_pplm-contact.py
This commit is contained in:
@@ -310,13 +310,7 @@ def collect_all_features():
|
||||
return feats
|
||||
|
||||
def predict_contact(feats, mode, device="cpu"):
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# if mode == "homo":
|
||||
# model_paths = [os.path.join(script_dir, "pplm_contact/models/pplm_contact.homo_" + str(i) + ".pkl") for i in range(1, 6)]
|
||||
# else:
|
||||
# model_paths = [os.path.join(script_dir, "pplm_contact/models/pplm_contact.hetero_" + str(i) + ".pkl") for i in range(1, 6)]
|
||||
#
|
||||
models_weight_ = torch.load(os.path.join(script_dir, "weights/pplm_contact_models.pkl"), map_location=device)
|
||||
if mode == "homo":
|
||||
models_weight = models_weight_['homo']
|
||||
@@ -336,7 +330,6 @@ def predict_contact(feats, mode, device="cpu"):
|
||||
ensemble_pred_inter_contact = []
|
||||
with torch.no_grad():
|
||||
for model_weight in models_weight:
|
||||
# checkpoint = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(model_weight)
|
||||
model.eval()
|
||||
|
||||
@@ -346,7 +339,7 @@ def predict_contact(feats, mode, device="cpu"):
|
||||
contact_pred_ = model(intra2_1d, intra2_2d, intra1_1d, intra1_2d, inter_2d.transpose(-1, -2), intra2_Mdist, intra1_Mdist)
|
||||
contact_pred = (contact_pred + contact_pred_.transpose(-1, -2)) / 2
|
||||
|
||||
pred_inter_contact = contact_pred # [inter_contact_mask_ur]
|
||||
pred_inter_contact = contact_pred
|
||||
ensemble_pred_inter_contact.append(pred_inter_contact)
|
||||
|
||||
pred_inter_contact = torch.stack(ensemble_pred_inter_contact)
|
||||
@@ -354,6 +347,10 @@ def predict_contact(feats, mode, device="cpu"):
|
||||
|
||||
return pred_inter_contact
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user