diff --git a/src/foundry/trainers/fabric.py b/src/foundry/trainers/fabric.py index 86ee699..c3ddcbb 100755 --- a/src/foundry/trainers/fabric.py +++ b/src/foundry/trainers/fabric.py @@ -129,7 +129,9 @@ class FabricTrainer(ABC): """ # Handle XPU accelerator is_xpu = hasattr(torch, "xpu") and torch.xpu.is_available() - if accelerator == "xpu" or (accelerator == "auto" and is_xpu and not torch.cuda.is_available()): + if accelerator == "xpu" or ( + accelerator == "auto" and is_xpu and not torch.cuda.is_available() + ): accelerator = XPUAccelerator() precision_plugin = None if precision in ("16-mixed", "bf16-mixed"): diff --git a/src/foundry/utils/xpu/__init__.py b/src/foundry/utils/xpu/__init__.py index 62781d8..9b94fc1 100644 --- a/src/foundry/utils/xpu/__init__.py +++ b/src/foundry/utils/xpu/__init__.py @@ -25,4 +25,3 @@ from .xpu_accelerator import XPUAccelerator from .xpu_precision import XPUMixedPrecision __all__ = ["SingleXPUStrategy", "XPUAccelerator", "XPUMixedPrecision"] - diff --git a/src/foundry/utils/xpu/single_xpu_strategy.py b/src/foundry/utils/xpu/single_xpu_strategy.py index 577cd9d..7a518ba 100644 --- a/src/foundry/utils/xpu/single_xpu_strategy.py +++ b/src/foundry/utils/xpu/single_xpu_strategy.py @@ -45,4 +45,3 @@ class SingleXPUStrategy(SingleDeviceStrategy): # Precision is handled via the _precision property in newer Lightning versions if precision_plugin is not None: self._precision = precision_plugin - diff --git a/src/foundry/utils/xpu/xpu_accelerator.py b/src/foundry/utils/xpu/xpu_accelerator.py index 38386df..4700528 100644 --- a/src/foundry/utils/xpu/xpu_accelerator.py +++ b/src/foundry/utils/xpu/xpu_accelerator.py @@ -89,4 +89,3 @@ class XPUAccelerator(Accelerator): """Clean up XPU accelerator resources.""" # Empty implementation required by base class pass - diff --git a/src/foundry/utils/xpu/xpu_precision.py b/src/foundry/utils/xpu/xpu_precision.py index 665e358..22b3357 100644 --- a/src/foundry/utils/xpu/xpu_precision.py +++ b/src/foundry/utils/xpu/xpu_precision.py @@ -70,4 +70,3 @@ class XPUMixedPrecision(MixedPrecision): if isinstance(data, Tensor) and data.is_floating_point(): return data.to(self._desired_input_dtype) return data -