diff --git a/src/lightning/fabric/accelerators/__init__.py b/src/lightning/fabric/accelerators/__init__.py index 3d4b43f75c762..1a3dda20828a7 100644 --- a/src/lightning/fabric/accelerators/__init__.py +++ b/src/lightning/fabric/accelerators/__init__.py @@ -18,6 +18,7 @@ from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401 from lightning.fabric.accelerators.registry import _AcceleratorRegistry from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401 +from lightning.fabric.accelerators.xpu import XPUAccelerator # noqa: F401 from lightning.fabric.utilities.registry import _register_classes ACCELERATOR_REGISTRY = _AcceleratorRegistry() diff --git a/src/lightning/fabric/accelerators/xpu.py b/src/lightning/fabric/accelerators/xpu.py new file mode 100644 index 0000000000000..0d88468ccabd0 --- /dev/null +++ b/src/lightning/fabric/accelerators/xpu.py @@ -0,0 +1,113 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import lru_cache +from typing import Any, Dict, List, Union + +import torch +from lightning_utilities.core.imports import RequirementCache +from typing_extensions import override + +from lightning.fabric.accelerators.accelerator import Accelerator + + +class XPUAccelerator(Accelerator): + """Support for a Intel Discrete Graphics Cards 'XPU'.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if not _IPEX_AVAILABLE: + raise ModuleNotFoundError(str(_IPEX_AVAILABLE)) + super().__init__(*args, **kwargs) + + @staticmethod + @override + def parse_devices(devices: Any) -> Any: + # Put parsing logic here how devices can be passed into the Trainer + # via the `devices` argument + from lightning.fabric.utilities.device_parser import _parse_gpu_ids + + return _parse_gpu_ids(devices, include_xpu=True) + + @staticmethod + @override + def get_parallel_devices(devices: Any) -> Any: + # Here, convert the device indices to actual device objects + + return [torch.device("xpu", idx) for idx in devices] + + @staticmethod + @override + def auto_device_count() -> int: + # Return a value for auto-device selection when `Trainer(devices="auto")` + return num_xpu_devices() + + @staticmethod + @override + def is_available() -> bool: + # Carefully check before trying to import: + if _IPEX_AVAILABLE: + import intel_extension_for_pytorch as ipex + + return ipex.xpu.is_available() + return False + + @override + def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: + # Return optional device statistics for loggers + return {} + + @override + def setup_device(self, device: torch.device) -> None: + pass + + @override + def teardown(self) -> None: + pass + + @classmethod + @override + def register_accelerators(cls, accelerator_registry: Any) -> None: + accelerator_registry.register( + "xpu", + cls, + description=cls.__name__, + ) + + +_IPEX_AVAILABLE = RequirementCache("intel_extension_for_pytorch>=2.0", "intel_extension_for_pytorch") + + +@lru_cache(1) +def num_xpu_devices() -> int: + """Returns the number of available XPU devices. + + Unlike :func:`torch.xpu.device_count`, this function does its best not to create a XPU context for fork support, + if the platform allows it. + + """ + if _IPEX_AVAILABLE: + import intel_extension_for_pytorch as ipex + + return ipex.xpu.device_count() + return 0 + + +def _get_all_visible_xpu_devices() -> List[int]: + """Returns a list of all visible Intel XPU devices. + + Devices masked by the environment variabale ``ZE_AFFINITY_MASK`` won't be returned here. For example, assume you + have 8 physical GPUs. If ``ZE_AFFINITY_MASK="1,3,6"``, then this function will return the list ``[0, 1, 2]`` + because these are the three visible GPUs after applying the mask ``ZE_AFFINITY_MASK``. + + """ + return list(range(num_xpu_devices())) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index d8c6fe47b6630..22e9082ef27d4 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -36,7 +36,7 @@ _CLICK_AVAILABLE = RequirementCache("click") _LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") -_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") +_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "xpu") def _get_supported_strategies() -> List[str]: diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index f677893100351..3fb183ae8d363 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -23,6 +23,7 @@ from lightning.fabric.accelerators.cuda import CUDAAccelerator from lightning.fabric.accelerators.mps import MPSAccelerator from lightning.fabric.accelerators.xla import XLAAccelerator +from lightning.fabric.accelerators.xpu import XPUAccelerator from lightning.fabric.plugins import ( BitsandbytesPrecision, CheckpointIO, @@ -321,6 +322,8 @@ def _choose_auto_accelerator(self) -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if XPUAccelerator.is_available(): + return "xpu" return "cpu" @staticmethod @@ -329,6 +332,8 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if XPUAccelerator.is_available(): + return "xpu" raise RuntimeError("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: @@ -399,8 +404,8 @@ def _choose_strategy(self) -> Union[Strategy, str]: if self._num_nodes_flag > 1: return "ddp" if len(self._parallel_devices) <= 1: - if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( - isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") + if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator, XPUAccelerator)) or ( + isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps", "xpu") ): device = _determine_root_gpu_device(self._parallel_devices) else: diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 71d8f623dcee4..94e1c5411d5b9 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -97,7 +97,7 @@ class Fabric: Args: accelerator: The hardware to run on. Possible choices are: - ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. + ``"cpu"``, ``"cuda"``, ``"mps"``, ``"xpu"``, ``"gpu"``, ``"tpu"``, ``"auto"``. strategy: Strategy for how to run across multiple devices. Possible choices are: ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``. devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``. diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 0ec5df1a6b0ae..8d21f9460053c 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -123,8 +123,15 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() + print(self.root_device) + print(self.root_device.type) # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + if self.root_device.type == "cuda": + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + elif self.root_device.type == "xpu": + ctx = torch.xpu.stream(torch.xpu.Stream()) if device_ids is not None else nullcontext() + else: + ctx = nullcontext() with ctx: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6bfed6a270b68..2e194025f0cd6 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -326,6 +326,8 @@ def load_checkpoint( """ torch.cuda.empty_cache() + if hasattr(torch, "xpu"): + torch.xpu.empty_cache() checkpoint = self.checkpoint_io.load_checkpoint(path) if not state: return checkpoint diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py index 9f06dc50cfbef..f44f961de3104 100644 --- a/src/lightning/fabric/utilities/device_dtype_mixin.py +++ b/src/lightning/fabric/utilities/device_dtype_mixin.py @@ -44,6 +44,9 @@ def device(self) -> torch.device: if device.type == "cuda" and device.index is None: return torch.device(f"cuda:{torch.cuda.current_device()}") + if hasattr(torch, "xpu") and device.type == "xpu" and device.index is None: + return torch.device(f"xpu:{torch.xpu.current_device()}") + return device @override @@ -75,6 +78,27 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: _update_properties(self, device=device) return super().cuda(device=device) + @override + def xpu(self, device: Optional[Union[torch.device, int]] = None) -> Self: + """Moves all model parameters and buffers to the XPU GPU. This also makes associated parameters and buffers + different objects. So it should be called before constructing optimizer if the module will live on GPU while + being optimized. + + Arguments: + device: If specified, all parameters will be copied to that device. If `None`, the current XPU device + index will be used. + + Returns: + Module: self + + """ + if device is None: + device = torch.device("xpu", torch.xpu.current_device()) + elif isinstance(device, int): + device = torch.device("xpu", index=device) + _update_properties(self, device=device) + return super().xpu(device=device) + @override def cpu(self) -> Self: """See :meth:`torch.nn.Module.cpu`.""" diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 16965d944caec..5962b33504835 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -49,6 +49,7 @@ def _parse_gpu_ids( gpus: Optional[Union[int, str, List[int]]], include_cuda: bool = False, include_mps: bool = False, + include_xpu: bool = False, ) -> Optional[List[int]]: """Parses the GPU IDs given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer`. @@ -60,6 +61,7 @@ def _parse_gpu_ids( Any int N > 0 indicates that GPUs [0..N) should be used. include_cuda: A boolean value indicating whether to include CUDA devices for GPU parsing. include_mps: A boolean value indicating whether to include MPS devices for GPU parsing. + include_xpu: A boolean value indicating whether to include XPU devices for GPU parsing. Returns: A list of GPUs to be used or ``None`` if no GPUs were requested @@ -69,7 +71,7 @@ def _parse_gpu_ids( If no GPUs are available but the value of gpus variable indicates request for GPUs .. note:: - ``include_cuda`` and ``include_mps`` default to ``False`` so that you only + ``include_cuda`` and ``include_mps`` and ``include_xpu`` default to ``False`` so that you only have to specify which device type to use and all other devices are not disabled. """ @@ -83,7 +85,9 @@ def _parse_gpu_ids( # We know the user requested GPUs therefore if some of the # requested GPUs are not available an exception is thrown. gpus = _normalize_parse_gpu_string_input(gpus) - gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps) + gpus = _normalize_parse_gpu_input_to_list( + gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) if not gpus: raise MisconfigurationException("GPUs requested but none are available.") @@ -91,7 +95,8 @@ def _parse_gpu_ids( torch.distributed.is_available() and torch.distributed.is_torchelastic_launched() and len(gpus) != 1 - and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps)) == 1 + and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu)) + == 1 ): # Omit sanity check on torchelastic because by default it shows one visible GPU per process return gpus @@ -99,7 +104,7 @@ def _parse_gpu_ids( # Check that GPUs are unique. Duplicate GPUs are not supported by the backend. _check_unique(gpus) - return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) + return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: @@ -112,7 +117,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _sanitize_gpu_ids( + gpus: List[int], include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -127,9 +134,11 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: If machine has fewer available GPUs than requested. """ - if sum((include_cuda, include_mps)) == 0: + if sum((include_cuda, include_mps, include_xpu)) == 0: raise ValueError("At least one gpu type should be specified!") - all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + all_available_gpus = _get_all_available_gpus( + include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu + ) for gpu in gpus: if gpu not in all_available_gpus: raise MisconfigurationException( @@ -139,7 +148,7 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool + gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool, include_xpu: bool ) -> Optional[List[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): @@ -149,22 +158,27 @@ def _normalize_parse_gpu_input_to_list( if not gpus: # gpus==0 return None if gpus == -1: - return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps) + return _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_xpu=include_xpu) return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _get_all_available_gpus( + include_cuda: bool = False, include_mps: bool = False, include_xpu: bool = False +) -> List[int]: """ Returns: A list of all available GPUs """ + from lightning.fabric.accelerators.cuda import _get_all_visible_cuda_devices from lightning.fabric.accelerators.mps import _get_all_available_mps_gpus + from lightning.fabric.accelerators.xpu import _get_all_visible_xpu_devices cuda_gpus = _get_all_visible_cuda_devices() if include_cuda else [] mps_gpus = _get_all_available_mps_gpus() if include_mps else [] - return cuda_gpus + mps_gpus + xpu_gpus = _get_all_visible_xpu_devices() if include_xpu else [] + return cuda_gpus + mps_gpus + xpu_gpus def _check_unique(device_ids: List[int]) -> None: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 30bfe4e254a07..b1461efd5f465 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -289,6 +289,10 @@ def _init_dist_connection( os.environ["MASTER_ADDR"] = cluster_environment.main_address os.environ["MASTER_PORT"] = str(cluster_environment.main_port) log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + + if torch_distributed_backend.lower() == "ccl": + pass + torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs) # On rank=0 let everyone know training is starting @@ -301,7 +305,11 @@ def _init_dist_connection( def _get_default_process_group_backend_for_device(device: torch.device) -> str: - return "nccl" if device.type == "cuda" else "gloo" + if device.type == "cuda": + return "nccl" + if device.type == "xpu": + return "ccl" + return "gloo" class _DatasetSamplerWrapper(Dataset): diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index b274bce88fcdf..4e78f01079b18 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -102,7 +102,7 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: random.seed(stdlib_seed) -def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: +def _collect_rng_states(include_cuda: bool = True, include_xpu: bool = True) -> Dict[str, Any]: r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" states = { "torch": torch.get_rng_state(), @@ -111,6 +111,8 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: } if include_cuda: states["torch.cuda"] = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else [] + if include_xpu and hasattr(torch, "xpu"): + states["torch.xpu"] = torch.xpu.get_rng_state_all() if torch.xpu.is_available() else [] return states @@ -121,6 +123,8 @@ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: # torch.cuda rng_state is only included since v1.8. if "torch.cuda" in rng_state_dict: torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"]) + if "torch.xpu" in rng_state_dict and hasattr(torch, "xpu"): + torch.xpu.set_rng_state_all(rng_state_dict["torch.xpu"]) np.random.set_state(rng_state_dict["numpy"]) version, state, gauss = rng_state_dict["python"] python_set_rng_state((version, tuple(state), gauss)) diff --git a/src/lightning/pytorch/accelerators/__init__.py b/src/lightning/pytorch/accelerators/__init__.py index 4cadee51f64c7..cbeb82d1cb32e 100644 --- a/src/lightning/pytorch/accelerators/__init__.py +++ b/src/lightning/pytorch/accelerators/__init__.py @@ -20,6 +20,7 @@ from lightning.pytorch.accelerators.cuda import CUDAAccelerator # noqa: F401 from lightning.pytorch.accelerators.mps import MPSAccelerator # noqa: F401 from lightning.pytorch.accelerators.xla import XLAAccelerator # noqa: F401 +from lightning.pytorch.accelerators.xpu import XPUAccelerator # noqa: F401 AcceleratorRegistry = _AcceleratorRegistry() _register_classes(AcceleratorRegistry, "register_accelerators", sys.modules[__name__], Accelerator) diff --git a/src/lightning/pytorch/accelerators/xpu.py b/src/lightning/pytorch/accelerators/xpu.py new file mode 100644 index 0000000000000..321bb3f206473 --- /dev/null +++ b/src/lightning/pytorch/accelerators/xpu.py @@ -0,0 +1,114 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import lru_cache +from typing import Any, Dict, List + +import torch +from lightning_utilities.core.imports import RequirementCache +from typing_extensions import override + +from lightning.fabric.utilities.types import _DEVICE +from lightning.pytorch.accelerators.accelerator import Accelerator + + +class XPUAccelerator(Accelerator): + """Support for a Intel Discrete Graphics Cards 'XPU'.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + if not _IPEX_AVAILABLE: + raise ModuleNotFoundError(str(_IPEX_AVAILABLE)) + super().__init__(*args, **kwargs) + + @staticmethod + @override + def parse_devices(devices: Any) -> Any: + # Put parsing logic here how devices can be passed into the Trainer + # via the `devices` argument + from lightning.fabric.utilities.device_parser import _parse_gpu_ids + + return _parse_gpu_ids(devices, include_xpu=True) + + @staticmethod + @override + def get_parallel_devices(devices: Any) -> Any: + # Here, convert the device indices to actual device objects + + return [torch.device("xpu", idx) for idx in devices] + + @staticmethod + @override + def auto_device_count() -> int: + # Return a value for auto-device selection when `Trainer(devices="auto")` + return num_xpu_devices() + + @staticmethod + @override + def is_available() -> bool: + # Carefully check before trying to import: + if _IPEX_AVAILABLE: + import intel_extension_for_pytorch as ipex + + return ipex.xpu.is_available() + return False + + @override + def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + # Return optional device statistics for loggers + return torch.xpu.memory_stats(device) + + @override + def setup_device(self, device: torch.device) -> None: + pass + + @override + def teardown(self) -> None: + pass + + @classmethod + @override + def register_accelerators(cls, accelerator_registry: Any) -> None: + accelerator_registry.register( + "xpu", + cls, + description=cls.__class__.__name__, + ) + + +_IPEX_AVAILABLE = RequirementCache("intel_extension_for_pytorch>=1.13", "intel_extension_for_pytorch") + + +@lru_cache(1) +def num_xpu_devices() -> int: + """Returns the number of available CUDA devices. + + Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support, + if the platform allows it. + + """ + if _IPEX_AVAILABLE: + import intel_extension_for_pytorch as ipex + + return ipex.xpu.device_count() + return 0 + + +def _get_all_visible_xpu_devices() -> List[int]: + """Returns a list of all visible Intel XPU devices. + + Devices masked by the environment variabale ``ZE_AFFINITY_MASK`` won't be returned here. For example, assume you + have 8 physical GPUs. If ``ZE_AFFINITY_MASK="1,3,6"``, then this function will return the list ``[0, 1, 2]`` + because these are the three visible GPUs after applying the mask ``ZE_AFFINITY_MASK``. + + """ + return list(range(num_xpu_devices())) diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index a2d73d83184b1..49cf0f14472a6 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -119,6 +119,8 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, if trainer.strategy.root_device.type == "cuda": # required or else perf_counter() won't be correct torch.cuda.synchronize() + elif trainer.strategy.root_device.type == "xpu": + torch.xpu.synchronize() elapsed = time.perf_counter() - self._t0s[stage] if self.length_fn is not None: diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 5a4f8d4e1bbb1..44003a1120596 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -290,7 +290,7 @@ def on_gpu(self) -> bool: Useful to set flags around the LightningModule for different CPU vs GPU behavior. """ - return self.device.type == "cuda" + return self.device.type == "cuda" or self.device.type == "xpu" @property def automatic_optimization(self) -> bool: diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index f8e9c8300337a..f7e02c4ec156b 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -34,7 +34,7 @@ from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.fabric.utilities.data import AttributeDict from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH -from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator +from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator, XPUAccelerator from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from lightning.pytorch.utilities.migration import pl_legacy_patch from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint @@ -109,6 +109,8 @@ def _default_map_location(storage: "UntypedStorage", location: str) -> Optional[ and not CUDAAccelerator.is_available() or location.startswith("xla") and not XLAAccelerator.is_available() + or location.startswith("xpu") + and not XPUAccelerator.is_available() ): return storage.cpu() return None # default behavior by `torch.load()` diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9031b6ee177f3..f95bbbda4ad11 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -190,7 +190,12 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + if self.root_device.type == "cuda": + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + elif self.root_device.type == "xpu": + ctx = torch.xpu.stream(torch.xpu.Stream()) if device_ids is not None else nullcontext() + else: + ctx = nullcontext() with ctx: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) @@ -304,7 +309,10 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj obj = [obj] - torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + if self.root_device.type != "xpu" and isinstance(type(obj[0]), str): + # I don't know why this is true. I will have to investigate. In the meantime, + # This is getting called by the profiler which can be worked around: + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) return obj[0] @override diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 1c97a223b129e..e4c9992f6310b 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -35,6 +35,7 @@ from lightning.pytorch.accelerators.cuda import CUDAAccelerator from lightning.pytorch.accelerators.mps import MPSAccelerator from lightning.pytorch.accelerators.xla import XLAAccelerator +from lightning.pytorch.accelerators.xpu import XPUAccelerator from lightning.pytorch.plugins import ( _PLUGIN_INPUT, BitsandbytesPrecision, @@ -308,6 +309,13 @@ def _check_config_and_set_final_flags( f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cuda" + if self._strategy_flag.parallel_devices[0].type == "xpu": + if self._accelerator_flag and self._accelerator_flag not in ("auto", "xpu", "gpu"): + raise MisconfigurationException( + f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," + f" but accelerator set to {self._accelerator_flag}, please choose one device type" + ) + self._accelerator_flag = "xpu" self._parallel_devices = self._strategy_flag.parallel_devices def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: @@ -341,6 +349,8 @@ def _choose_auto_accelerator(self) -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if XPUAccelerator.is_available(): + return "xpu" return "cpu" @staticmethod @@ -349,6 +359,8 @@ def _choose_gpu_accelerator_backend() -> str: return "mps" if CUDAAccelerator.is_available(): return "cuda" + if XPUAccelerator.is_available(): + return "xpu" raise MisconfigurationException("No supported gpu backend found!") def _set_parallel_devices_and_init_accelerator(self) -> None: diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 00b546b252ac8..5bcd09a054d8e 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -17,7 +17,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator +from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator, XPUAccelerator from lightning.pytorch.loggers.logger import DummyLogger from lightning.pytorch.profilers import ( AdvancedProfiler, @@ -148,11 +148,14 @@ def _log_device_info(trainer: "pl.Trainer") -> None: elif MPSAccelerator.is_available(): gpu_available = True gpu_type = " (mps)" + elif XPUAccelerator.is_available(): + gpu_available = True + gpu_type = " (xpu)" else: gpu_available = False gpu_type = "" - gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) + gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator, XPUAccelerator)) rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0 @@ -173,6 +176,8 @@ def _log_device_info(trainer: "pl.Trainer") -> None: and not isinstance(trainer.accelerator, CUDAAccelerator) or MPSAccelerator.is_available() and not isinstance(trainer.accelerator, MPSAccelerator) + or XPUAccelerator.is_available() + and not isinstance(trainer.accelerator, XPUAccelerator) ): rank_zero_warn( "GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.",