diff --git a/_notebooks b/_notebooks deleted file mode 160000 index b83fde09c7243..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b83fde09c724311af0d528e810b2ba606f31c95e diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/README.md b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/README.md new file mode 100644 index 0000000000000..5af7be196572c --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/README.md @@ -0,0 +1,38 @@ +# PyTorch Native FP8 Training with FSDP1/2 and Torch Compile using Custom Handler + +This is an example of a ... + +## Requirements + +Install requirements by running + +```bash +sh setup.sh +``` + +## Example + +In this example we present + +```bash +# # config the PYTHONPATH if needed +# export PYTHONPATH=/teamspace/studios/this_studio/pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile:$PYTHONPATH +cd pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile + +# fsdp1 + fp8 + torch compile + gradient checkpointing + cpu offloading +python train.py --enable_fp8 --enable_torch_compile --enable_gradient_checkpointing --enable_cpu_offload + +# fsdp2 + fp8 + torch compile + gradient checkpointing (the example does not implement fsdp2 cpu offloading currently) +python train.py --enable_fsdp2 --enable_fp8 --enable_torch_compile --enable_gradient_checkpointing +``` + +## Test the handlers + +```bash +# # config the PYTHONPATH if needed +# export PYTHONPATH=/teamspace/studios/this_studio/pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile:$PYTHONPATH +cd pytorch-lightning/examples/pytorch/custom_handler_fp8_fsdp1n2_compile +pytest tests/* +``` + +> **Warning** diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/__init__.py b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fp8_training_handler.py b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fp8_training_handler.py new file mode 100644 index 0000000000000..22b05d96b1af3 --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fp8_training_handler.py @@ -0,0 +1,192 @@ +# the script is modified based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py +import logging +import operator +from dataclasses import dataclass +from typing import Union + +import torch +import torch.nn as nn +from lightning_utilities.core.imports import compare_version + +log = logging.getLogger(__name__) + + +def is_sm89_or_later(): + # Float8 is only supported on SM89 or later (H100+ GPUs) + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) + + +# check https://github.com/pytorch/ao/blob/main/torchao/float8/config.py for more config details +@dataclass +class FP8Config: + enable_fp8: bool = True + enable_amax_init: bool = False + scaling_type_input: str = "delayed" + scaling_type_weight: str = "delayed" + scaling_type_grad_output: str = "delayed" + enable_fsdp_float8_all_gather: bool = False + precompute_float8_dynamic_scale_for_fsdp: bool = False + pad_inner_dim: bool = True + emulate_fp8: bool = False # Set to True for testing without FP8 hardware + enable_torch_compile: bool = True + enable_pre_and_post_forward: bool = False + + +# Define a map for module filter functions based on model name +MODULE_FILTER_MAP = { + "llama": lambda mod, fqn: isinstance(mod, nn.Linear) and "mlp" in fqn and "lm_head" not in fqn, + "mixtral": lambda mod, fqn: isinstance(mod, nn.Linear) + and "block_sparse_moe" in fqn + and "block_sparse_moe.gate" not in fqn + and "lm_head" not in fqn, + "default": lambda mod, fqn: isinstance(mod, nn.Linear), # Default filter +} + + +class Float8TrainingHandler: + """Handler for configuring models for FP8 training using torchao.""" + + def __init__(self, args: FP8Config, model_path: str, parallel_dims: dict[str, bool]): + """Initializes the handler for FP8 training and configuration. + + Args: + args (FP8Config): Configuration object for FP8 training, including settings for scaling, amax initialization, and torch compile. + model_path (str): The path to the model. Typically used for determining model-specific settings. + parallel_dims (Dict[str, bool]): Dictionary specifying parallelization settings, such as whether DP shard is enabled. + + Example Usage: + fp8_config = FP8Config( + enable_fp8=True, + enable_amax_init=True, + scaling_type_input="delayed", + scaling_type_weight="delayed", + scaling_type_grad_output="delayed", + enable_fsdp_float8_all_gather=False, + precompute_float8_dynamic_scale_for_fsdp=False, + pad_inner_dim=True, + emulate_fp8=False, # Set to True for testing without FP8 hardware + enable_torch_compile=True, + enable_pre_and_post_forward=False, + ) + + parallel_dims = {"dp_shard_enabled": False} + handler = Float8TrainingHandler(fp8_config, "path/to/model", parallel_dims) + + """ + self.model_path = model_path + self.args = args + self.parallel_dims = parallel_dims + self.compile = args.enable_torch_compile + self.enable_fp8 = args.enable_fp8 + + if not self.enable_fp8: + log.warning("Fp8 is disabled here") + return + + if not is_sm89_or_later() and not args.emulate_fp8: + log.error("Failed to swap to Float8Linear because float8 is only supported on SM89 or later (H100+ GPUs)") + raise RuntimeError("Float8Linear operation is not supported on the current hardware.") + + # Check if torchao is installed and version is >= 0.5.0 + try: + compare_version("torchao", operator.ge, "0.6.1") + from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType + except ImportError as e: + log.error(str(e)) + raise + + # Configure Float8LinearConfig parameters from args + scaling_type_input = ScalingType(args.scaling_type_input) + scaling_type_weight = ScalingType(args.scaling_type_weight) + scaling_type_grad_output = ScalingType(args.scaling_type_grad_output) + + enable_fsdp_float8_all_gather = ( + parallel_dims.get("dp_shard_enabled", False) and args.enable_fsdp_float8_all_gather + ) + + enable_amax_init = args.enable_amax_init + self.config = Float8LinearConfig( + enable_amax_init=enable_amax_init, + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + cast_config_input=CastConfig(scaling_type=scaling_type_input), + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), + enable_pre_and_post_forward=args.enable_pre_and_post_forward, + pad_inner_dim=args.pad_inner_dim, + emulate=args.emulate_fp8, + ) + + # For precompute_float8_dynamic_scale_for_fsdp + self.precompute_scale = enable_fsdp_float8_all_gather and args.precompute_float8_dynamic_scale_for_fsdp + + # For sync_float8_amax_and_scale_history + self.delayed_scaling = ( + scaling_type_input == ScalingType.DELAYED + or scaling_type_weight == ScalingType.DELAYED + or scaling_type_grad_output == ScalingType.DELAYED + ) + self._sync_float8_amax_and_scale_history = None + + log.info("Float8 training active") + + def convert_to_float8_training(self, model: nn.Module, module_filter_fn: callable = None): + """Converts the linear layers of `model` to `Float8Linear` based on a module filter function. Mutates the model + in place. + + Args: + model (nn.Module): The model whose layers should be converted. + module_filter_fn (callable, optional): A function to filter which modules should be replaced. + Defaults to a model-specific filter based on `model_path`. + + """ + if not self.enable_fp8: + log.warning("FP8 is disabled, so layers will not be replaced.") + return + + log.warning("Enabling FP8 Training") + + # Use the provided filter function or select from the map + if module_filter_fn is None: + model_path_lower = self.model_path.lower() + module_filter_fn = next( + (fn for key, fn in MODULE_FILTER_MAP.items() if key in model_path_lower), + MODULE_FILTER_MAP["default"], # Default filter if no match is found + ) + + from torchao.float8 import convert_to_float8_training + + convert_to_float8_training( + model, + config=self.config, + module_filter_fn=module_filter_fn, + ) + log.info( + f"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather={self.config.enable_fsdp_float8_all_gather}" + ) + + def precompute_float8_dynamic_scale_for_fsdp(self, model: Union[nn.Module, list[nn.Module]]): + if not self.enable_fp8 or not self.precompute_scale: + return + + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp + + models = [model] if isinstance(model, nn.Module) else model + for m in models: + precompute_float8_dynamic_scale_for_fsdp(m) + + def sync_float8_amax_and_scale_history(self, model: Union[nn.Module, list[nn.Module]]): + if not self.enable_fp8 or not self.delayed_scaling: + return + + from torchao.float8 import sync_float8_amax_and_scale_history + + # Cache the compiled function if necessary + if self._sync_float8_amax_and_scale_history is None: + if self.compile: + self._sync_float8_amax_and_scale_history = torch.compile(sync_float8_amax_and_scale_history) + else: + self._sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history + + models = [model] if isinstance(model, nn.Module) else model + for m in models: + self._sync_float8_amax_and_scale_history(m) diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fsdp2_handler.py b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fsdp2_handler.py new file mode 100644 index 0000000000000..0c4fe5649d40e --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/fsdp2_handler.py @@ -0,0 +1,100 @@ +import logging +import operator +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn +from lightning_utilities.core.imports import compare_version + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + +log = logging.getLogger(__name__) + + +@dataclass +class FSDP2Config: + enable_cpu_offload: bool = False + enable_gradient_checkpointing: bool = False + + +class FSDP2Handler: + """Handler for wrapping the model layers with FSDP2. + + Args: + args (FSDP2Config): Configuration for FSDP2, including options for CPU offload and gradient checkpointing. + device_mesh (DeviceMesh): Device mesh configuration for FSDP2 parallelism. + + Attributes: + args (FSDP2Config): Stores the FSDP2 configuration. + device_mesh (DeviceMesh): Stores the device mesh configuration. + + """ + + def __init__(self, args: FSDP2Config, device_mesh: "DeviceMesh"): + self.args = args + self.device_mesh = device_mesh + + # Check PyTorch version for FSDP2 support (currently we require PyTorch >= 2.6.0) + try: + compare_version("torch", operator.ge, "2.6.0") + except RuntimeError as e: + log.error(str(e)) + raise + + # Import necessary FSDP modules + try: + from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + fully_shard, + ) + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + ) + + self.fully_shard = fully_shard + self.checkpoint_wrapper = checkpoint_wrapper + self.MixedPrecisionPolicy = MixedPrecisionPolicy + self.CPUOffloadPolicy = CPUOffloadPolicy + except ImportError as e: + log.error(f"Failed to import FSDP modules: {e}") + raise + + def wrap_model(self, model: nn.Module): + """Wraps the model layers with FSDP configurations. + + Args: + model (nn.Module): The model to wrap. + + Returns: + nn.Module: The wrapped model. + + """ + dp_mesh = self.device_mesh["data_parallel"] + assert dp_mesh.size() > 1, "FSDP requires at least two devices." + + fsdp_policy = { + "mesh": dp_mesh, + "mp_policy": self.MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ), + } + if self.args.enable_cpu_offload: + fsdp_policy["offload_policy"] = self.CPUOffloadPolicy() + + for layer_id, module in enumerate(model.model.layers): + reshard_after_forward = layer_id < len(model.model.layers) - 1 + if self.args.enable_gradient_checkpointing: + module = self.checkpoint_wrapper(module) + self.fully_shard( + module, + **fsdp_policy, + reshard_after_forward=reshard_after_forward, + ) + model.model.layers[layer_id] = module + + self.fully_shard(model, **fsdp_policy) + return model diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/torch_compile_handler.py b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/torch_compile_handler.py new file mode 100644 index 0000000000000..320757796a6d2 --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/handlers/torch_compile_handler.py @@ -0,0 +1,106 @@ +import logging +import operator + +import torch +import torch.nn as nn +from lightning_utilities.core.imports import compare_version + +log = logging.getLogger(__name__) + + +class TorchCompileHandler: + """Handler for compiling specific layers of the model using torch.compile. + + Args: + enable_compile (bool): Whether to enable compilation. + model_path (str): Path to the model, used to determine default compilable layers. + compile_layers (List[str], optional): List of layer class names to compile. If None, defaults are used. + compile_args (dict, optional): Additional arguments to pass to torch.compile. + + """ + + # Default mapping of model names to compilable layer class names + DEFAULT_COMPILABLE_LAYERS = { + "llama": ["LlamaMLP"], + "mixtral": ["MixtralBlockSparseTop2MLP"], + } + + def __init__( + self, + enable_compile: bool, + model_path: str, + compile_layers: list = None, + compile_args: dict = None, + ): + self.enable_compile = enable_compile + self.model_path = model_path.lower() + self.compile_args = compile_args if compile_args is not None else {} + self.compile_layers = compile_layers # User-provided layers to compile + + if self.enable_compile: + # Check PyTorch version for torch.compile support (requires PyTorch >= 2.6.0) + try: + compare_version("torch", operator.ge, "2.6.0") + except RuntimeError as e: + log.error(str(e)) + raise + + # Determine default layers to compile if not provided + if self.compile_layers is None: + self.compile_layers = self._get_default_compile_layers() + if not self.compile_layers: + log.warning( + "No default compilable layers found for the model. " "Please provide compile_layers explicitly." + ) + + def _get_default_compile_layers(self): + """Determines the default layers to compile based on the model name. + + Returns: + List[str]: List of layer class names to compile. + + """ + for model_name, layers in self.DEFAULT_COMPILABLE_LAYERS.items(): + if model_name in self.model_path: + return layers + return [] + + def compile_model(self, model: nn.Module): + """Compiles specified layers in the model. + + Args: + model (nn.Module): The model to compile. + + """ + if not self.enable_compile: + return + + if not self.compile_layers: + log.warning("No layers specified for compilation. Skipping compilation.") + return + + log.warning(f"Compiling layers: {self.compile_layers} with args: {self.compile_args}") + + self._compile_layers(model) + + def _compile_layers(self, module: nn.Module): + """Recursively compiles specified layers in the module. + + Args: + module (nn.Module): The module to process. + + """ + for name, child in module.named_children(): + child_class_name = type(child).__name__ + if child_class_name in self.compile_layers: + log.warning(f"Compiling layer {name} ({child_class_name})") + try: + # Compile the layer with provided arguments + compiled_child = torch.compile(child, **self.compile_args) + setattr(module, name, compiled_child) + except Exception as e: + log.error(f"Failed to compile layer {name}: {e}") + raise + else: + # Recursively process child modules + self._compile_layers(child) diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/requirements.txt b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/requirements.txt new file mode 100644 index 0000000000000..c0605af1f88e0 --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/requirements.txt @@ -0,0 +1 @@ +lightning diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/setup.sh b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/setup.sh new file mode 100644 index 0000000000000..5118ce1899402 --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/setup.sh @@ -0,0 +1,3 @@ +pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 +pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 +pip install -r requirements.txt diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_fp8_training_handler.py b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_fp8_training_handler.py new file mode 100644 index 0000000000000..015032ae6819f --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_fp8_training_handler.py @@ -0,0 +1,115 @@ +import unittest +from unittest.mock import patch + +import pytest +import torch.nn as nn +from handlers.fp8_training_handler import Float8TrainingHandler, FP8Config +from lightning.pytorch.demos import Transformer +from torchao.float8 import Float8Linear + + +class TestFloat8TrainingHandler(unittest.TestCase): + def setUp(self): + self.args = FP8Config( + enable_fp8=True, + enable_amax_init=True, + scaling_type_input="delayed", + scaling_type_weight="delayed", + scaling_type_grad_output="delayed", + enable_fsdp_float8_all_gather=False, + precompute_float8_dynamic_scale_for_fsdp=False, + pad_inner_dim=False, + emulate_fp8=False, # Set to True for testing without FP8 hardware + enable_torch_compile=False, + enable_pre_and_post_forward=False, + ) + + self.model_path = "test_mixtral_model" + self.parallel_dims = {"dp_shard_enabled": False} + + # Simple model for testing + self.model = Transformer( + vocab_size=32000, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) + + @patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=True) + def test_handler_initialization(self, mock_sm89): + handler = Float8TrainingHandler(self.args, self.model_path, self.parallel_dims) + assert handler.enable_fp8 + assert not handler.compile + assert handler.args is not None + assert handler.parallel_dims is not None + + @patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=True) + def test_compile_flag(self, mock_sm89): + self.args.enable_torch_compile = True + handler = Float8TrainingHandler(self.args, self.model_path, self.parallel_dims) + assert handler.compile + + @patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=False) + def test_handler_disabled_on_unsupported_hardware(self, mock_sm89): + # Assert that the RuntimeError is raised + with pytest.raises(RuntimeError) as context: + Float8TrainingHandler(self.args, self.model_path, self.parallel_dims) + + # Check that the error message matches the expected text + assert "Float8Linear operation is not supported on the current hardware." in str(context.exception) + + def test_handler_disabled_when_fp8_not_enabled(self): + self.args.enable_fp8 = False + handler = Float8TrainingHandler(self.args, self.model_path, self.parallel_dims) + assert not handler.enable_fp8 + + @patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=True) + def test_convert_to_float8_training(self, mock_sm89): + handler = Float8TrainingHandler(self.args, self.model_path, self.parallel_dims) + handler.convert_to_float8_training(self.model) + + # Check if nn.Linear layers have been converted to Float8Linear + print(self.model) + for module_name, module in self.model.named_modules(): + if any(proj in module_name for proj in ["w1", "w2", "w3"]): # Float8Linear + assert isinstance(module, Float8Linear), f"{module_name} should be Float8Linear" + elif isinstance(module, nn.Linear): + assert not isinstance(module, Float8Linear), f"{module_name} should not be Float8Linear" + + @patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=True) + def test_precompute_float8_dynamic_scale_for_fsdp(self, mock_sm89): + handler = Float8TrainingHandler(self.args, self.model_path, self.parallel_dims) + handler.convert_to_float8_training(self.model) + + with patch("torchao.float8.precompute_float8_dynamic_scale_for_fsdp") as mock_precompute: + handler.precompute_float8_dynamic_scale_for_fsdp(self.model) + mock_precompute.assert_not_called() # Should not be called since precompute_scale is False + + # Enable precompute_scale + args = self.args + args.scaling_type_input = "dynamic" + args.scaling_type_weight = "dynamic" + args.scaling_type_grad_output = "dynamic" + args.enable_fsdp_float8_all_gather = True + args.precompute_float8_dynamic_scale_for_fsdp = True + handler = Float8TrainingHandler(args, self.model_path, {"dp_shard_enabled": True}) + model = Transformer( + vocab_size=32000, + nlayers=16, + nhid=4096, + ninp=1024, + nhead=32, + ) # recreate the model + with patch("torchao.float8.precompute_float8_dynamic_scale_for_fsdp") as mock_precompute: + handler.precompute_float8_dynamic_scale_for_fsdp(model) + mock_precompute.assert_called() + + @patch("handlers.fp8_training_handler.is_sm89_or_later", return_value=True) + def test_sync_float8_amax_and_scale_history(self, mock_sm89): + handler = Float8TrainingHandler(self.args, self.model_path, self.parallel_dims) + handler.convert_to_float8_training(self.model) + + with patch("torchao.float8.sync_float8_amax_and_scale_history") as mock_sync: + handler.sync_float8_amax_and_scale_history(self.model) + mock_sync.assert_called_once_with(self.model) diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_fsdp2_handler.py b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_fsdp2_handler.py new file mode 100644 index 0000000000000..f0909f7cd1bcb --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_fsdp2_handler.py @@ -0,0 +1,120 @@ +import unittest +from unittest.mock import MagicMock, patch + +import pytest +import torch.nn as nn +from handlers.fsdp2_handler import FSDP2Config, FSDP2Handler + + +# Define mock functions +def mock_fully_shard(module, **kwargs): + """Mock for torch.distributed._composable.fsdp.fully_shard. + + Returns the module unchanged to simulate sharding without actual processing. + + """ + return module + + +def mock_checkpoint_wrapper(module): + """Mock for torch.distributed.algorithms._checkpoint.checkpoint_wrapper. + + Returns the module unchanged to simulate checkpoint wrapping without actual processing. + + """ + return module + + +class TestFSDP2Handler(unittest.TestCase): + def setUp(self): + self.args = FSDP2Config( + enable_gradient_checkpointing=True, + enable_cpu_offload=False, + ) + + # Mock device mesh + self.device_mesh = {"data_parallel": MagicMock()} + self.device_mesh["data_parallel"].size.return_value = 2 # Simulate more than one device + + class ModelWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model # The wrapped Transformer model + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + class InnerModel(nn.Module): + def __init__(self, num_layers, input_size, hidden_size): + super().__init__() + # Initialize a ModuleList to store the layers + self.layers = nn.ModuleList() + for _ in range(num_layers): + layer = nn.Linear(input_size, hidden_size) + self.layers.append(layer) + # You can add more complex layers or custom layers here + + def forward(self, x): + # Pass the input through each layer sequentially + for layer in self.layers: + x = layer(x) + return x + + self.model = ModelWrapper( + InnerModel( + num_layers=16, + input_size=4096, + hidden_size=1024, + ) + ) + + @patch("torch.distributed._composable.fsdp.fully_shard", side_effect=mock_fully_shard) + @patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper", + side_effect=mock_checkpoint_wrapper, + ) + def test_wrap_model(self, mock_checkpoint_wrapper_func, mock_fully_shard_func): + handler = FSDP2Handler(self.args, self.device_mesh) + wrapped_model = handler.wrap_model(self.model) + + # Ensure fully_shard and checkpoint_wrapper are called + assert mock_fully_shard_func.called, "fully_shard was not called" + assert mock_checkpoint_wrapper_func.called, "checkpoint_wrapper was not called" + + # Verify that the model's layers have been wrapped + assert wrapped_model is not None, "wrapped_model is None" + mock_fully_shard_func.assert_called() + + # Ensure that checkpoint_wrapper is called for each layer + assert mock_checkpoint_wrapper_func.call_count == len(self.model.model.layers) + # Ensure that fully_shard is called for each layer + full module + assert mock_fully_shard_func.call_count == len(self.model.model.layers) + 1 + + def test_wrap_model_with_single_device(self): + # Simulate single device + self.device_mesh["data_parallel"].size.return_value = 1 + handler = FSDP2Handler(self.args, self.device_mesh) + with pytest.raises(AssertionError): + handler.wrap_model(self.model) + + @patch("torch.distributed._composable.fsdp.fully_shard", side_effect=mock_fully_shard) + def test_enable_cpu_offload(self, mock_fully_shard_func): + self.args.enable_cpu_offload = True + handler = FSDP2Handler(self.args, self.device_mesh) + handler.wrap_model(self.model) + # Check if CPUOffloadPolicy is used + args, kwargs = mock_fully_shard_func.call_args + assert "offload_policy" in kwargs + assert kwargs["offload_policy"] is not None + + @patch("torch.distributed._composable.fsdp.fully_shard", side_effect=mock_fully_shard) + @patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper", + side_effect=mock_checkpoint_wrapper, + ) + def test_diable_gradient_checkpointing(self, mock_checkpoint_wrapper_func, mock_fully_shard_func): + self.args.enable_gradient_checkpointing = False + handler = FSDP2Handler(self.args, self.device_mesh) + handler.wrap_model(self.model) + # Check if gradient checkpointing is disabled + assert not mock_checkpoint_wrapper_func.called, "Error: checkpoint_wrapper was unexpectedly called." diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_torch_compile_handler.py b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_torch_compile_handler.py new file mode 100644 index 0000000000000..bc286a8511481 --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/tests/test_torch_compile_handler.py @@ -0,0 +1,74 @@ +# test_torch_compile_handler.py + +import unittest +from unittest.mock import patch + +import torch.nn as nn +from handlers.torch_compile_handler import TorchCompileHandler +from lightning.pytorch.demos import Transformer + + +def mock_torch_compile(module, **kwargs): + """Mock function for torch.compile that returns the module unchanged. + + This avoids actual compilation during testing. + + """ + return module + + +class TestTorchCompileHandler(unittest.TestCase): + def setUp(self): + self.enable_compile = True + self.model_path = "test_custom_transformer_model" + + self.num_layers = 16 + self.model = Transformer( + vocab_size=32000, + nlayers=self.num_layers, + nhid=4096, + ninp=1024, + nhead=32, + ) + self.compile_args = {"backend": "inductor", "mode": "default"} + + @patch("torch.compile", side_effect=mock_torch_compile) + def test_compile_transformer_encoder_layers(self, mock_compile): + handler = TorchCompileHandler( + enable_compile=self.enable_compile, + model_path=self.model_path, + compile_layers=["TransformerEncoderLayer"], # Explicitly specify layers + compile_args=self.compile_args, + ) + handler.compile_model(self.model) + + # Ensure torch.compile was called with the correct layer + assert mock_compile.call_count == self.num_layers, f"Expected mock_compile to be called {self.num_layers} times" + + def test_compile_disabled(self): + handler = TorchCompileHandler(False, self.model_path) + with patch("torch.compile") as mock_torch_compile: + handler.compile_model(self.model) + mock_torch_compile.assert_not_called() + + @patch("torch.compile", side_effect=mock_torch_compile) + def test_compile_recursive(self, mock_compile): + # Nested modules + class NestedModel(nn.Module): + def __init__(self, child_module): + super().__init__() + self.layer = nn.Sequential( + nn.Linear(128, 128), + child_module, + ) + + def forward(self, x): + return self.layer(x) + + model = NestedModel(child_module=self.model) + handler = TorchCompileHandler(self.enable_compile, self.model_path, compile_layers=["TransformerDecoderLayer"]) + handler.compile_model(model) + + # LlamaMLP inside NestedModel should be compiled + assert mock_compile.called + assert mock_compile.call_count == self.num_layers, f"Expected mock_compile to be called {self.num_layers} times" diff --git a/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/train.py b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/train.py new file mode 100644 index 0000000000000..b3a75398e4bf1 --- /dev/null +++ b/examples/pytorch/custom_handler_fp8_fsdp1n2_compile/train.py @@ -0,0 +1,262 @@ +import argparse +import logging +from dataclasses import dataclass + +import lightning as L +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from lightning.pytorch.demos import WikiText2 +from lightning.pytorch.strategies import FSDPStrategy, ModelParallelStrategy +from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision +from torch.utils.data import DataLoader + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +log = logging.getLogger(__name__) + + +@dataclass +class Args: + vocab_size: int = 32000 + enable_fp8: bool = False + enable_torch_compile: bool = False + enable_cpu_offload: bool = False + enable_gradient_checkpointing: bool = False + enable_fsdp2: bool = False + + +class SimpleLayer(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.linear = nn.Linear(hidden_size, hidden_size) + self.activation = nn.ReLU() + + def forward(self, x): + print(f"Input shape before Linear: {x.shape}") + x = self.linear(x) + print(f"Output shape after Linear: {x.shape}") + return self.activation(x) + + +class InnerModel(nn.Module): + def __init__(self, num_layers, hidden_size, vocab_size=32000): + super().__init__() + # Embedding layer + self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size) + # Initialize a ModuleList to store the intermediate layers + self.layers = nn.ModuleList([SimpleLayer(hidden_size) for _ in range(num_layers)]) + self.lm_head = nn.Linear(hidden_size, vocab_size) + + def forward(self, x): + x = self.embedding(x) + # Pass the input through each layer sequentially + for layer in self.layers: + x = layer(x) + return self.lm_head(x) + + +class ModelWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model # The wrapped Transformer model + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + +class LanguageModel(L.LightningModule): + def __init__( + self, + vocab_size=32000, + enable_fp8=False, + enable_fsdp2=False, + enable_torch_compile=False, + enable_gradient_checkpointing=False, + enable_cpu_offload=False, + ): + super().__init__() + self.model = None + self.vocab_size = vocab_size + self.enable_fp8 = enable_fp8 + self.enable_fsdp2 = enable_fsdp2 + self.enable_torch_compile = enable_torch_compile + self.enable_gradient_checkpointing = enable_gradient_checkpointing + self.enable_cpu_offload = enable_cpu_offload + self.model_path = "dummy" # placeholder + self.parallel_dims = {"dp_shard_enabled": torch.cuda.device_count() > 1} # only used for FP8 training + + def log_model_stage(self, stage: str): + """Logs the current state of the model with a description of the stage. + + Args: + stage (str): Description of the current model stage. + + """ + log.warning(f"Model at stage: {stage}\n{self.model}") + + def configure_torch_compile(self): + if self.enable_torch_compile: + from handlers.torch_compile_handler import TorchCompileHandler + + torch_compile_handler = TorchCompileHandler( + enable_compile=self.enable_torch_compile, + model_path=self.model_path, + # Implicitly specify layers, default only support compile HuggingFace llama and mixtral model with llama MLP block and Mixtral MixtralBlockSparseTop2MLP block compiled + compile_layers=["SimpleLayer"], + compile_args=None, + ) + torch_compile_handler.compile_model(self.model) + + self.log_model_stage("Model after compile") + + def configure_fsdp2(self): + if self.enable_fsdp2: + self.all_gpus = dist.new_group(backend="nccl") + dp_mesh = self.device_mesh["data_parallel"] + assert dp_mesh.size() > 1 + + from handlers.fsdp2_handler import FSDP2Config, FSDP2Handler + + fsdp2_config = FSDP2Config( + enable_cpu_offload=self.enable_cpu_offload, + enable_gradient_checkpointing=self.enable_gradient_checkpointing, + ) + fsdp2_handler = FSDP2Handler(fsdp2_config, self.device_mesh) + self.model = fsdp2_handler.wrap_model(self.model) + + self.log_model_stage("Model after FSDP wrapper") + + def configure_fp8(self): + # Setup fp8 training, if enable_fp8 is false, it will create a fake handler + from handlers.fp8_training_handler import Float8TrainingHandler, FP8Config + + fp8_config = FP8Config( + enable_fp8=self.enable_fp8, + enable_amax_init=False, + scaling_type_input="delayed", + scaling_type_weight="delayed", + scaling_type_grad_output="delayed", + enable_fsdp_float8_all_gather=False, + precompute_float8_dynamic_scale_for_fsdp=False, + pad_inner_dim=True, + emulate_fp8=False, # Set to True for testing without FP8 hardware + enable_torch_compile=self.enable_torch_compile, + enable_pre_and_post_forward=False, + ) + self.fp8_handler = Float8TrainingHandler(fp8_config, self.model_path, self.parallel_dims) + self.fp8_handler.convert_to_float8_training(self.model) + self.log_model_stage("Model after FP8 wrapper") + + def configure_model(self): + if self.model is not None: + return + + with torch.device("meta"): + self.model = ModelWrapper( + InnerModel( + num_layers=16, + hidden_size=1024, + vocab_size=self.vocab_size, + ) + ) + self.configure_fp8() + self.configure_fsdp2() + self.configure_torch_compile() + self.model.train() + + def on_train_batch_start(self, batch, batch_idx): + super().on_train_batch_start(batch, batch_idx) + self.hand_roll_base_zero_grad() + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): + super().on_validation_batch_start(batch, batch_idx, dataloader_idx) + self.hand_roll_base_zero_grad() + + def hand_roll_base_zero_grad(self): + # to resolve the torch compile + FSDP1 issue https://github.com/pytorch/pytorch/issues/139110 + if self.enable_torch_compile and not self.enable_fsdp2: + self.zero_grad(set_to_none=True) + for p in self.parameters(): + if p._base is not None and p._base.grad is not None: + p._base._grad = None + + def on_before_optimizer_step(self, optimizer): + self.fp8_handler.sync_float8_amax_and_scale_history(self.model) + super().on_before_optimizer_step(optimizer) + + def on_train_batch_end(self, outputs, batch, batch_idx): + self.fp8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model) + super().on_train_batch_end(outputs, batch, batch_idx) + + def training_step(self, batch): + input, target = batch + output = self.model(input) + log_softmax = nn.LogSoftmax(dim=1) + loss = F.nll_loss(log_softmax(output).view(-1, self.vocab_size), target.view(-1)) + self.log("train_loss", loss, prog_bar=True) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=1e-4) + + +def train(args): + L.seed_everything(42) + + dataset = WikiText2() + train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1) + + model = LanguageModel( + vocab_size=args.vocab_size, + enable_fp8=args.enable_fp8, + enable_fsdp2=args.enable_fsdp2, + enable_torch_compile=args.enable_torch_compile, + enable_gradient_checkpointing=args.enable_gradient_checkpointing, + enable_cpu_offload=args.enable_cpu_offload, + ) + + if args.enable_fsdp2: + strategy = ModelParallelStrategy( + data_parallel_size=1, + tensor_parallel_size=1, + ) + else: + layers = {SimpleLayer} + strategy = FSDPStrategy( + auto_wrap_policy=layers, + sharding_strategy="FULL_SHARD", + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + sync_module_states=True, + activation_checkpointing_policy=layers if args.enable_gradient_checkpointing else None, + # for FSDP, we set mixed precision here instead of passing precision to PL trainer. + # precision="bf16-true" in PL trainer means pure half precision (including optimizer update etc.) + # while precision="bf16-mixed" results in unshard allgather performed in fp32: + # https://github.com/Lightning-AI/pytorch-lightning/blob/bf25167bbf64f50ba335aa759318946b21775cd2/src/lightning/fabric/plugins/precision/fsdp.py#L83 + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + cpu_offload=args.enable_cpu_offload, + ) + trainer = L.Trainer(strategy=strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8) + + trainer.fit(model, train_dataloader) + + trainer.print(torch.cuda.memory_summary()) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train a language model.") + parser.add_argument("--vocab_size", type=int, default=32000, help="Vocabulary size. Default is 32000.") + parser.add_argument("--enable_fp8", action="store_true", help="Enable FP8 precision.") + parser.add_argument("--enable_torch_compile", action="store_true", help="Enable Torch Compile.") + parser.add_argument("--enable_cpu_offload", action="store_true", help="Enable CPU offload.") + parser.add_argument("--enable_gradient_checkpointing", action="store_true", help="Enable gradient checkpointing.") + parser.add_argument("--enable_fsdp2", action="store_true", help="Enable FSDP2.") + args = parser.parse_args() + return Args(**vars(args)) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + args = parse_args() + train(args)