mirror of
https://github.com/gcorso/DiffDock.git
synced 2026-06-04 09:54:21 +08:00
Also add web app code for simple gradio app. Refine requirements.txt/environment.yml. Automatically download models if not present.
21 lines
562 B
Python
21 lines
562 B
Python
import os
|
|
import torch
|
|
# from utils.utils import get_default_device
|
|
|
|
|
|
def get_default_device():
|
|
if torch.cuda.is_available():
|
|
return torch.device('cuda')
|
|
elif torch.backends.mps.is_available():
|
|
# Not all operations implemented in MPS yet
|
|
use_mps = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0") == "1"
|
|
if use_mps:
|
|
return torch.device('mps')
|
|
else:
|
|
return torch.device('cpu')
|
|
else:
|
|
return torch.device('cpu')
|
|
|
|
|
|
device = get_default_device()
|
|
print(f"DiffDock Device: {device}") |