Files
D-SCRIPT/scripts/push_huggingface.py

43 lines
1.2 KiB
Python

import torch
import argparse
from huggingface_hub import PyTorchModelHubMixin
from dscript.models.interaction import DSCRIPTModel
if __name__ == "__main__":
parser = argparse.ArgumentParser("Push a model to HuggingFace")
parser.add_argument("model_pt", help="Path to the model.pt file")
parser.add_argument("hf_user", help="HuggingFace user")
parser.add_argument("hf_model_name", help="HuggingFace model name")
args = parser.parse_args()
model_pt = args.model_pt
hf_user = args.hf_user
hf_model_name = args.hf_model_name
# Load the state dict
model_old = torch.load(model_pt, map_location=torch.device("cpu"))
state_dict = model_old.state_dict()
# Load the model
model = DSCRIPTModel(
emb_nin = 6165,
emb_nout = 100,
emb_dropout = 0.5,
con_embed_dim = 121,
con_hidden_dim = 50,
con_width = 7,
use_cuda = False,
do_w=model_old.do_w,
pool_size=9,
do_pool=model_old.do_pool,
do_sigmoid=model_old.do_sigmoid,
)
# Load the state dict into the model
model.load_state_dict(state_dict)
model.eval()
model.push_to_hub(f"{hf_user}/{hf_model_name}")