Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit c7e087d

Browse files
committed
add infenerence module workflow
1 parent 2ea0ab5 commit c7e087d

File tree

3 files changed

+412
-0
lines changed

3 files changed

+412
-0
lines changed

float8_experimental/float8_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def to_float8(
264264
scale: the scale to use to convert the tensor
265265
float8_dtype: the float8 dtype to use
266266
amax_buffer: a buffer to store the amax value in prior to conversion
267+
mm_config: Defines the configuration for the scaled_mm
267268
268269
Returns:
269270
Float8Tensor: a float8 tensor

float8_experimental/inference.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Defines an nn module designed to be used during inference
8+
"""
9+
from dataclasses import dataclass
10+
11+
from enum import auto, Enum
12+
from typing import Callable, List, Optional
13+
14+
import torch
15+
import torch.nn as nn
16+
17+
from float8_experimental.float8_tensor import (
18+
Float8Tensor,
19+
ScaledMMConfig,
20+
tensor_already_casted_to_fp8,
21+
to_fp8_no_autograd,
22+
)
23+
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
24+
25+
26+
class ActivationCasting(Enum):
27+
"""Types of quantization to perform on the activations
28+
29+
WEIGHT_ONLY: Only quantize the weight, no activation casting, weight will be dequantized in the forward pass
30+
STATIC: Activation is quantized during model initialization with a static scale
31+
DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation
32+
"""
33+
34+
WEIGHT_ONLY = auto()
35+
DYNAMIC = auto()
36+
STATIC = auto()
37+
38+
39+
@dataclass(frozen=True)
40+
class QuantConfig:
41+
"""Defines the configuration for the quantization to fp8 of a linear module
42+
43+
Args:
44+
activation_casting: The type of quantization to perform on the activations
45+
activation_scale: The scale of the input to this linear module, used for static quantization only
46+
"""
47+
48+
activation_casting: ActivationCasting
49+
activation_scale: Optional[torch.Tensor] = None
50+
51+
def __post_init__(self):
52+
if self.activation_casting == ActivationCasting.STATIC:
53+
assert isinstance(
54+
self.activation_scale, torch.Tensor
55+
), "When activation_casting is 'static', activation_scale must be a tensor."
56+
57+
58+
class Float8LinearInference(torch.nn.Linear):
59+
"""
60+
This is a wrapper around torch.nn.Linear that supports FP8 inference
61+
Supported forms of infernce:
62+
- FP8 inference with fp32 matmul - weight only
63+
- FP8 inference with fp8 matmul and dynamic weight casting
64+
- FP8 inference with fp8 matmul and static weight casting
65+
"""
66+
67+
def __init__(self, **super_kwargs):
68+
super().__init__(**super_kwargs)
69+
70+
def forward(self, input: torch.Tensor) -> torch.Tensor:
71+
if self.activation_casting == ActivationCasting.WEIGHT_ONLY:
72+
return torch.nn.functional.linear(
73+
input, self.weight.to_original_precision()
74+
)
75+
76+
x_fp8 = cast_to_float8_e4m3fn(
77+
input, self.forward_config, activation_scale=self.activation_scale
78+
)
79+
return torch.nn.functional.linear(x_fp8, self.weight, self.bias)
80+
81+
# Builder functions for Float8LinearInference
82+
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
83+
"""This functions converts the weight to a Float8Tensor and sets its requires_grad to False.
84+
85+
Args:
86+
dtype: The dtype to quantize the weight to. Default is e4m3_dtype.
87+
88+
Note:
89+
This function is typically called during inference to quantize the weight once since
90+
the weight is not updated during inference.
91+
92+
"""
93+
assert not isinstance(
94+
self.weight, Float8Tensor
95+
), "Weight has already been quantized, cannot quantize again."
96+
scale = tensor_to_scale(self.weight, dtype)
97+
quantized_weight = to_fp8_no_autograd(
98+
self.weight,
99+
scale,
100+
dtype,
101+
self.forward_config,
102+
)
103+
self.weight = nn.Parameter(quantized_weight)
104+
self.weight.requires_grad = False
105+
106+
@classmethod
107+
def create_meta_class(
108+
cls, in_features: int, out_features: int
109+
) -> "Float8LinearInference":
110+
with torch.device("meta"):
111+
return cls(in_features=in_features, out_features=out_features, bias=False)
112+
113+
def set_mm_config(self, emulate: bool) -> "Float8LinearInference":
114+
self.forward_config: ScaledMMConfig = ScaledMMConfig(emulate, not emulate)
115+
return self
116+
117+
def set_weight_and_bias(
118+
self, weight: torch.nn.Parameter, bias: Optional[torch.nn.Parameter]
119+
) -> "Float8LinearInference":
120+
self.weight = weight
121+
self.bias = bias
122+
return self
123+
124+
def set_quantization_config(
125+
self,
126+
quant_config: QuantConfig,
127+
) -> "Float8LinearInference":
128+
# We destructure the quant_config into the individual fields
129+
# If an activation config is passed in we want to register that as a buffer
130+
self.activation_casting: ActivationCasting = quant_config.activation_casting
131+
self.quantize_weight()
132+
133+
if self.activation_casting == ActivationCasting.STATIC:
134+
self.register_buffer("activation_scale", quant_config.activation_scale)
135+
else:
136+
self.activation_scale = None
137+
return self
138+
139+
@classmethod
140+
def from_float(
141+
cls,
142+
module: nn.Module,
143+
quant_config: QuantConfig,
144+
) -> "Float8LinearInference":
145+
"""
146+
Create an nn.Linear with fp8 compute from a regular nn.Linear
147+
148+
Args:
149+
mod (torch.nn.Linear): nn.Linear to convert
150+
quant_config (QuantConfig): Configuration for the weight and activation casting
151+
"""
152+
return (
153+
cls.create_meta_class(module.in_features, module.out_features)
154+
.set_weight_and_bias(module.weight, module.bias)
155+
.set_mm_config(False)
156+
.set_quantization_config(quant_config)
157+
)
158+
159+
160+
def cast_to_float8_e4m3fn(
161+
inpt_tensor: torch.Tensor,
162+
mm_config: ScaledMMConfig,
163+
reduce_amax: bool = False,
164+
activation_scale: Optional[torch.Tensor] = None,
165+
) -> Float8Tensor:
166+
"""Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.
167+
168+
Args:
169+
inpt_tensor: The input tensor to be cast.
170+
mm_config: Configuration settings for the matrix multiplication
171+
reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
172+
activation_scale: Optional tensor specifying the scale for activation. Default is None.
173+
174+
Returns:
175+
Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
176+
177+
Note:
178+
If the input tensor is already in Float8 format, it is returned as is without re-casting.
179+
"""
180+
if tensor_already_casted_to_fp8(inpt_tensor):
181+
return inpt_tensor
182+
scale = (
183+
activation_scale
184+
if activation_scale is not None
185+
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
186+
)
187+
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
188+
189+
190+
def quantize_to_float8(
191+
module: nn.Module,
192+
quant_config: QuantConfig,
193+
*,
194+
skip_fqn_list: Optional[List[str]] = None,
195+
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
196+
) -> nn.Module:
197+
"""
198+
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
199+
of ``module_cls`` (either ``Float8Linear`` or ``Float8DynamicLinear``).
200+
201+
Args:
202+
module (torch.nn.Module): Module to modify.
203+
module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
204+
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
205+
Linear submodules of these skipped modules will also be skipped.
206+
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
207+
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
208+
that pass the filter function will be swapped.
209+
"""
210+
module_names_to_skip = set(skip_fqn_list or [])
211+
if isinstance(module, nn.Linear) and (
212+
linear_layer_filter is None or linear_layer_filter(module)
213+
):
214+
if len(list(module.children())) > 0:
215+
raise AssertionError(
216+
f"Does not support a root nn.Linear with children: {module}"
217+
)
218+
return Float8LinearInference.from_float(module, quant_config)
219+
220+
# Mark all modules to skip as visited
221+
root_module = module
222+
visited_modules = {root_module}
223+
for module_name, module in root_module.named_modules():
224+
if module_name in module_names_to_skip:
225+
visited_modules.add(module)
226+
227+
# Run a post-order traversal to swap linears
228+
def post_order_traversal(
229+
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
230+
):
231+
nonlocal visited_modules
232+
for child_module_name, child_module in module.named_children():
233+
if child_module not in visited_modules:
234+
visited_modules.add(child_module)
235+
post_order_traversal(child_module, child_module_name, module)
236+
if isinstance(module, nn.Linear) and (
237+
linear_layer_filter is None or linear_layer_filter(module)
238+
):
239+
assert (
240+
parent_module is not None
241+
), f"Linear root module should return early: {module}"
242+
float8linear_module = Float8LinearInference.from_float(module, quant_config)
243+
setattr(parent_module, module_name, float8linear_module)
244+
245+
post_order_traversal(root_module, "", None)
246+
# Without this explicit `del`, this set only gets deleted upon an explicit
247+
# garbage collection (not from when its refcount hits zero)
248+
del visited_modules
249+
return root_module

0 commit comments

Comments
 (0)