mirror of
https://github.com/gcorso/DiffDock.git
synced 2026-06-04 18:04:23 +08:00
Fix ckpt path #85
This commit is contained in:
@@ -213,7 +213,7 @@ class ConfidenceDataset(Dataset):
|
||||
t_to_sigma = partial(t_to_sigma_compl, args=self.original_model_args)
|
||||
|
||||
model = get_model(self.original_model_args, self.device, t_to_sigma=t_to_sigma, no_parallel=True)
|
||||
state_dict = torch.load(f'{self.original_model_dir}/{self.model_ckpt}.pt', map_location=torch.device('cpu'))
|
||||
state_dict = torch.load(f'{self.original_model_dir}/{self.model_ckpt}', map_location=torch.device('cpu'))
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model = model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user