Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 36405a7

Browse files
drisspgfacebook-github-bot
authored andcommitted
Add a Float8LinearInference module to support static, dynamic, and wo quant (#287)
Summary: # Perf script: https://gist.github.com/drisspg/f7a553710d64cce013227a2249d582d2 ## Performance In eager this produces: | Operation | Time (μs) | |-----------------------------------|------------| | bf16 | 2667.9172 | | fp8_dynamic_activations | 2494.7294 | | fp8_static_activations | 2449.1784 | | fp8_weight_only_activations | 4084.7190 | With compile this produces: | Operation | Time (μs) | |------------------------------|------------| | bf16 | 2547.1938 | | fp8_dynamic_activations | 1542.0729 | | fp8_static_activations | 1407.0310 | | fp8_weight_only_activations | 2750.6369 | ## UX #### Dynamic activation quantization ``` Python original_mlp = FeedForward().to("cuda", dtype=dtype) original_mlp.reset_parameters() dynamic_fp8_mlp = copy.deepcopy(original_mlp) quant_config = QuantConfig(ActivationCasting.DYNAMIC) quantize_to_float8(dynamic_fp8_mlp, quant_config) ``` #### Static activation quantization ```Python original_mlp = FeedForward().to("cuda", dtype=dtype) original_mlp.reset_parameters() static_fp8_mlp = copy.deepcopy(original_mlp) quant_config = QuantConfig( ActivationCasting.STATIC, static_quantization_scale=torch.tensor( [1.0], device="cuda", dtype=torch.float32 ), ) quantize_to_float8(static_fp8_mlp, quant_config) ``` #### Weight Only quantization ``` Python original_mlp = FeedForward().to("cuda", dtype=dtype) original_mlp.reset_parameters() wo_fp8_mlp = copy.deepcopy(original_mlp) quant_config = QuantConfig(ActivationCasting.WEIGHT_ONLY) quantize_to_float8(wo_fp8_mlp, quant_config) ``` All of these are using Per-Tensor scaling will add in a follow up PR row-wise scaling and likely make this the default. Pull Request resolved: #287 Reviewed By: vkuzo Differential Revision: D59179113 Pulled By: drisspg fbshipit-source-id: 7938efbcbc51109d2ff7261275ca04d1b90732d3
1 parent 0b60496 commit 36405a7

15 files changed

+559
-36
lines changed

.github/workflows/python-app.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
- name: Install dependencies
2626
run: |
2727
python -m pip install --upgrade pip
28+
pip3 install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
2829
pip install -e .
2930
pip install -e .'[dev]'
3031
pip install -e .'[test]'

benchmarks/profile_linear_float8.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
2222
from float8_experimental.float8_linear import Float8Linear
2323
from float8_experimental.float8_linear_utils import (
24-
get_float8_linear,
2524
linear_requires_sync,
2625
LinearType,
2726
swap_linear_with_float8_linear,

benchmarks/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import collections
8-
import json
98
import re
109

1110

float8_experimental/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
77
from float8_experimental.float8_linear import Float8Linear
8-
from float8_experimental.float8_tensor import Float8Tensor
8+
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
9+
10+
# Needed to load Float8Tensor with weights_only = True
11+
from torch.serialization import add_safe_globals
12+
13+
add_safe_globals([Float8Tensor, ScaledMMConfig])
914

1015
__all__ = ["Float8Tensor", "Float8Linear"]

float8_experimental/float8_dynamic_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ class Float8DynamicLinear(torch.nn.Linear):
6262
def __init__(self, **super_kwargs):
6363
super().__init__(**super_kwargs)
6464

65-
def forward(self, x):
66-
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
65+
def forward(self, input: torch.Tensor) -> torch.Tensor:
66+
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
6767
if isinstance(self.weight, Float8Tensor): # cast by FSDP
6868
w_fp8 = self.weight
6969
else:

float8_experimental/float8_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,10 @@ def float8_post_forward(self):
312312
self.is_amax_initialized = True
313313
self.amax_and_scale_synced = False
314314

315-
def forward(self, x):
316-
self.float8_pre_forward(x)
315+
def forward(self, input: torch.Tensor) -> torch.Tensor:
316+
self.float8_pre_forward(input)
317317

318-
x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
318+
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
319319
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
320320

321321
y = torch.matmul(x_fp8, w_fp8.t())

float8_experimental/float8_linear_utils.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import copy
77
import logging
88
from enum import auto, Enum
9-
from typing import Callable, List, Optional, Type
9+
from typing import Callable, List, Optional, Type, Union
1010

1111
import torch
1212
import torch.distributed as dist
@@ -97,45 +97,51 @@ def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear],
9797
)
9898

