mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Make format
This commit is contained in:
@@ -129,7 +129,9 @@ class FabricTrainer(ABC):
|
||||
"""
|
||||
# Handle XPU accelerator
|
||||
is_xpu = hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
if accelerator == "xpu" or (accelerator == "auto" and is_xpu and not torch.cuda.is_available()):
|
||||
if accelerator == "xpu" or (
|
||||
accelerator == "auto" and is_xpu and not torch.cuda.is_available()
|
||||
):
|
||||
accelerator = XPUAccelerator()
|
||||
precision_plugin = None
|
||||
if precision in ("16-mixed", "bf16-mixed"):
|
||||
|
||||
@@ -25,4 +25,3 @@ from .xpu_accelerator import XPUAccelerator
|
||||
from .xpu_precision import XPUMixedPrecision
|
||||
|
||||
__all__ = ["SingleXPUStrategy", "XPUAccelerator", "XPUMixedPrecision"]
|
||||
|
||||
|
||||
@@ -45,4 +45,3 @@ class SingleXPUStrategy(SingleDeviceStrategy):
|
||||
# Precision is handled via the _precision property in newer Lightning versions
|
||||
if precision_plugin is not None:
|
||||
self._precision = precision_plugin
|
||||
|
||||
|
||||
@@ -89,4 +89,3 @@ class XPUAccelerator(Accelerator):
|
||||
"""Clean up XPU accelerator resources."""
|
||||
# Empty implementation required by base class
|
||||
pass
|
||||
|
||||
|
||||
@@ -70,4 +70,3 @@ class XPUMixedPrecision(MixedPrecision):
|
||||
if isinstance(data, Tensor) and data.is_floating_point():
|
||||
return data.to(self._desired_input_dtype)
|
||||
return data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user