Make format

This commit is contained in:
jbutch
2026-01-14 16:59:41 -08:00
parent e94fdf63e7
commit d35d8ac3c1
5 changed files with 3 additions and 5 deletions

View File

@@ -129,7 +129,9 @@ class FabricTrainer(ABC):
"""
# Handle XPU accelerator
is_xpu = hasattr(torch, "xpu") and torch.xpu.is_available()
if accelerator == "xpu" or (accelerator == "auto" and is_xpu and not torch.cuda.is_available()):
if accelerator == "xpu" or (
accelerator == "auto" and is_xpu and not torch.cuda.is_available()
):
accelerator = XPUAccelerator()
precision_plugin = None
if precision in ("16-mixed", "bf16-mixed"):

View File

@@ -25,4 +25,3 @@ from .xpu_accelerator import XPUAccelerator
from .xpu_precision import XPUMixedPrecision
__all__ = ["SingleXPUStrategy", "XPUAccelerator", "XPUMixedPrecision"]

View File

@@ -45,4 +45,3 @@ class SingleXPUStrategy(SingleDeviceStrategy):
# Precision is handled via the _precision property in newer Lightning versions
if precision_plugin is not None:
self._precision = precision_plugin

View File

@@ -89,4 +89,3 @@ class XPUAccelerator(Accelerator):
"""Clean up XPU accelerator resources."""
# Empty implementation required by base class
pass

View File

@@ -70,4 +70,3 @@ class XPUMixedPrecision(MixedPrecision):
if isinstance(data, Tensor) and data.is_floating_point():
return data.to(self._desired_input_dtype)
return data