diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 2709b223..91a723ba 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -25,6 +25,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip3 install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 pip install -e . pip install -e .'[dev]' pip install -e .'[test]' diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 148ca6d1..d8447e3a 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -21,7 +21,6 @@ from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( - get_float8_linear, linear_requires_sync, LinearType, swap_linear_with_float8_linear, diff --git a/benchmarks/utils.py b/benchmarks/utils.py index fd4d5014..aec19e2c 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import collections -import json import re diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 72c09052..88227968 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -5,6 +5,11 @@ # LICENSE file in the root directory of this source tree. # Lets define a few top level things here from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig + +# Needed to load Float8Tensor with weights_only = True +from torch.serialization import add_safe_globals + +add_safe_globals([Float8Tensor, ScaledMMConfig]) __all__ = ["Float8Tensor", "Float8Linear"] diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index ef0be7f5..bc75f772 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -62,8 +62,8 @@ class Float8DynamicLinear(torch.nn.Linear): def __init__(self, **super_kwargs): super().__init__(**super_kwargs) - def forward(self, x): - x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config) + def forward(self, input: torch.Tensor) -> torch.Tensor: + x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config) if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 3b3caed0..35380b94 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -312,10 +312,10 @@ def float8_post_forward(self): self.is_amax_initialized = True self.amax_and_scale_synced = False - def forward(self, x): - self.float8_pre_forward(x) + def forward(self, input: torch.Tensor) -> torch.Tensor: + self.float8_pre_forward(input) - x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized) + x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized) w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) y = torch.matmul(x_fp8, w_fp8.t()) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 92392006..881f40fe 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -6,7 +6,7 @@ import copy import logging from enum import auto, Enum -from typing import Callable, List, Optional, Type +from typing import Callable, List, Optional, Type, Union import torch import torch.distributed as dist @@ -97,28 +97,33 @@ def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear], ) -def swap_linear_with_float8_linear( +def swap_linear_layers( module: nn.Module, - module_cls: Type[nn.Module], + from_float_func: Callable[[nn.Linear], nn.Linear], *, skip_fqn_list: Optional[List[str]] = None, - emulate: bool = False, linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, -) -> nn.Module: +) -> Optional[nn.Module]: """ - Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances - of ``module_cls`` (either ``Float8Linear`` or ``Float8DynamicLinear``). + Generic function to swap linear layers in a module with a new type of linear layer. + + Note: + If applied to a root-level nn.Linear, the module will not be modified in place + and returned instead Args: - module (torch.nn.Module): Module to modify. - module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap. - skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip. - Linear submodules of these skipped modules will also be skipped. - emulate (bool): Whether to emulate the fp8 matmul logic in fp32. - linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers + module: Module to modify. + from_float_func: Function that accepts a linear layer and returns a new type of linear layer. + skip_fqn_list: If specified, a list of module FQNs to skip. + linear_layer_filter: If specified, only the linear layers that pass the filter function will be swapped. + from_float_kwargs: Additional keyword arguments for from_float_func. + + Returns: + nn.Module: The modified module with swapped linear layers. """ module_names_to_skip = set(skip_fqn_list or []) + if isinstance(module, nn.Linear) and ( linear_layer_filter is None or linear_layer_filter(module) ): @@ -126,16 +131,17 @@ def swap_linear_with_float8_linear( raise AssertionError( f"Does not support a root nn.Linear with children: {module}" ) - return module_cls.from_float(module, emulate=emulate) + return from_float_func( + module, + ) - # Mark all modules to skip as visited root_module = module visited_modules = {root_module} + for module_name, module in root_module.named_modules(): if module_name in module_names_to_skip: visited_modules.add(module) - # Run a post-order traversal to swap linears def post_order_traversal( module: nn.Module, module_name: str, parent_module: Optional[nn.Module] ): @@ -144,14 +150,15 @@ def post_order_traversal( if child_module not in visited_modules: visited_modules.add(child_module) post_order_traversal(child_module, child_module_name, module) + if isinstance(module, nn.Linear) and ( linear_layer_filter is None or linear_layer_filter(module) ): assert ( parent_module is not None ), f"Linear root module should return early: {module}" - float8linear_module = module_cls.from_float(module, emulate=emulate) - setattr(parent_module, module_name, float8linear_module) + new_linear_module = from_float_func(module) + setattr(parent_module, module_name, new_linear_module) post_order_traversal(root_module, "", None) # Without this explicit `del`, this set only gets deleted upon an explicit @@ -160,6 +167,22 @@ def post_order_traversal( return root_module +def swap_linear_with_float8_linear( + module: nn.Module, + module_cls: Union[Type[Float8Linear], Type[Float8DynamicLinear]], + *, + skip_fqn_list: Optional[List[str]] = None, + emulate: bool = False, + linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, +) -> Optional[nn.Module]: + return swap_linear_layers( + module, + lambda m: module_cls.from_float(m, emulate=emulate), + skip_fqn_list=skip_fqn_list, + linear_layer_filter=linear_layer_filter, + ) + + def get_float8_layers(model: torch.nn.Module): """Iterates through the model and returns all the Float8Linear layers. Args: diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 4644a4b7..26d4688c 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -266,6 +266,7 @@ def to_float8( scale: the scale to use to convert the tensor float8_dtype: the float8 dtype to use amax_buffer: a buffer to store the amax value in prior to conversion + mm_config: Defines the configuration for the scaled_mm Returns: Float8Tensor: a float8 tensor diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index f6d95a92..2be568eb 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -139,7 +139,7 @@ def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") -def compute_error(x: torch.Tensor, y: torch.Tensor): +def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes the error between two tensors in dB. For more details see: diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py new file mode 100644 index 00000000..1c931eed --- /dev/null +++ b/float8_experimental/inference.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +Defines an nn module designed to be used during inference +""" + +from dataclasses import dataclass + +from enum import auto, Enum +from typing import List, Optional + +import float8_experimental.config as config + +import torch +import torch.nn as nn +from float8_experimental.float8_linear_utils import swap_linear_layers + +from float8_experimental.float8_tensor import ( + Float8Tensor, + ScaledMMConfig, + tensor_already_casted_to_fp8, + to_fp8_no_autograd, +) +from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale + + +class ActivationCasting(Enum): + """Types of quantization to perform on the activations + + WEIGHT_ONLY: Only quantize the weight, no activation casting, weight will be dequantized in the forward pass + STATIC: Activation is quantized during model initialization with a static scale + DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation + """ + + # TODO: A better name would be NONE, we should unify this with torchao + WEIGHT_ONLY = auto() + DYNAMIC = auto() + STATIC = auto() + + +@dataclass(frozen=True) +class QuantConfig: + """Defines the configuration for the quantization to fp8 of a linear module + + Args: + activation_casting: The type of quantization to perform on the activations + static_quantization_scale: The scale of the input to this linear module, used for static quantization only + """ + + activation_casting: ActivationCasting + static_quantization_scale: Optional[torch.Tensor] = None + + def __post_init__(self): + if self.activation_casting == ActivationCasting.STATIC: + assert isinstance( + self.static_quantization_scale, torch.Tensor + ), "When activation_casting is 'static', activation_scale must be a tensor." + + +class Float8InferenceLinear(torch.nn.Linear): + """ + This is a wrapper around torch.nn.Linear that supports FP8 inference + Supported forms of inference: + - FP8 inference with high precision matmul - weight only + - FP8 inference with fp8 matmul and dynamic weight casting + - FP8 inference with fp8 matmul and static weight casting + """ + + def __init__( + self, + # FP8 specific arguments + quant_config: QuantConfig, + forward_config: ScaledMMConfig, + # nn.Linear arguments + in_features: int, + out_features: int, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + # Construct the superclass this will create dummy weights and biases + super().__init__(in_features, out_features, bias, device, dtype) + self.forward_config = forward_config + self.activation_casting = quant_config.activation_casting + if self.activation_casting == ActivationCasting.STATIC: + self.register_buffer( + "static_quantization_scale", quant_config.static_quantization_scale + ) + else: + self.static_quantization_scale = None + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.activation_casting == ActivationCasting.WEIGHT_ONLY: + return torch.nn.functional.linear( + input, self.weight.to_original_precision() + ) + + x_fp8 = cast_to_float8_e4m3_inference( + input, + self.forward_config, + static_quantization_scale=self.static_quantization_scale, + ) + return torch.nn.functional.linear(x_fp8, self.weight, self.bias) + + # Builder functions for Float8LinearInference + def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: + """This functions converts the weight to a Float8Tensor and sets its requires_grad to False. + + Args: + dtype: The dtype to quantize the weight to. Default is e4m3_dtype. + + Note: + This function is typically called during inference to quantize the weight once since + the weight is not updated during inference. + + """ + assert not isinstance( + self.weight, Float8Tensor + ), "Weight has already been quantized, cannot quantize again." + scale = tensor_to_scale(self.weight, dtype) + quantized_weight = to_fp8_no_autograd( + self.weight, + scale, + dtype, + self.forward_config, + ) + self.weight = nn.Parameter(quantized_weight) + self.weight.requires_grad = False + + def set_weight_and_bias( + self, weight: torch.nn.Parameter, bias: Optional[torch.nn.Parameter] + ): + self.weight = weight + self.bias = bias + + @classmethod + def from_float( + cls, module: nn.Module, quant_config: QuantConfig, use_fast_accum: bool + ) -> "Float8InferenceLinear": + """ + Create an nn.Linear with fp8 compute from another nn.Linear + + Args: + mod (torch.nn.Linear): nn.Linear to convert + quant_config (QuantConfig): Configuration for the weight and activation casting + """ + forward_config = ScaledMMConfig( + False, use_fast_accum, pad_inner_dim=config.pad_inner_dim + ) + linear = cls( + quant_config, + forward_config, + module.in_features, + module.out_features, + False, + device=torch.device("meta"), + ) + linear.set_weight_and_bias(module.weight, module.bias) + linear.quantize_weight() + return linear + + +def cast_to_float8_e4m3_inference( + inpt_tensor: torch.Tensor, + mm_config: ScaledMMConfig, + reduce_amax: bool = False, + static_quantization_scale: Optional[torch.Tensor] = None, +) -> Float8Tensor: + """Casts an input tensor to the Float8 (e4m3fn*) + + Args: + inpt_tensor: The input tensor to be cast. + mm_config: Configuration settings for the matrix multiplication + reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group. + static_quantization_scale: Optional tensor specifying the scale for activation. Default is None. + + Returns: + Float8Tensor: The input tensor cast to Float8 (e4m3fn) format. + + Note: + If the input tensor is already in Float8 format, it is returned as is without re-casting. + """ + if tensor_already_casted_to_fp8(inpt_tensor): + return inpt_tensor + scale = ( + static_quantization_scale + if static_quantization_scale is not None + else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) + ) + return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) + + +def quantize_to_float8( + module: nn.Module, + quant_config: QuantConfig, + *, + skip_fqn_list: Optional[List[str]] = None, + use_fast_accum: bool = True, +) -> Optional[nn.Module]: + """ + Converts torch.nn.Linear layers in the given module to Float8InferenceLinear. + + Note: + If applied to a root-level nn.Linear, the module will not be modified in place + and returned instead + + Args: + module (nn.Module): The module to modify. + quant_config (QuantConfig): Quantization configuration for Float8 conversion. + skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion. + use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True. + + Returns: + nn.Module: The modified module with applicable Linear layers converted to Float8. + + Raises: + AssertionError: If a root-level nn.Linear with children is encountered. + """ + return swap_linear_layers( + module, + lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), + skip_fqn_list=skip_fqn_list, + ) diff --git a/pyproject.toml b/pyproject.toml index 858e53b4..addd5220 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers = [ ] dependencies = [ - "torch >= 2.2", + "torch >= 2.3", ] [project.optional-dependencies] diff --git a/test/test_base.py b/test/test_base.py index 742a4b1e..7ce0b7bd 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import io import itertools import random import re @@ -13,6 +14,7 @@ import torch import torch.nn as nn + from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( @@ -37,6 +39,11 @@ FP8_TYPES, tensor_to_scale, ) +from float8_experimental.inference import ( + ActivationCasting, + QuantConfig, + quantize_to_float8, +) random.seed(0) torch.manual_seed(0) @@ -121,6 +128,21 @@ def test_copy_(self): fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) + def test_weights_only_load(self): + module = nn.Linear(16, 16) + # Save model state dict + buffer = io.BytesIO() + fp8_module = quantize_to_float8( + module, + QuantConfig( + ActivationCasting.DYNAMIC, + ), + ) + + torch.save(fp8_module.state_dict(), buffer) + buffer.seek(0) + _ = torch.load(buffer, weights_only=True) + class TestFloat8Linear: def _test_linear_impl( diff --git a/test/test_everything.sh b/test/test_everything.sh index b9893933..ada305d7 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -7,6 +7,7 @@ IS_ROCM=$(rocm-smi --version || true) pytest test/test_base.py pytest test/test_sam.py pytest test/test_compile.py +pytest test/test_inference_flows.py # These tests do not work on ROCm yet if [ -z "$IS_ROCM" ] diff --git a/test/test_fsdp.py b/test/test_fsdp.py index e0875573..1b5acd40 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -85,7 +85,7 @@ def fsdp_main(rank, world_size, args): model = get_model(K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype).to( rank ) - model.load_state_dict(torch.load(sd_in_fname)) + model.load_state_dict(torch.load(sd_in_fname, weights_only=True)) # To compile FSDP, we need use_orig_params to True model = FSDP(model, use_orig_params=True) # TODO: The following line doesn't work. We should fix it. @@ -95,7 +95,7 @@ def fsdp_main(rank, world_size, args): # optimizer update optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) - ref_input_global = torch.load(input_fname).to(base_dtype) + ref_input_global = torch.load(input_fname, weights_only=True).to(base_dtype) # basic distributed data sampling assert B % world_size == 0 @@ -175,11 +175,11 @@ def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = F torch.save(model.state_dict(), sd_in_fname) elif mode == "single_gpu": - ref_input = torch.load(input_fname).to(base_dtype) + ref_input = torch.load(input_fname, weights_only=True).to(base_dtype) model = get_model( K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype ).cuda() - model.load_state_dict(torch.load(sd_in_fname)) + model.load_state_dict(torch.load(sd_in_fname, weights_only=True)) optimizer = torch.optim.SGD(model.parameters(), lr=lr) def forward_backward(): @@ -203,8 +203,8 @@ def forward_backward(): mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) elif mode == "analyze": - y_single_gpu = torch.load(output_single_gpu_fname).cpu() - y_fsdp = torch.load(output_fsdp_fname).cpu() + y_single_gpu = torch.load(output_single_gpu_fname, weights_only=True).cpu() + y_fsdp = torch.load(output_fsdp_fname, weights_only=True).cpu() if is_fp8 and not emulate: atol, rtol = 2e-2, 2e-2 else: @@ -212,8 +212,8 @@ def forward_backward(): torch.testing.assert_close(y_single_gpu, y_fsdp, atol=atol, rtol=rtol) print("output testing single_gpu vs FSDP success") - sd_out_single_gpu = torch.load(sd_out_single_gpu_fname) - sd_out_fsdp = torch.load(sd_out_fsdp_fname) + sd_out_single_gpu = torch.load(sd_out_single_gpu_fname, weights_only=True) + sd_out_fsdp = torch.load(sd_out_fsdp_fname, weights_only=True) for k, v1 in sd_out_single_gpu.items(): if compile_fsdp: # The state-dict for compiled fsdp has a `_orig_mod` prefix diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py new file mode 100644 index 00000000..b0c00c6b --- /dev/null +++ b/test/test_inference_flows.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import copy +import io +import random +import unittest + +import pytest + +import torch +import torch.nn as nn +import torch.nn.functional as F +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear +from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear +from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_utils import compute_error +from float8_experimental.inference import ( + ActivationCasting, + Float8InferenceLinear, + QuantConfig, + quantize_to_float8, +) + + +random.seed(0) +torch.manual_seed(0) + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +class FeedForward(nn.Module): + def __init__(self) -> None: + super().__init__() + self.w1 = nn.Linear(4096, 14336, bias=False) + self.w3 = nn.Linear(4096, 14336, bias=False) + self.w2 = nn.Linear(14336, 4096, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + m.reset_parameters() + + +class TestHPTrainToFP8LinearInference: + def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor): + with torch.no_grad(): + base_output = base_mlp(input_tensor) + transformed_output = quantized_mlp(input_tensor) + + # Compute and check SQNR + sqnr = compute_error(base_output, transformed_output) + assert sqnr.item() > 20, f"SQNR is too low: {sqnr.item()} dB" + + @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_dynamic_fp8_mlp(self, compile_backend, dtype): + original_mlp = FeedForward().to("cuda", dtype=dtype) + original_mlp.reset_parameters() + + dynamic_fp8_mlp = copy.deepcopy(original_mlp) + + quant_config = QuantConfig(ActivationCasting.DYNAMIC) + quantize_to_float8(dynamic_fp8_mlp, quant_config) + + batch_size = 4 + num_tokens = 1024 + embedding_dim = 4096 + + input_tensor = torch.randn( + batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype + ) + + # Compile the models + compiled_original_mlp = torch.compile( + original_mlp, backend=compile_backend, fullgraph=True + ) + compiled_dynamic_fp8_mlp = torch.compile( + dynamic_fp8_mlp, backend=compile_backend, fullgraph=True + ) + + self.base_test_mlp_transform( + compiled_original_mlp, compiled_dynamic_fp8_mlp, input_tensor + ) + + @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_static_fp8_mlp(self, compile_backend, dtype): + original_mlp = FeedForward().to("cuda", dtype=dtype) + original_mlp.reset_parameters() + + static_fp8_mlp = copy.deepcopy(original_mlp) + quant_config = QuantConfig( + ActivationCasting.STATIC, + static_quantization_scale=torch.tensor( + [1.0], device="cuda", dtype=torch.float32 + ), + ) + quantize_to_float8(static_fp8_mlp, quant_config) + + batch_size = 4 + num_tokens = 1024 + embedding_dim = 4096 + + input_tensor = torch.randn( + batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype + ) + + # Compile the models + compiled_original_mlp = torch.compile( + original_mlp, backend=compile_backend, fullgraph=True + ) + compiled_static_fp8_mlp = torch.compile( + static_fp8_mlp, backend=compile_backend, fullgraph=True + ) + + self.base_test_mlp_transform( + compiled_original_mlp, compiled_static_fp8_mlp, input_tensor + ) + + @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_weight_only_fp8_mlp(self, compile_backend, dtype): + original_mlp = FeedForward().to("cuda", dtype=dtype) + original_mlp.reset_parameters() + + static_fp8_mlp = copy.deepcopy(original_mlp) + quant_config = QuantConfig(ActivationCasting.WEIGHT_ONLY) + quantize_to_float8(static_fp8_mlp, quant_config) + + batch_size = 4 + num_tokens = 1024 + embedding_dim = 4096 + + input_tensor = torch.randn( + batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype + ) + + # Compile the models + compiled_original_mlp = torch.compile( + original_mlp, backend=compile_backend, fullgraph=True + ) + compiled_static_fp8_mlp = torch.compile( + static_fp8_mlp, backend=compile_backend, fullgraph=True + ) + + self.base_test_mlp_transform( + compiled_original_mlp, compiled_static_fp8_mlp, input_tensor + ) + + +class TestFP8TrainToFP8LinearInference: + def train(self, model: nn.Module, dtype: torch.dtype): + model.train() + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + criterion = nn.MSELoss() + target_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) + for _ in range(10): + input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) + optimizer.zero_grad() + output = model(input_tensor) + loss = criterion(output, target_tensor) + loss.backward() + optimizer.step() + model.eval() + return model + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + @unittest.skipIf( + not torch.cuda.is_available() or not is_H100, + "CUDA not available or on non H100 machine", + ) + def test_fp8_save_and_load(self, dtype: torch.dtype): + # Initialize FP8 model + fp8_mlp = FeedForward().to("cuda", dtype=torch.float32) + fp8_mlp.reset_parameters() + swap_linear_with_float8_linear( + fp8_mlp, + Float8DynamicLinear, + ) + + # Train the model + self.train(fp8_mlp, dtype) + + # Generate input tensor and original out + input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) + og_out = fp8_mlp(input_tensor) + + # Save model state dict + buffer = io.BytesIO() + torch.save(fp8_mlp.state_dict(), buffer) + + # Reset buffer position to the beginning + buffer.seek(0) + + # Later on you load the model, will be w/ Float8DynamicLinear on meta device + with torch.device("meta"): + new_fp8_mlp = FeedForward().to(dtype=dtype) + swap_linear_with_float8_linear( + new_fp8_mlp, + Float8DynamicLinear, + ) + + # Load the actual data + new_fp8_mlp.load_state_dict( + torch.load(buffer, weights_only=True), strict=True, assign=True + ) + + quant_config = QuantConfig(ActivationCasting.DYNAMIC) + quantize_to_float8(new_fp8_mlp, quant_config) + + fp8_mod_count = 0 + for module in new_fp8_mlp.modules(): + if isinstance(module, Float8InferenceLinear): + assert isinstance(module.weight, Float8Tensor) + assert module.weight.requires_grad is False + fp8_mod_count += 1 + assert fp8_mod_count == 3, "Expected 3 FP8 modules, got {}".format( + fp8_mod_count + ) + + new_out = new_fp8_mlp(input_tensor) + + # Assert exact equality + assert torch.all(og_out == new_out).item() + + +if __name__ == "__main__": + pytest.main([__file__])