Skip to content

Commit f977d14

Browse files
committed
add npu support
1 parent 93c1ab0 commit f977d14

File tree

9 files changed

+152
-13
lines changed

9 files changed

+152
-13
lines changed

src/lightning/pytorch/accelerators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from lightning.pytorch.accelerators.cpu import CPUAccelerator # noqa: F401
2020
from lightning.pytorch.accelerators.cuda import CUDAAccelerator # noqa: F401
2121
from lightning.pytorch.accelerators.mps import MPSAccelerator # noqa: F401
22+
from lightning.pytorch.accelerators.npu import NPUAccelerator # noqa: F401
2223
from lightning.pytorch.accelerators.xla import XLAAccelerator # noqa: F401
2324

2425
AcceleratorRegistry = _AcceleratorRegistry()

src/lightning/pytorch/accelerators/accelerator.py

+7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC
15+
from contextlib import nullcontext
1516
from typing import Any, Dict
1617

1718
import lightning.pytorch as pl
@@ -45,3 +46,9 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
4546
4647
"""
4748
raise NotImplementedError
49+
50+
def get_distribute_name(self) -> str:
51+
return "gloo"
52+
53+
def get_stream_context(self, device_id: Any) -> Any:
54+
return nullcontext()

src/lightning/pytorch/accelerators/cuda.py

+9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import shutil
1717
import subprocess
18+
from contextlib import nullcontext
1819
from typing import Any, Dict, List, Optional, Union
1920

2021
import torch
@@ -104,6 +105,14 @@ def auto_device_count() -> int:
104105
def is_available() -> bool:
105106
return num_cuda_devices() > 0
106107

108+
@override
109+
def get_distribute_name(self) -> str:
110+
return "nccl"
111+
112+
@override
113+
def get_stream_context(self, device_id: List[int]) -> Any:
114+
return torch.cuda.stream(torch.cuda.Stream()) if device_id is not None else nullcontext()
115+
107116
@classmethod
108117
@override
109118
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from contextlib import nullcontext
15+
from typing import Any, Dict, List, Optional, Union
16+
17+
import torch
18+
from typing_extensions import override
19+
20+
from lightning.fabric.accelerators import _AcceleratorRegistry
21+
from lightning.fabric.utilities.types import _DEVICE
22+
from lightning.pytorch.accelerators.accelerator import Accelerator
23+
from lightning.pytorch.utilities.exceptions import MisconfigurationException
24+
25+
26+
class NPUAccelerator(Accelerator):
27+
"""Accelerator for Ascend NPU devices."""
28+
29+
@override
30+
def setup_device(self, device: torch.device) -> None:
31+
"""
32+
Raises:
33+
MisconfigurationException:
34+
If the selected device is not NPU.
35+
"""
36+
if device.type != "npu":
37+
raise MisconfigurationException(f"Device should be NPU, got {device} instead.")
38+
torch.npu.set_device(device)
39+
40+
@override
41+
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
42+
return torch.npu.memory_stats(device)
43+
44+
@override
45+
def teardown(self) -> None:
46+
torch.npu.empty_cache()
47+
48+
@staticmethod
49+
@override
50+
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
51+
"""Accelerator device parsing logic.
52+
53+
-1 or '-1' means use all npus.
54+
55+
"""
56+
57+
if isinstance(devices, list):
58+
return devices
59+
if isinstance(devices, str):
60+
if devices == "-1":
61+
return list(range(torch.npu.device_count()))
62+
if "," in devices:
63+
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
64+
return list(range(int(devices.strip())))
65+
if isinstance(devices, int):
66+
if devices == -1:
67+
return list(range(torch.npu.device_count()))
68+
return list(range(devices))
69+
70+
return None
71+
72+
@staticmethod
73+
@override
74+
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
75+
"""Gets parallel devices for the Accelerator."""
76+
77+
return [torch.device("npu", i) for i in devices]
78+
79+
@staticmethod
80+
@override
81+
def auto_device_count() -> int:
82+
"""Get the devices when set to auto."""
83+
84+
return torch.npu.device_count()
85+
86+
@staticmethod
87+
@override
88+
def is_available() -> bool:
89+
try:
90+
import torch_npu # noqa: F401
91+
92+
return torch.npu.device_count() > 0
93+
except ImportError:
94+
# NPU may raise these exceptions if it's not properly configured.
95+
return False
96+
97+
@override
98+
def get_distribute_name(self) -> str:
99+
return "hccl"
100+
101+
@override
102+
def get_stream_context(self, device_id: List[int]) -> Any:
103+
return torch.npu.stream(torch.npu.Stream()) if device_id is not None else nullcontext()
104+
105+
@classmethod
106+
@override
107+
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
108+
accelerator_registry.register(
109+
"npu",
110+
cls,
111+
description=cls.__name__,
112+
)

src/lightning/pytorch/strategies/ddp.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from contextlib import nullcontext
1615
from datetime import timedelta
1716
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union
1817

@@ -31,7 +30,6 @@
3130
from lightning.fabric.strategies import _StrategyRegistry
3231
from lightning.fabric.utilities.distributed import (
3332
_distributed_is_initialized,
34-
_get_default_process_group_backend_for_device,
3533
_init_dist_connection,
3634
_sync_ddp_if_available,
3735
)
@@ -193,7 +191,8 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
193191
device_ids = self.determine_ddp_device_ids()
194192
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
195193
# https://pytorch.org/docs/stable/notes/cuda.html#id5
196-
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
194+
assert self.accelerator is not None
195+
ctx = self.accelerator.get_stream_context(device_ids)
197196
with ctx:
198197
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
199198

@@ -206,7 +205,8 @@ def setup_distributed(self) -> None:
206205
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
207206

208207
def _get_process_group_backend(self) -> str:
209-
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
208+
assert self.accelerator is not None
209+
return self._process_group_backend or self.accelerator.get_distribute_name()
210210

211211
def set_world_ranks(self) -> None:
212212
if self.cluster_environment is not None:

src/lightning/pytorch/strategies/deepspeed.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from lightning.fabric.utilities.seed import reset_seed
4040
from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
4141
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
42+
from lightning.pytorch.accelerators.npu import NPUAccelerator
4243
from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers
4344
from lightning.pytorch.plugins.precision import Precision
4445
from lightning.pytorch.strategies.ddp import DDPStrategy
@@ -315,10 +316,10 @@ def __init__(
315316

316317
@override
317318
def setup_environment(self) -> None:
318-
if not isinstance(self.accelerator, CUDAAccelerator):
319+
if not isinstance(self.accelerator, (CUDAAccelerator, NPUAccelerator)):
319320
raise RuntimeError(
320-
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
321-
" is used."
321+
f"The DeepSpeed strategy is only supported on CUDA GPUs or Ascend NPUs but"
322+
" `{self.accelerator.__class__.__name__}` is used."
322323
)
323324
super().setup_environment()
324325

src/lightning/pytorch/strategies/fsdp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
)
4848
from lightning.fabric.utilities.distributed import (
4949
_distributed_is_initialized,
50-
_get_default_process_group_backend_for_device,
5150
_init_dist_connection,
5251
_sync_ddp_if_available,
5352
)
@@ -261,7 +260,8 @@ def setup_environment(self) -> None:
261260
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
262261

263262
def _get_process_group_backend(self) -> str:
264-
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
263+
assert self.accelerator is not None
264+
return self._process_group_backend or self.accelerator.get_distribute_name()
265265

266266
def set_world_ranks(self) -> None:
267267
if self.cluster_environment is not None:

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from lightning.pytorch.accelerators.accelerator import Accelerator
3535
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
3636
from lightning.pytorch.accelerators.mps import MPSAccelerator
37+
from lightning.pytorch.accelerators.npu import NPUAccelerator
3738
from lightning.pytorch.accelerators.xla import XLAAccelerator
3839
from lightning.pytorch.plugins import (
3940
_PLUGIN_INPUT,
@@ -355,6 +356,8 @@ def _choose_auto_accelerator(self) -> str:
355356
return "mps"
356357
if CUDAAccelerator.is_available():
357358
return "cuda"
359+
if NPUAccelerator.is_available():
360+
return "npu"
358361
return "cpu"
359362

360363
@staticmethod
@@ -462,7 +465,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
462465
return "ddp"
463466
if len(self._parallel_devices) <= 1:
464467
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
465-
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
468+
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps", "npu")
466469
):
467470
device = _determine_root_gpu_device(self._parallel_devices)
468471
else:
@@ -482,9 +485,9 @@ def _check_strategy_and_fallback(self) -> None:
482485

483486
if (
484487
strategy_flag in FSDPStrategy.get_registered_strategies() or isinstance(self._strategy_flag, FSDPStrategy)
485-
) and self._accelerator_flag not in ("cuda", "gpu"):
488+
) and self._accelerator_flag not in ("cuda", "gpu", "npu"):
486489
raise MisconfigurationException(
487-
f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used."
490+
f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU nor NPU accelerator is not used."
488491
)
489492
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
490493
raise ValueError(

src/lightning/pytorch/trainer/setup.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import lightning.pytorch as pl
1919
from lightning.fabric.utilities.warnings import PossibleUserWarning
20-
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator
20+
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, NPUAccelerator, XLAAccelerator
2121
from lightning.pytorch.loggers.logger import DummyLogger
2222
from lightning.pytorch.profilers import (
2323
AdvancedProfiler,
@@ -178,6 +178,9 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
178178
hpu_available = False
179179
rank_zero_info(f"HPU available: {hpu_available}, using: {num_hpus} HPUs")
180180

181+
number_npu_cores = trainer.num_devices if isinstance(trainer.accelerator, NPUAccelerator) else 0
182+
rank_zero_info(f"NPU available: {NPUAccelerator.is_available()}, using: {number_npu_cores} NPU cores")
183+
181184
if (
182185
CUDAAccelerator.is_available()
183186
and not isinstance(trainer.accelerator, CUDAAccelerator)
@@ -203,3 +206,6 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
203206

204207
if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):
205208
rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.")
209+
210+
if NPUAccelerator.is_available() and not isinstance(trainer.accelerator, NPUAccelerator):
211+
rank_zero_warn("NPU available but not used. You can set it by doing `Trainer(accelerator='npu')`.")

0 commit comments

Comments
 (0)