mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
import torch
|
|
import argparse
|
|
from huggingface_hub import PyTorchModelHubMixin
|
|
from dscript.models.interaction import DSCRIPTModel
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser("Push a model to HuggingFace")
|
|
parser.add_argument("model_pt", help="Path to the model.pt file")
|
|
parser.add_argument("hf_user", help="HuggingFace user")
|
|
parser.add_argument("hf_model_name", help="HuggingFace model name")
|
|
args = parser.parse_args()
|
|
|
|
model_pt = args.model_pt
|
|
hf_user = args.hf_user
|
|
hf_model_name = args.hf_model_name
|
|
|
|
# Load the state dict
|
|
model_old = torch.load(model_pt, map_location=torch.device("cpu"))
|
|
state_dict = model_old.state_dict()
|
|
|
|
# Load the model
|
|
model = DSCRIPTModel(
|
|
emb_nin = 6165,
|
|
emb_nout = 100,
|
|
emb_dropout = 0.5,
|
|
con_embed_dim = 121,
|
|
con_hidden_dim = 50,
|
|
con_width = 7,
|
|
use_cuda = False,
|
|
do_w=model_old.do_w,
|
|
pool_size=9,
|
|
do_pool=model_old.do_pool,
|
|
do_sigmoid=model_old.do_sigmoid,
|
|
)
|
|
|
|
# Load the state dict into the model
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
|
|
model.push_to_hub(f"{hf_user}/{hf_model_name}")
|
|
|
|
|