mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
feat: Add Intel XPU support for Foundry Models (#156)
feat: Add Intel XPU support for all models - Add XPU accelerator, precision, and strategy classes - Update DDP utilities to detect and use Intel XPU - Add XPU trainer configs for RF3 and RFD3 - Update MPNN inference engine for XPU compatibility - Update README with XPU documentation
This commit is contained in:
10
README.md
10
README.md
@@ -15,6 +15,16 @@ All models within Foundry rely on [AtomWorks](https://github.com/RosettaCommons/
|
||||
pip install "rc-foundry[all]"
|
||||
```
|
||||
|
||||
**Intel XPU Installation**
|
||||
|
||||
For Intel XPU devices, install PyTorch with XPU support first, then install Foundry.
|
||||
```bash
|
||||
pip install torch --index-url https://download.pytorch.org/whl/xpu
|
||||
pip install "rc-foundry[all]"
|
||||
```
|
||||
> [!NOTE]
|
||||
> Use `pip` (not `uv`) for XPU installs since UV re-resolves dependencies and may replace your XPU torch with the standard PyPI version.
|
||||
|
||||
**Downloading weights** Models can be downloaded to a target folder with:
|
||||
```
|
||||
foundry install base-models --checkpoint-dir <path/to/ckpt/dir>
|
||||
|
||||
@@ -67,11 +67,15 @@ class MPNNInferenceEngine:
|
||||
else checkpoint_path
|
||||
)
|
||||
|
||||
# Determine the device.
|
||||
# Determine the device (supports XPU, CUDA, and CPU).
|
||||
if device is not None:
|
||||
self.device = torch.device(device)
|
||||
elif torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
self.device = torch.device("xpu")
|
||||
else:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# Set up allowed model types.
|
||||
self.allowed_model_types = {"protein_mpnn", "ligand_mpnn"}
|
||||
|
||||
6
models/rf3/configs/trainer/xpu.yaml
Normal file
6
models/rf3/configs/trainer/xpu.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
strategy: xpu_single
|
||||
|
||||
accelerator: xpu
|
||||
devices_per_node: 1
|
||||
num_nodes: 1
|
||||
|
||||
6
models/rfd3/configs/trainer/xpu.yaml
Normal file
6
models/rfd3/configs/trainer/xpu.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
strategy: xpu_single
|
||||
|
||||
accelerator: xpu
|
||||
devices_per_node: 1
|
||||
num_nodes: 1
|
||||
|
||||
@@ -36,6 +36,11 @@ from foundry.utils.weights import (
|
||||
freeze_parameters_with_config,
|
||||
load_weights_with_policies,
|
||||
)
|
||||
from foundry.utils.xpu import (
|
||||
SingleXPUStrategy,
|
||||
XPUAccelerator,
|
||||
XPUMixedPrecision,
|
||||
)
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
@@ -122,8 +127,17 @@ class FabricTrainer(ABC):
|
||||
(3) Fabric Loggers (https://lightning.ai/docs/fabric/2.4.0/api/loggers.html)
|
||||
(4) Efficient Gradient Accumulation (https://lightning.ai/docs/fabric/2.4.0/advanced/gradient_accumulation.html)
|
||||
"""
|
||||
# Use custom DDP strategy only for multi-device, non-interactive environments
|
||||
if (
|
||||
# Handle XPU accelerator
|
||||
is_xpu = hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
if accelerator == "xpu" or (accelerator == "auto" and is_xpu and not torch.cuda.is_available()):
|
||||
accelerator = XPUAccelerator()
|
||||
precision_plugin = None
|
||||
if precision in ("16-mixed", "bf16-mixed"):
|
||||
precision_plugin = XPUMixedPrecision(precision=precision)
|
||||
precision = None # Handled by plugin
|
||||
strategy = SingleXPUStrategy(precision_plugin=precision_plugin)
|
||||
ranked_logger.info("Using Intel XPU with SingleXPUStrategy")
|
||||
elif (
|
||||
strategy == "ddp"
|
||||
and not is_interactive_environment()
|
||||
and not (num_nodes == 1 and devices_per_node == 1)
|
||||
|
||||
@@ -20,7 +20,7 @@ def is_rank_zero() -> bool:
|
||||
|
||||
|
||||
def set_accelerator_based_on_availability(cfg: dict | DictConfig):
|
||||
"""Set training accelerator to CPU if no GPUs are available.
|
||||
"""Set training accelerator based on available hardware.
|
||||
|
||||
Args:
|
||||
cfg: Hydra object with trainer settings "accelerator", "devices_per_node", and "num_nodes".
|
||||
@@ -28,22 +28,25 @@ def set_accelerator_based_on_availability(cfg: dict | DictConfig):
|
||||
Returns:
|
||||
None; modifies the input `cfg` object in place.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
logger.error(
|
||||
"No GPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
|
||||
)
|
||||
assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
|
||||
for key in ["accelerator", "devices_per_node", "num_nodes"]:
|
||||
assert (
|
||||
key in cfg.trainer
|
||||
), f"Configuration object must have a 'trainer.{key}' key."
|
||||
assert "trainer" in cfg, "Configuration object must have a 'trainer' key."
|
||||
for key in ["accelerator", "devices_per_node", "num_nodes"]:
|
||||
assert (
|
||||
key in cfg.trainer
|
||||
), f"Configuration object must have a 'trainer.{key}' key."
|
||||
|
||||
# Override accelerator settings
|
||||
if torch.cuda.is_available():
|
||||
cfg.trainer.accelerator = "gpu"
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
logger.info("Intel XPU detected - using XPU accelerator")
|
||||
cfg.trainer.accelerator = "xpu"
|
||||
else:
|
||||
logger.error(
|
||||
"No GPUs/XPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?"
|
||||
)
|
||||
cfg.trainer.accelerator = "cpu"
|
||||
cfg.trainer.devices_per_node = 1
|
||||
cfg.trainer.num_nodes = 1
|
||||
else:
|
||||
cfg.trainer.accelerator = "gpu"
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
28
src/foundry/utils/xpu/__init__.py
Normal file
28
src/foundry/utils/xpu/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""XPU utilities for Intel GPU support.
|
||||
|
||||
XPU support in PyTorch is now native (torch.xpu.is_available()), but Lightning Fabric
|
||||
requires custom Accelerator, Strategy, and Precision plugins for proper XPU handling.
|
||||
|
||||
These components are used directly (not via registry) when XPU is detected:
|
||||
- XPUAccelerator: Custom accelerator for XPU devices
|
||||
- SingleXPUStrategy: Strategy for single-device XPU training/inference
|
||||
- XPUMixedPrecision: Precision plugin with proper XPU autocast support
|
||||
|
||||
Usage:
|
||||
from foundry.utils.xpu import XPUAccelerator, SingleXPUStrategy, XPUMixedPrecision
|
||||
|
||||
# Check availability
|
||||
if XPUAccelerator.is_available():
|
||||
strategy = SingleXPUStrategy(precision_plugin=XPUMixedPrecision("bf16-mixed"))
|
||||
|
||||
Note:
|
||||
The FabricTrainer automatically uses these components when XPU is detected.
|
||||
You typically don't need to use them directly unless customizing behavior.
|
||||
"""
|
||||
|
||||
from .single_xpu_strategy import SingleXPUStrategy
|
||||
from .xpu_accelerator import XPUAccelerator
|
||||
from .xpu_precision import XPUMixedPrecision
|
||||
|
||||
__all__ = ["SingleXPUStrategy", "XPUAccelerator", "XPUMixedPrecision"]
|
||||
|
||||
48
src/foundry/utils/xpu/single_xpu_strategy.py
Normal file
48
src/foundry/utils/xpu/single_xpu_strategy.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Lightning Fabric strategy for single XPU device.
|
||||
|
||||
https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html
|
||||
"""
|
||||
|
||||
import torch
|
||||
from lightning.fabric.plugins import CheckpointIO
|
||||
from lightning.fabric.plugins.precision import Precision
|
||||
from lightning.fabric.strategies import SingleDeviceStrategy
|
||||
from lightning.fabric.utilities.types import _DEVICE
|
||||
|
||||
|
||||
class SingleXPUStrategy(SingleDeviceStrategy):
|
||||
"""Strategy for training/inference on a single Intel XPU device.
|
||||
|
||||
This strategy extends SingleDeviceStrategy to properly handle XPU devices.
|
||||
"""
|
||||
|
||||
strategy_name = "xpu_single"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: _DEVICE = "xpu:0",
|
||||
checkpoint_io: CheckpointIO | None = None,
|
||||
precision_plugin: Precision | None = None,
|
||||
) -> None:
|
||||
"""Initialize the single XPU strategy.
|
||||
|
||||
Args:
|
||||
device: The XPU device to use. Defaults to "xpu:0".
|
||||
checkpoint_io: Plugin for checkpoint I/O.
|
||||
precision_plugin: Plugin for precision handling (set via _precision property).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If XPU devices are not available.
|
||||
"""
|
||||
if not (hasattr(torch, "xpu") and torch.xpu.is_available()):
|
||||
msg = "`SingleXPUStrategy` requires XPU devices to run"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
super().__init__(
|
||||
device=device,
|
||||
checkpoint_io=checkpoint_io,
|
||||
)
|
||||
# Precision is handled via the _precision property in newer Lightning versions
|
||||
if precision_plugin is not None:
|
||||
self._precision = precision_plugin
|
||||
|
||||
92
src/foundry/utils/xpu/xpu_accelerator.py
Normal file
92
src/foundry/utils/xpu/xpu_accelerator.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""XPU Accelerator for Intel XPU devices."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from lightning.fabric.accelerators import Accelerator
|
||||
|
||||
|
||||
class XPUAccelerator(Accelerator):
|
||||
"""Accelerator for Intel XPU devices.
|
||||
|
||||
This accelerator enables training and inference on Intel GPUs using
|
||||
PyTorch's native XPU support (torch.xpu).
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the name of this accelerator."""
|
||||
return "xpu"
|
||||
|
||||
@staticmethod
|
||||
def setup_device(device: torch.device) -> None:
|
||||
"""Set up the specified XPU device.
|
||||
|
||||
Args:
|
||||
device: The torch device to set up.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If device is not an XPU device.
|
||||
"""
|
||||
if device.type != "xpu":
|
||||
msg = f"Device should be xpu, got {device} instead"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
@staticmethod
|
||||
def parse_devices(devices: str | list | torch.device) -> list:
|
||||
"""Parse devices specification for XPU training.
|
||||
|
||||
Args:
|
||||
devices: Device specification (int, list of ints, or string).
|
||||
|
||||
Returns:
|
||||
List of device indices.
|
||||
"""
|
||||
if isinstance(devices, list):
|
||||
return devices
|
||||
return [devices]
|
||||
|
||||
@staticmethod
|
||||
def get_parallel_devices(devices: list) -> list[torch.device]:
|
||||
"""Generate a list of parallel XPU devices.
|
||||
|
||||
Args:
|
||||
devices: List of device indices.
|
||||
|
||||
Returns:
|
||||
List of torch.device objects for XPU.
|
||||
"""
|
||||
return [torch.device("xpu", idx) for idx in devices]
|
||||
|
||||
@staticmethod
|
||||
def auto_device_count() -> int:
|
||||
"""Return the number of available XPU devices."""
|
||||
return torch.xpu.device_count()
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Check if XPU is available."""
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
|
||||
@staticmethod
|
||||
def get_device_stats(device: str | torch.device) -> dict[str, Any]:
|
||||
"""Return XPU device statistics.
|
||||
|
||||
Currently returns an empty dict as XPU stats API may vary.
|
||||
|
||||
Args:
|
||||
device: The device to get stats for.
|
||||
|
||||
Returns:
|
||||
Dictionary of device statistics.
|
||||
"""
|
||||
del device # Unused
|
||||
return {}
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""Clean up XPU accelerator resources."""
|
||||
# Empty implementation required by base class
|
||||
pass
|
||||
|
||||
73
src/foundry/utils/xpu/xpu_precision.py
Normal file
73
src/foundry/utils/xpu/xpu_precision.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""XPU Precision Plugin for Lightning Fabric."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Literal
|
||||
|
||||
import torch
|
||||
from lightning.fabric.plugins.precision import MixedPrecision
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class XPUMixedPrecision(MixedPrecision):
|
||||
"""Mixed precision plugin for Intel XPU devices.
|
||||
|
||||
This overrides the default MixedPrecision plugin to use 'xpu' as the
|
||||
device type for torch.autocast instead of 'cuda'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Literal["16-mixed", "bf16-mixed"] = "bf16-mixed",
|
||||
) -> None:
|
||||
"""Initialize XPU mixed precision.
|
||||
|
||||
Args:
|
||||
precision: The precision mode. "16-mixed" uses float16,
|
||||
"bf16-mixed" uses bfloat16. Defaults to "bf16-mixed".
|
||||
|
||||
Raises:
|
||||
ValueError: If precision is not "16-mixed" or "bf16-mixed".
|
||||
"""
|
||||
# Determine dtype from precision string
|
||||
if precision == "16-mixed":
|
||||
dtype = torch.float16
|
||||
elif precision == "bf16-mixed":
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
msg = f"Invalid precision: {precision}. Must be '16-mixed' or 'bf16-mixed'"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Initialize with xpu device type
|
||||
super().__init__(precision=precision, device="xpu")
|
||||
self._desired_input_dtype = dtype
|
||||
|
||||
@contextmanager
|
||||
def forward_context(self) -> Generator[None, None, None]:
|
||||
"""Context manager for forward pass with XPU autocast."""
|
||||
with torch.autocast(device_type="xpu", dtype=self._desired_input_dtype):
|
||||
yield
|
||||
|
||||
def convert_input(self, data: Any) -> Any:
|
||||
"""Convert input data to the appropriate precision.
|
||||
|
||||
Args:
|
||||
data: Input data to convert.
|
||||
|
||||
Returns:
|
||||
Converted data.
|
||||
"""
|
||||
return self._convert_fp_tensor(data)
|
||||
|
||||
def _convert_fp_tensor(self, data: Any) -> Any:
|
||||
"""Convert floating point tensors to the desired dtype.
|
||||
|
||||
Args:
|
||||
data: Data to convert.
|
||||
|
||||
Returns:
|
||||
Converted data if it's a floating point tensor, otherwise unchanged.
|
||||
"""
|
||||
if isinstance(data, Tensor) and data.is_floating_point():
|
||||
return data.to(self._desired_input_dtype)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user