Fix ckpt path #85

This commit is contained in:
Gabriele Corso
2023-07-18 01:40:55 -07:00
committed by GitHub
parent 2782769ccd
commit 600a23fa23

View File

@@ -213,7 +213,7 @@ class ConfidenceDataset(Dataset):
t_to_sigma = partial(t_to_sigma_compl, args=self.original_model_args)
model = get_model(self.original_model_args, self.device, t_to_sigma=t_to_sigma, no_parallel=True)
state_dict = torch.load(f'{self.original_model_dir}/{self.model_ckpt}.pt', map_location=torch.device('cpu'))
state_dict = torch.load(f'{self.original_model_dir}/{self.model_ckpt}', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=True)
model = model.to(self.device)
model.eval()