feat: Add Intel XPU support for Foundry Models (#156)

feat: Add Intel XPU support for all models

- Add XPU accelerator, precision, and strategy classes
- Update DDP utilities to detect and use Intel XPU
- Add XPU trainer configs for RF3 and RFD3
- Update MPNN inference engine for XPU compatibility
- Update README with XPU documentation
This commit is contained in:
Daan Krol
2026-01-08 01:30:04 +01:00
committed by GitHub
parent d0379cea93
commit 8bfde2381a
10 changed files with 301 additions and 17 deletions

View File

@@ -15,6 +15,16 @@ All models within Foundry rely on [AtomWorks](https://github.com/RosettaCommons/
pip install "rc-foundry[all]"
```
**Intel XPU Installation**
For Intel XPU devices, install PyTorch with XPU support first, then install Foundry.
```bash
pip install torch --index-url https://download.pytorch.org/whl/xpu
pip install "rc-foundry[all]"
```
> [!NOTE]
> Use `pip` (not `uv`) for XPU installs since UV re-resolves dependencies and may replace your XPU torch with the standard PyPI version.
**Downloading weights** Models can be downloaded to a target folder with:
```
foundry install base-models --checkpoint-dir <path/to/ckpt/dir>

View File

@@ -67,11 +67,15 @@ class MPNNInferenceEngine:
else checkpoint_path
)
# Determine the device.
# Determine the device (supports XPU, CUDA, and CPU).
if device is not None:
self.device = torch.device(device)
elif torch.cuda.is_available():
self.device = torch.device("cuda")
elif hasattr(torch, "xpu") and torch.xpu.is_available():
self.device = torch.device("xpu")
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cpu")
# Set up allowed model types.
self.allowed_model_types = {"protein_mpnn", "ligand_mpnn"}

View File

@@ -0,0 +1,6 @@
strategy: xpu_single
accelerator: xpu
devices_per_node: 1
num_nodes: 1

View File

@@ -0,0 +1,6 @@
strategy: xpu_single
accelerator: xpu
devices_per_node: 1
num_nodes: 1

View File

@@ -36,6 +36,11 @@ from foundry.utils.weights import (
freeze_parameters_with_config,
load_weights_with_policies,
)
from foundry.utils.xpu import (
SingleXPUStrategy,
XPUAccelerator,
XPUMixedPrecision,
)
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
@@ -122,8 +127,17 @@ class FabricTrainer(ABC):
(3) Fabric Loggers (https://lightning.ai/docs/fabric/2.4.0/api/loggers.html)
(4) Efficient Gradient Accumulation (https://lightning.ai/docs/fabric/2.4.0/advanced/gradient_accumulation.html)
"""
# Use custom DDP strategy only for multi-device, non-interactive environments
if (
# Handle XPU accelerator
is_xpu = hasattr(torch, "xpu") and torch.xpu.is_available()
if accelerator == "xpu" or (accelerator == "auto" and is_xpu and not torch.cuda.is_available()):
accelerator = XPUAccelerator()
precision_plugin = None
if precision in ("16-mixed", "bf16-mixed"):
precision_plugin = XPUMixedPrecision(precision=precision)
precision = None # Handled by plugin
strategy = SingleXPUStrategy(precision_plugin=precision_plugin)
ranked_logger.info("Using Intel XPU with SingleXPUStrategy")
elif (
strategy == "ddp"
and not is_interactive_environment()
and not (num_nodes == 1 and devices_per_node == 1)

View File

@@ -20,7 +20,7 @@ def is_rank_zero() -> bool:
def set_accelerator_based_on_availability(cfg: dict | DictConfig):
"""Set training accelerator to CPU if no GPUs are available.
"""Set training accelerator based on available hardware.
Args:
cfg: Hydra object with trainer settings "accelerator", "devices_per_node", and "num_nodes".
@@ -28,22 +28,25 @@ def set_accelerator_based_on_availability(cfg: dict | DictConfig):
Returns:
None; modifies the input `cfg` object in place.
"""
if not torch.cuda.is_available():
logger.error(
"No GPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
)
assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
for key in ["accelerator", "devices_per_node", "num_nodes"]:
assert (
key in cfg.trainer
), f"Configuration object must have a 'trainer.{key}' key."
assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
for key in ["accelerator", "devices_per_node", "num_nodes"]:
assert (
key in cfg.trainer
), f"Configuration object must have a 'trainer.{key}' key."
# Override accelerator settings
if torch.cuda.is_available():
cfg.trainer.accelerator = "gpu"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
logger.info("Intel XPU detected - using XPU accelerator")
cfg.trainer.accelerator = "xpu"
else:
logger.error(
"No GPUs/XPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
)
cfg.trainer.accelerator = "cpu"
cfg.trainer.devices_per_node = 1
cfg.trainer.num_nodes = 1
else:
cfg.trainer.accelerator = "gpu"
return cfg

View File

@@ -0,0 +1,28 @@
"""XPU utilities for Intel GPU support.
XPU support in PyTorch is now native (torch.xpu.is_available()), but Lightning Fabric
requires custom Accelerator, Strategy, and Precision plugins for proper XPU handling.
These components are used directly (not via registry) when XPU is detected:
- XPUAccelerator: Custom accelerator for XPU devices
- SingleXPUStrategy: Strategy for single-device XPU training/inference
- XPUMixedPrecision: Precision plugin with proper XPU autocast support
Usage:
from foundry.utils.xpu import XPUAccelerator, SingleXPUStrategy, XPUMixedPrecision
# Check availability
if XPUAccelerator.is_available():
strategy = SingleXPUStrategy(precision_plugin=XPUMixedPrecision("bf16-mixed"))
Note:
The FabricTrainer automatically uses these components when XPU is detected.
You typically don't need to use them directly unless customizing behavior.
"""
from .single_xpu_strategy import SingleXPUStrategy
from .xpu_accelerator import XPUAccelerator
from .xpu_precision import XPUMixedPrecision
__all__ = ["SingleXPUStrategy", "XPUAccelerator", "XPUMixedPrecision"]

View File

@@ -0,0 +1,48 @@
"""Lightning Fabric strategy for single XPU device.
https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html
"""
import torch
from lightning.fabric.plugins import CheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies import SingleDeviceStrategy
from lightning.fabric.utilities.types import _DEVICE
class SingleXPUStrategy(SingleDeviceStrategy):
"""Strategy for training/inference on a single Intel XPU device.
This strategy extends SingleDeviceStrategy to properly handle XPU devices.
"""
strategy_name = "xpu_single"
def __init__(
self,
device: _DEVICE = "xpu:0",
checkpoint_io: CheckpointIO | None = None,
precision_plugin: Precision | None = None,
) -> None:
"""Initialize the single XPU strategy.
Args:
device: The XPU device to use. Defaults to "xpu:0".
checkpoint_io: Plugin for checkpoint I/O.
precision_plugin: Plugin for precision handling (set via _precision property).
Raises:
RuntimeError: If XPU devices are not available.
"""
if not (hasattr(torch, "xpu") and torch.xpu.is_available()):
msg = "`SingleXPUStrategy` requires XPU devices to run"
raise RuntimeError(msg)
super().__init__(
device=device,
checkpoint_io=checkpoint_io,
)
# Precision is handled via the _precision property in newer Lightning versions
if precision_plugin is not None:
self._precision = precision_plugin

View File

@@ -0,0 +1,92 @@
"""XPU Accelerator for Intel XPU devices."""
from typing import Any
import torch
from lightning.fabric.accelerators import Accelerator
class XPUAccelerator(Accelerator):
"""Accelerator for Intel XPU devices.
This accelerator enables training and inference on Intel GPUs using
PyTorch's native XPU support (torch.xpu).
"""
@property
def name(self) -> str:
"""Return the name of this accelerator."""
return "xpu"
@staticmethod
def setup_device(device: torch.device) -> None:
"""Set up the specified XPU device.
Args:
device: The torch device to set up.
Raises:
RuntimeError: If device is not an XPU device.
"""
if device.type != "xpu":
msg = f"Device should be xpu, got {device} instead"
raise RuntimeError(msg)
torch.xpu.set_device(device)
@staticmethod
def parse_devices(devices: str | list | torch.device) -> list:
"""Parse devices specification for XPU training.
Args:
devices: Device specification (int, list of ints, or string).
Returns:
List of device indices.
"""
if isinstance(devices, list):
return devices
return [devices]
@staticmethod
def get_parallel_devices(devices: list) -> list[torch.device]:
"""Generate a list of parallel XPU devices.
Args:
devices: List of device indices.
Returns:
List of torch.device objects for XPU.
"""
return [torch.device("xpu", idx) for idx in devices]
@staticmethod
def auto_device_count() -> int:
"""Return the number of available XPU devices."""
return torch.xpu.device_count()
@staticmethod
def is_available() -> bool:
"""Check if XPU is available."""
return hasattr(torch, "xpu") and torch.xpu.is_available()
@staticmethod
def get_device_stats(device: str | torch.device) -> dict[str, Any]:
"""Return XPU device statistics.
Currently returns an empty dict as XPU stats API may vary.
Args:
device: The device to get stats for.
Returns:
Dictionary of device statistics.
"""
del device # Unused
return {}
def teardown(self) -> None:
"""Clean up XPU accelerator resources."""
# Empty implementation required by base class
pass

View File

@@ -0,0 +1,73 @@
"""XPU Precision Plugin for Lightning Fabric."""
from contextlib import contextmanager
from typing import Any, Generator, Literal
import torch
from lightning.fabric.plugins.precision import MixedPrecision
from torch import Tensor
class XPUMixedPrecision(MixedPrecision):
"""Mixed precision plugin for Intel XPU devices.
This overrides the default MixedPrecision plugin to use 'xpu' as the
device type for torch.autocast instead of 'cuda'.
"""
def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed"] = "bf16-mixed",
) -> None:
"""Initialize XPU mixed precision.
Args:
precision: The precision mode. "16-mixed" uses float16,
"bf16-mixed" uses bfloat16. Defaults to "bf16-mixed".
Raises:
ValueError: If precision is not "16-mixed" or "bf16-mixed".
"""
# Determine dtype from precision string
if precision == "16-mixed":
dtype = torch.float16
elif precision == "bf16-mixed":
dtype = torch.bfloat16
else:
msg = f"Invalid precision: {precision}. Must be '16-mixed' or 'bf16-mixed'"
raise ValueError(msg)
# Initialize with xpu device type
super().__init__(precision=precision, device="xpu")
self._desired_input_dtype = dtype
@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""Context manager for forward pass with XPU autocast."""
with torch.autocast(device_type="xpu", dtype=self._desired_input_dtype):
yield
def convert_input(self, data: Any) -> Any:
"""Convert input data to the appropriate precision.
Args:
data: Input data to convert.
Returns:
Converted data.
"""
return self._convert_fp_tensor(data)
def _convert_fp_tensor(self, data: Any) -> Any:
"""Convert floating point tensors to the desired dtype.
Args:
data: Data to convert.
Returns:
Converted data if it's a floating point tensor, otherwise unchanged.
"""
if isinstance(data, Tensor) and data.is_floating_point():
return data.to(self._desired_input_dtype)
return data