9999

100-
def swap_linear_with_float8_linear(
100+
def swap_linear_layers(
101101
module: nn.Module,
102-
module_cls: Type[nn.Module],
102+
from_float_func: Callable[[nn.Linear], nn.Linear],
103103
*,
104104
skip_fqn_list: Optional[List[str]] = None,
105-
emulate: bool = False,
106105
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
107-
) -> nn.Module:
106+
) -> Optional[nn.Module]:
108107
"""
109-
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
110-
of ``module_cls`` (either ``Float8Linear`` or ``Float8DynamicLinear``).
108+
Generic function to swap linear layers in a module with a new type of linear layer.
109+
110+
Note:
111+
If applied to a root-level nn.Linear, the module will not be modified in place
112+
and returned instead
111113
112114
Args:
113-
module (torch.nn.Module): Module to modify.
114-
module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
115-
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
116-
Linear submodules of these skipped modules will also be skipped.
117-
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
118-
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
115+
module: Module to modify.
116+
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
117+
skip_fqn_list: If specified, a list of module FQNs to skip.
118+
linear_layer_filter: If specified, only the linear layers
119119
that pass the filter function will be swapped.
120+
from_float_kwargs: Additional keyword arguments for from_float_func.
121+
122+
Returns:
123+
nn.Module: The modified module with swapped linear layers.
120124
"""
121125
module_names_to_skip = set(skip_fqn_list or [])
126+
122127
if isinstance(module, nn.Linear) and (
123128
linear_layer_filter is None or linear_layer_filter(module)
124129
):
125130
if len(list(module.children())) > 0:
126131
raise AssertionError(
127132
f"Does not support a root nn.Linear with children: {module}"
128133
)
129-
return module_cls.from_float(module, emulate=emulate)
134+
return from_float_func(
135+
module,
136+
)
130137

131-
# Mark all modules to skip as visited
132138
root_module = module
133139
visited_modules = {root_module}
140+
134141
for module_name, module in root_module.named_modules():
135142
if module_name in module_names_to_skip:
136143
visited_modules.add(module)
137144

138-
# Run a post-order traversal to swap linears
139145
def post_order_traversal(
140146
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
141147
):
@@ -144,14 +150,15 @@ def post_order_traversal(
144150
if child_module not in visited_modules:
145151
visited_modules.add(child_module)
146152
post_order_traversal(child_module, child_module_name, module)
153+
147154
if isinstance(module, nn.Linear) and (
148155
linear_layer_filter is None or linear_layer_filter(module)
149156
):
150157
assert (
151158
parent_module is not None
152159
), f"Linear root module should return early: {module}"
153-
float8linear_module = module_cls.from_float(module, emulate=emulate)
154-
setattr(parent_module, module_name, float8linear_module)
160+
new_linear_module = from_float_func(module)
161+
setattr(parent_module, module_name, new_linear_module)
155162

156163
post_order_traversal(root_module, "", None)
157164
# Without this explicit `del`, this set only gets deleted upon an explicit
@@ -160,6 +167,22 @@ def post_order_traversal(
160167
return root_module
161168

162169

170+
def swap_linear_with_float8_linear(
171+
module: nn.Module,
172+
module_cls: Union[Type[Float8Linear], Type[Float8DynamicLinear]],
173+
*,
174+
skip_fqn_list: Optional[List[str]] = None,
175+
emulate: bool = False,
176+
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
177+
) -> Optional[nn.Module]:
178+
return swap_linear_layers(
179+
module,
180+
lambda m: module_cls.from_float(m, emulate=emulate),
181+
skip_fqn_list=skip_fqn_list,
182+
linear_layer_filter=linear_layer_filter,
183+
)
184+
185+
163186
def get_float8_layers(model: torch.nn.Module):
164187
"""Iterates through the model and returns all the Float8Linear layers.
165188
Args:

float8_experimental/float8_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def to_float8(
266266
scale: the scale to use to convert the tensor
267267
float8_dtype: the float8 dtype to use
268268
amax_buffer: a buffer to store the amax value in prior to conversion
269+
mm_config: Defines the configuration for the scaled_mm
269270
270271
Returns:
271272
Float8Tensor: a float8 tensor

float8_experimental/float8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
139139
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
140140

141141

142-
def compute_error(x: torch.Tensor, y: torch.Tensor):
142+
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
143143
"""Computes the error between two tensors in dB.
144144
145145
For more details see:

0 commit comments

Comments
 (0)