From 9a7bf400e868533413ced407bd5749be52c89852 Mon Sep 17 00:00:00 2001 From: hehepig4 Date: Tue, 3 Dec 2024 16:20:44 +0800 Subject: [PATCH 1/8] Add Deepspeed Zero 3 MiCS support (Issues #20378) --- src/lightning/pytorch/strategies/deepspeed.py | 25 +++- .../strategies/test_deepspeed.py | 122 ++++++++++++++++++ 2 files changed, 141 insertions(+), 6 deletions(-) mode change 100644 => 100755 tests/tests_pytorch/strategies/test_deepspeed.py diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4fa771114768d..ae317b15c9057 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -521,12 +521,25 @@ def model_sharded_context(self) -> Generator[None, None, None]: import deepspeed self._init_config_if_needed() - with deepspeed.zero.Init( - enabled=self.zero_stage_3, - remote_device=self.remote_device, - config_dict_or_path=self.config, - ): - yield + + # If detect 'mics_shard_size'>0 in config['zero_optimization'], alter to use deepspeed.zero.MiCS_Init() + # https://deepspeed.readthedocs.io/en/latest/zero3.html#mics-configurations + #! default deepspeed 0.9.0 is not compatible + if 'zero_optimization' in self.config and 'mics_shard_size' in self.config['zero_optimization']\ + and self.config['zero_optimization']['mics_shard_size'] > 0 and self.zero_stage_3: + with deepspeed.zero.MiCS_Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ): + yield + else: + with deepspeed.zero.Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ): + yield def _set_deepspeed_activation_checkpointing(self) -> None: import deepspeed diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py old mode 100644 new mode 100755 index 73697ea131545..2a1317c9c9707 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -1279,3 +1279,125 @@ def test_deepspeed_load_checkpoint_validate_path(tmp_path): checkpoint_path.touch() with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"): strategy.load_checkpoint(checkpoint_path=checkpoint_path) + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_support(tmp_path): + """Test to ensure we can use DeepSpeed with basic ZeRO Stage 3 MiCS Support""" + model = ModelParallelBoringModel() + strategy = DeepSpeedStrategy(stage=3) + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + + trainer = Trainer( + default_root_dir=tmp_path, + strategy=strategy, + accelerator="gpu", + devices=2, + fast_dev_run=True, + precision="16-mixed", + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.test(model) + trainer.fit(model) + + _assert_save_model_is_equal(model, tmp_path, trainer) + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert 'zero_optimization' in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] == False + assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1 + assert trainer.strategy.config["zero_optimization"]["stage"] == 3 + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support(tmp_path): + """Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support \ + However, in some past pratice, offload param + mics + torchrun will cause inner exception in multi-node environment. \ + Probably this exception is caused by torchrun, not deepspeed. """ + model = ModelParallelBoringModel() + strategy = DeepSpeedStrategy(stage=3,offload_params_device="cpu") + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + trainer = Trainer( + default_root_dir=tmp_path, + strategy=strategy, + accelerator="gpu", + devices=2, + fast_dev_run=True, + precision="16-mixed", + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.test(model) + trainer.fit(model) + + _assert_save_model_is_equal(model, tmp_path, trainer) + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert 'zero_optimization' in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] == False + assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1 + assert trainer.strategy.config["zero_optimization"]["stage"] == 3 + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support(tmp_path): + """Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support""" + model = ModelParallelBoringModel() + strategy = DeepSpeedStrategy(stage=3,offload_params_device="cpu", offload_optimizer_device="cpu") + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + trainer = Trainer( + default_root_dir=tmp_path, + strategy=strategy, + accelerator="gpu", + devices=2, + fast_dev_run=True, + precision="16-mixed", + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.test(model) + trainer.fit(model) + + _assert_save_model_is_equal(model, tmp_path, trainer) + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert 'zero_optimization' in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] == False + assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1 + assert trainer.strategy.config["zero_optimization"]["stage"] == 3 + + +@RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(tmp_path): + """Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' = True).""" + model = ModelParallelBoringModel() + strategy = DeepSpeedStrategy(stage=3) + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 2 + strategy.config["zero_optimization"]["offload_param"] = {} + strategy.config["zero_optimization"]["offload_optimizer"] = {} + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True + #Forming a 2 x 2 hierarchy + trainer = Trainer( + default_root_dir=tmp_path, + strategy=strategy, + accelerator="gpu", + devices=4, + fast_dev_run=True, + precision="16-mixed", + enable_progress_bar=False, + enable_model_summary=False, + ) + trainer.test(model) + trainer.fit(model) + + _assert_save_model_is_equal(model, tmp_path, trainer) + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert 'zero_optimization' in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] == True + assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 2 + assert trainer.strategy.config["zero_optimization"]["stage"] == 3 + From 2bb1e4a36ed815b7d687c601a19d61292b5c05fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:00:20 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/strategies/deepspeed.py | 10 +++-- .../strategies/test_deepspeed.py | 37 ++++++++++--------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index ae317b15c9057..fd29ca82fe4d1 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -521,12 +521,16 @@ def model_sharded_context(self) -> Generator[None, None, None]: import deepspeed self._init_config_if_needed() - + # If detect 'mics_shard_size'>0 in config['zero_optimization'], alter to use deepspeed.zero.MiCS_Init() # https://deepspeed.readthedocs.io/en/latest/zero3.html#mics-configurations #! default deepspeed 0.9.0 is not compatible - if 'zero_optimization' in self.config and 'mics_shard_size' in self.config['zero_optimization']\ - and self.config['zero_optimization']['mics_shard_size'] > 0 and self.zero_stage_3: + if ( + "zero_optimization" in self.config + and "mics_shard_size" in self.config["zero_optimization"] + and self.config["zero_optimization"]["mics_shard_size"] > 0 + and self.zero_stage_3 + ): with deepspeed.zero.MiCS_Init( enabled=self.zero_stage_3, remote_device=self.remote_device, diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 2a1317c9c9707..d807eaee49590 100755 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -1283,7 +1283,7 @@ def test_deepspeed_load_checkpoint_validate_path(tmp_path): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multigpu_stage_3_MiCS_support(tmp_path): - """Test to ensure we can use DeepSpeed with basic ZeRO Stage 3 MiCS Support""" + """Test to ensure we can use DeepSpeed with basic ZeRO Stage 3 MiCS Support.""" model = ModelParallelBoringModel() strategy = DeepSpeedStrategy(stage=3) strategy.config["zero_optimization"]["stage"] = 3 @@ -1302,11 +1302,11 @@ def test_deepspeed_multigpu_stage_3_MiCS_support(tmp_path): ) trainer.test(model) trainer.fit(model) - + _assert_save_model_is_equal(model, tmp_path, trainer) assert isinstance(trainer.strategy, DeepSpeedStrategy) - assert 'zero_optimization' in trainer.strategy.config - assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] == False + assert "zero_optimization" in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1 assert trainer.strategy.config["zero_optimization"]["stage"] == 3 @@ -1317,9 +1317,9 @@ def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support(tmp_path): However, in some past pratice, offload param + mics + torchrun will cause inner exception in multi-node environment. \ Probably this exception is caused by torchrun, not deepspeed. """ model = ModelParallelBoringModel() - strategy = DeepSpeedStrategy(stage=3,offload_params_device="cpu") + strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu") strategy.config["zero_optimization"]["stage"] = 3 - strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False trainer = Trainer( default_root_dir=tmp_path, @@ -1336,18 +1336,19 @@ def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support(tmp_path): _assert_save_model_is_equal(model, tmp_path, trainer) assert isinstance(trainer.strategy, DeepSpeedStrategy) - assert 'zero_optimization' in trainer.strategy.config - assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] == False + assert "zero_optimization" in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1 assert trainer.strategy.config["zero_optimization"]["stage"] == 3 + @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support(tmp_path): - """Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support""" + """Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support.""" model = ModelParallelBoringModel() - strategy = DeepSpeedStrategy(stage=3,offload_params_device="cpu", offload_optimizer_device="cpu") + strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu", offload_optimizer_device="cpu") strategy.config["zero_optimization"]["stage"] = 3 - strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False trainer = Trainer( default_root_dir=tmp_path, @@ -1364,15 +1365,16 @@ def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support(tmp_pat _assert_save_model_is_equal(model, tmp_path, trainer) assert isinstance(trainer.strategy, DeepSpeedStrategy) - assert 'zero_optimization' in trainer.strategy.config - assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] == False + assert "zero_optimization" in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1 assert trainer.strategy.config["zero_optimization"]["stage"] == 3 @RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True) def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(tmp_path): - """Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' = True).""" + """Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' = + True).""" model = ModelParallelBoringModel() strategy = DeepSpeedStrategy(stage=3) strategy.config["zero_optimization"]["stage"] = 3 @@ -1380,7 +1382,7 @@ def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(tmp_path): strategy.config["zero_optimization"]["offload_param"] = {} strategy.config["zero_optimization"]["offload_optimizer"] = {} strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True - #Forming a 2 x 2 hierarchy + # Forming a 2 x 2 hierarchy trainer = Trainer( default_root_dir=tmp_path, strategy=strategy, @@ -1396,8 +1398,7 @@ def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(tmp_path): _assert_save_model_is_equal(model, tmp_path, trainer) assert isinstance(trainer.strategy, DeepSpeedStrategy) - assert 'zero_optimization' in trainer.strategy.config - assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] == True + assert "zero_optimization" in trainer.strategy.config + assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is True assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 2 assert trainer.strategy.config["zero_optimization"]["stage"] == 3 - From 689d61c47da7c1b2d61107e73a45534cda242452 Mon Sep 17 00:00:00 2001 From: hehepig4 Date: Wed, 4 Dec 2024 15:29:09 +0800 Subject: [PATCH 3/8] Add Deepspeed Zero 3 MiCS support (Issues #20378) --- src/lightning/pytorch/strategies/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index fd29ca82fe4d1..a39be03fea5dd 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -521,7 +521,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: import deepspeed self._init_config_if_needed() - + assert self.config is not None # If detect 'mics_shard_size'>0 in config['zero_optimization'], alter to use deepspeed.zero.MiCS_Init() # https://deepspeed.readthedocs.io/en/latest/zero3.html#mics-configurations #! default deepspeed 0.9.0 is not compatible From ff1efa0711b5e5b80f18fe059ad9175297a9093a Mon Sep 17 00:00:00 2001 From: hehepig4 Date: Fri, 6 Dec 2024 14:28:36 +0800 Subject: [PATCH 4/8] Add documentation for Deepspeed Zero 3 MiCS support (#20378) --- docs/source-pytorch/advanced/model_parallel/deepspeed.rst | 1 + tests/tests_pytorch/strategies/test_deepspeed.py | 0 2 files changed, 1 insertion(+) mode change 100755 => 100644 tests/tests_pytorch/strategies/test_deepspeed.py diff --git a/docs/source-pytorch/advanced/model_parallel/deepspeed.rst b/docs/source-pytorch/advanced/model_parallel/deepspeed.rst index 9689f8c217eaf..b7d6ee5a5dd19 100644 --- a/docs/source-pytorch/advanced/model_parallel/deepspeed.rst +++ b/docs/source-pytorch/advanced/model_parallel/deepspeed.rst @@ -408,6 +408,7 @@ Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lig * Treat your GPU/CPU memory as one large pool. In some cases, you may not want to offload certain things (like activations) to provide even more space to offload model parameters * When offloading to the CPU, make sure to bump up the batch size as GPU memory will be freed * We also support sharded checkpointing. By passing ``save_full_weights=False`` to the ``DeepSpeedStrategy``, we'll save shards of the model which allows you to save extremely large models. However to load the model and run test/validation/predict you must use the Trainer object. +* DeepSpeed provides `MiCS support `_ which allows you to control how model parameters are sharded across GPUs. This can be useful if you have a large cluster of GPUs and want to avoid communication overhead. .. _deepspeed-zero-stage-3-single-file: diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py old mode 100755 new mode 100644 From e66dd110513416fc9ef7d4b3e35e7955d3c4c32d Mon Sep 17 00:00:00 2001 From: hehepig4 Date: Fri, 6 Dec 2024 14:35:49 +0800 Subject: [PATCH 5/8] Add Deepspeed Zero 3 MiCS support for fabric (Issues #20378, pr #20461) --- src/lightning/fabric/strategies/deepspeed.py | 20 ++- .../strategies/test_deepspeed_integration.py | 145 ++++++++++++++++++ 2 files changed, 160 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 03d90cd5df057..4d0e5a134f698 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -373,11 +373,21 @@ def module_sharded_context(self) -> AbstractContextManager: import deepspeed assert self._config_initialized - return deepspeed.zero.Init( - enabled=self.zero_stage_3, - remote_device=self.remote_device, - config_dict_or_path=self.config, - ) + assert self.config is not None + + if 'zero_optimization' in self.config and 'mics_shard_size' in self.config['zero_optimization']\ + and self.config['zero_optimization']['mics_shard_size'] > 0 and self.zero_stage_3: + return deepspeed.zero.MiCS_Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ) + else: + return deepspeed.zero.Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ) @override def save_checkpoint( diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 4811599ed05ab..3cd4c6dbf4951 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -414,3 +414,148 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init): zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY) assert init_mock.call_count == int(not empty_init) assert model.layer.weight.dtype == torch.bfloat16 + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_support(): + """Test to ensure ZeRO Stage 3 MiCS works with a parallel model.""" + strategy = DeepSpeedStrategy(stage=3) + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + + fabric = Fabric( + strategy= strategy, + accelerator="cuda", + devices=2, + precision="16-mixed", + ) + fabric.launch() + + def _make_block(): + return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) + + with fabric.init_module(): + model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + model, optimizer = fabric.setup(model, optimizer) + + x = torch.rand(2, 32, device=fabric.device) + y = torch.ones(x.size(0), device=x.device, dtype=torch.long) + x = model(x) + x = x.float() # Ensure output is in float32 for softmax operation + logits = F.softmax(x, dim=1) + loss = F.cross_entropy(logits, y) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support(): + """Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support""" + strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu") + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + + fabric = Fabric( + strategy= strategy, + accelerator="cuda", + devices=2, + precision="16-mixed", + ) + fabric.launch() + + def _make_block(): + return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) + + with fabric.init_module(): + model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + model, optimizer = fabric.setup(model, optimizer) + + x = torch.rand(2, 32, device=fabric.device) + y = torch.ones(x.size(0), device=x.device, dtype=torch.long) + x = model(x) + x = x.float() # Ensure output is in float32 for softmax operation + logits = F.softmax(x, dim=1) + loss = F.cross_entropy(logits, y) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support(): + """Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support.""" + strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu", offload_optimizer_device="cpu") + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 1 + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False + + fabric = Fabric( + strategy= strategy, + accelerator="cuda", + devices=2, + precision="16-mixed", + ) + fabric.launch() + + def _make_block(): + return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) + + with fabric.init_module(): + model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + model, optimizer = fabric.setup(model, optimizer) + + x = torch.rand(2, 32, device=fabric.device) + y = torch.ones(x.size(0), device=x.device, dtype=torch.long) + x = model(x) + x = x.float() # Ensure output is in float32 for softmax operation + logits = F.softmax(x, dim=1) + loss = F.cross_entropy(logits, y) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + +@RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True) +def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(): + """Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' = + True).""" + strategy = DeepSpeedStrategy(stage=3) + strategy.config["zero_optimization"]["stage"] = 3 + strategy.config["zero_optimization"]["mics_shard_size"] = 2 + strategy.config["zero_optimization"]["offload_param"] = {} + strategy.config["zero_optimization"]["offload_optimizer"] = {} + strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True + + fabric = Fabric( + strategy= strategy, + accelerator="cuda", + devices=2, + precision="16-mixed", + ) + fabric.launch() + + def _make_block(): + return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) + + with fabric.init_module(): + model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3)) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + model, optimizer = fabric.setup(model, optimizer) + + x = torch.rand(2, 32, device=fabric.device) + y = torch.ones(x.size(0), device=x.device, dtype=torch.long) + x = model(x) + x = x.float() # Ensure output is in float32 for softmax operation + logits = F.softmax(x, dim=1) + loss = F.cross_entropy(logits, y) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() \ No newline at end of file From 9ccbb1fec7bb0da4163004a2037e82b28dc636fa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 06:37:34 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/strategies/deepspeed.py | 10 ++++++--- .../strategies/test_deepspeed_integration.py | 22 ++++++++++--------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 4d0e5a134f698..e377d696e1f59 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -374,9 +374,13 @@ def module_sharded_context(self) -> AbstractContextManager: assert self._config_initialized assert self.config is not None - - if 'zero_optimization' in self.config and 'mics_shard_size' in self.config['zero_optimization']\ - and self.config['zero_optimization']['mics_shard_size'] > 0 and self.zero_stage_3: + + if ( + "zero_optimization" in self.config + and "mics_shard_size" in self.config["zero_optimization"] + and self.config["zero_optimization"]["mics_shard_size"] > 0 + and self.zero_stage_3 + ): return deepspeed.zero.MiCS_Init( enabled=self.zero_stage_3, remote_device=self.remote_device, diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 3cd4c6dbf4951..d30d649346c75 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -423,9 +423,9 @@ def test_deepspeed_multigpu_stage_3_MiCS_support(): strategy.config["zero_optimization"]["stage"] = 3 strategy.config["zero_optimization"]["mics_shard_size"] = 1 strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False - + fabric = Fabric( - strategy= strategy, + strategy=strategy, accelerator="cuda", devices=2, precision="16-mixed", @@ -451,16 +451,17 @@ def _make_block(): optimizer.step() optimizer.zero_grad() + @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support(): - """Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support""" + """Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support.""" strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu") strategy.config["zero_optimization"]["stage"] = 3 strategy.config["zero_optimization"]["mics_shard_size"] = 1 strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False - + fabric = Fabric( - strategy= strategy, + strategy=strategy, accelerator="cuda", devices=2, precision="16-mixed", @@ -494,9 +495,9 @@ def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support(): strategy.config["zero_optimization"]["stage"] = 3 strategy.config["zero_optimization"]["mics_shard_size"] = 1 strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False - + fabric = Fabric( - strategy= strategy, + strategy=strategy, accelerator="cuda", devices=2, precision="16-mixed", @@ -522,6 +523,7 @@ def _make_block(): optimizer.step() optimizer.zero_grad() + @RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True) def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(): """Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' = @@ -532,9 +534,9 @@ def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(): strategy.config["zero_optimization"]["offload_param"] = {} strategy.config["zero_optimization"]["offload_optimizer"] = {} strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True - + fabric = Fabric( - strategy= strategy, + strategy=strategy, accelerator="cuda", devices=2, precision="16-mixed", @@ -558,4 +560,4 @@ def _make_block(): loss = F.cross_entropy(logits, y) fabric.backward(loss) optimizer.step() - optimizer.zero_grad() \ No newline at end of file + optimizer.zero_grad() From 5409bc94ce5ff4728180c9938512c3f0ecc18fa1 Mon Sep 17 00:00:00 2001 From: hehepig4 Date: Tue, 10 Dec 2024 10:33:11 +0800 Subject: [PATCH 7/8] update docs (#20378) --- docs/source-pytorch/advanced/model_parallel/deepspeed.rst | 2 +- src/lightning/pytorch/strategies/deepspeed.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel/deepspeed.rst b/docs/source-pytorch/advanced/model_parallel/deepspeed.rst index b7d6ee5a5dd19..2680f600a445e 100644 --- a/docs/source-pytorch/advanced/model_parallel/deepspeed.rst +++ b/docs/source-pytorch/advanced/model_parallel/deepspeed.rst @@ -408,7 +408,7 @@ Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lig * Treat your GPU/CPU memory as one large pool. In some cases, you may not want to offload certain things (like activations) to provide even more space to offload model parameters * When offloading to the CPU, make sure to bump up the batch size as GPU memory will be freed * We also support sharded checkpointing. By passing ``save_full_weights=False`` to the ``DeepSpeedStrategy``, we'll save shards of the model which allows you to save extremely large models. However to load the model and run test/validation/predict you must use the Trainer object. -* DeepSpeed provides `MiCS support `_ which allows you to control how model parameters are sharded across GPUs. This can be useful if you have a large cluster of GPUs and want to avoid communication overhead. +* DeepSpeed provides `MiCS support `_ which allows you to control how model parameters are sharded across GPUs. For example, with 16 GPUs, ZeRO-3 will shard the model into 16 pieces by default. Instead with ``mics_shard_size=8``, every 8 GPUs will keep a full copy of the model weights, reducing the communication overhead. You can set ``"zero_optimization": {"stage": 3, "mics_shard_size": (shards num), ...}`` in a DeepSpeed config file to take advantage of this feature. .. _deepspeed-zero-stage-3-single-file: diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index a39be03fea5dd..38df1874deed6 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -522,7 +522,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: self._init_config_if_needed() assert self.config is not None - # If detect 'mics_shard_size'>0 in config['zero_optimization'], alter to use deepspeed.zero.MiCS_Init() + # If we detect `'mics_shard_size' > 0` in `config['zero_optimization']`, use `deepspeed.zero.MiCS_Init(...)` instead of `deepspeed.zero.Init(...)` # https://deepspeed.readthedocs.io/en/latest/zero3.html#mics-configurations #! default deepspeed 0.9.0 is not compatible if ( From 10cdeb4bb8d9ec00a67ffcce7399794bce2e35c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 07:20:19 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/strategies/deepspeed.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index b66b6b0cb4f40..555a4ef2e6db7 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -389,12 +389,11 @@ def module_sharded_context(self) -> AbstractContextManager: remote_device=self.remote_device, config_dict_or_path=self.config, ) - else: - return deepspeed.zero.Init( - enabled=self.zero_stage_3, - remote_device=self.remote_device, - config_dict_or_path=self.config, - ) + return deepspeed.zero.Init( + enabled=self.zero_stage_3, + remote_device=self.remote_device, + config_dict_or_path=self.config, + ) @override def save_checkpoint(