diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 7619e9c..2a3bf36 100755 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -41,6 +41,15 @@ def main(conf: HydraConfig) -> None: if conf.inference.deterministic: make_deterministic() + # Check for available GPU and print result of check + if torch.cuda.is_available(): + device_name = torch.cuda.get_device_name(torch.cuda.current_device()) + log.info(f"Found GPU with device_name {device_name}. Will run RFdiffusion on {device_name}") + else: + log.info("////////////////////////////////////////////////") + log.info("///// NO GPU DETECTED! Falling back to CPU /////") + log.info("////////////////////////////////////////////////") + # Initialize sampler and target/contig. sampler = iu.sampler_selector(conf)