Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b89515c

Browse files
committed
add infenerence module workflow
1 parent 2ea0ab5 commit b89515c

File tree

3 files changed

+413
-0
lines changed

3 files changed

+413
-0
lines changed

float8_experimental/float8_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def to_float8(
264264
scale: the scale to use to convert the tensor
265265
float8_dtype: the float8 dtype to use
266266
amax_buffer: a buffer to store the amax value in prior to conversion
267+
mm_config: Defines the configuration for the scaled_mm
267268
268269
Returns:
269270
Float8Tensor: a float8 tensor

float8_experimental/inference.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""
7+
Defines an nn module designed to be used during inference
8+
"""
9+
from dataclasses import dataclass
10+
11+
from enum import auto, Enum
12+
from typing import Callable, List, Optional
13+
14+
import torch
15+
import torch.nn as nn
16+
17+
from float8_experimental.float8_tensor import (
18+
Float8Tensor,
19+
ScaledMMConfig,
20+
tensor_already_casted_to_fp8,
21+
to_fp8_no_autograd,
22+
)
23+
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale
24+
25+
26+
class ActivationCasting(Enum):
27+
"""Types of quantization to perform on the activations
28+
29+
WEIGHT_ONLY: Only quantize the weight, no activation casting, weight will be dequantized in the forward pass
30+
STATIC: Activation is quantized during model initialization with a static scale
31+
DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation
32+
"""
33+
34+
WEIGHT_ONLY = auto()
35+
DYNAMIC = auto()
36+
STATIC = auto()
37+
38+
39+
@dataclass(frozen=True)
40+
class QuantConfig:
41+
"""Defines the configuration for the quantization to fp8 of a linear module
42+
43+
Args:
44+
activation_casting: The type of quantization to perform on the activations
45+
activation_scale: The scale of the input to this linear module, used for static quantization only
46+
"""
47+
48+
activation_casting: ActivationCasting
49+
activation_scale: Optional[torch.Tensor] = None
50+
51+
def __post_init__(self):
52+
if self.activation_casting == ActivationCasting.STATIC:
53+
assert isinstance(
54+
self.activation_scale, torch.Tensor
55+
), "When activation_casting is 'static', activation_scale must be a tensor."
56+
57+
58+
class Float8LinearInference(torch.nn.Linear):
59+
"""
60+
This is a wrapper around torch.nn.Linear that supports FP8 inference
61+
Supported forms of infernce:
62+
- FP8 inference with fp32 matmul - weight only
63+
- FP8 inference with fp8 matmul and dynamic weight casting
64+
- FP8 inference with fp8 matmul and static weight casting
65+
"""
66+
67+
def __init__(self, **super_kwargs):
68+
super().__init__(**super_kwargs)
69+
70+
def forward(self, input: torch.Tensor) -> torch.Tensor:
71+
if self.activation_casting == ActivationCasting.WEIGHT_ONLY:
72+
return torch.nn.functional.linear(
73+
input, self.weight.to_original_precision()
74+
)
75+
76+
x_fp8 = cast_to_float8_e4m3fn(
77+
input, self.forward_config, activation_scale=self.activation_scale
78+
)
79+
return torch.nn.functional.linear(x_fp8, self.weight, self.bias)
80+
81+
# Builder functions for Float8LinearInference
82+
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
83+
"""This functions converts the weight to a Float8Tensor and sets its requires_grad to False.
84+
85+
Args:
86+
dtype: The dtype to quantize the weight to. Default is e4m3_dtype.
87+
88+
Note:
89+
This function is typically called during inference to quantize the weight once since
90+
the weight is not updated during inference.
91+
92+
"""
93+
assert not isinstance(
94+
self.weight, Float8Tensor
95+
), "Weight has already been quantized, cannot quantize again."
96+
scale = tensor_to_scale(self.weight, dtype)
97+
quantized_weight = to_fp8_no_autograd(
98+
self.weight,
99+
scale,
100+
dtype,
101+
self.forward_config,
102+
)
103+
self.weight = nn.Parameter(quantized_weight)
104+
self.weight.requires_grad = False
105+
106+
@classmethod
107+
def create_meta_class(
108+
cls, in_features: int, out_features: int
109+
) -> "Float8LinearInference":
110+
with torch.device("meta"):
111+
return cls(in_features=in_features, out_features=out_features, bias=False)
112+
113+
def set_mm_config(self, use_fast_accum: bool = True) -> "Float8LinearInference":
114+
"""TODO Hardcode for now but we could/should likely add this to the constructor"""
115+
self.forward_config: ScaledMMConfig = ScaledMMConfig(False, use_fast_accum)
116+
return self
117+
118+
def set_weight_and_bias(
119+
self, weight: torch.nn.Parameter, bias: Optional[torch.nn.Parameter]
120+
) -> "Float8LinearInference":
121+
self.weight = weight
122+
self.bias = bias
123+
return self
124+
125+
def set_quantization_config(
126+
self,
127+
quant_config: QuantConfig,
128+
) -> "Float8LinearInference":
129+
# We destructure the quant_config into the individual fields
130+
# If an activation config is passed in we want to register that as a buffer
131+
self.activation_casting: ActivationCasting = quant_config.activation_casting
132+
self.quantize_weight()
133+
134+
if self.activation_casting == ActivationCasting.STATIC:
135+
self.register_buffer("activation_scale", quant_config.activation_scale)
136+
else:
137+
self.activation_scale = None
138+
return self
139+
140+
@classmethod
141+
def from_float(
142+
cls,
143+
module: nn.Module,
144+
quant_config: QuantConfig,
145+
) -> "Float8LinearInference":
146+
"""
147+
Create an nn.Linear with fp8 compute from a regular nn.Linear
148+
149+
Args:
150+
mod (torch.nn.Linear): nn.Linear to convert
151+
quant_config (QuantConfig): Configuration for the weight and activation casting
152+
"""
153+
return (
154+
cls.create_meta_class(module.in_features, module.out_features)
155+
.set_weight_and_bias(module.weight, module.bias)
156+
.set_mm_config(False)
157+
.set_quantization_config(quant_config)
158+
)
159+
160+
161+
def cast_to_float8_e4m3fn(
162+
inpt_tensor: torch.Tensor,
163+
mm_config: ScaledMMConfig,
164+
reduce_amax: bool = False,
165+
activation_scale: Optional[torch.Tensor] = None,
166+
) -> Float8Tensor:
167+
"""Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.
168+
169+
Args:
170+
inpt_tensor: The input tensor to be cast.
171+
mm_config: Configuration settings for the matrix multiplication
172+
reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
173+
activation_scale: Optional tensor specifying the scale for activation. Default is None.
174+
175+
Returns:
176+
Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
177+
178+
Note:
179+
If the input tensor is already in Float8 format, it is returned as is without re-casting.
180+
"""
181+
if tensor_already_casted_to_fp8(inpt_tensor):
182+
return inpt_tensor
183+
scale = (
184+
activation_scale
185+
if activation_scale is not None
186+
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
187+
)
188+
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
189+
190+
191+
def quantize_to_float8(
192+
module: nn.Module,
193+
quant_config: QuantConfig,
194+
*,
195+
skip_fqn_list: Optional[List[str]] = None,
196+
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
197+
) -> nn.Module:
198+
"""
199+
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
200+
of ``module_cls`` (either ``Float8Linear`` or ``Float8DynamicLinear``).
201+
202+
Args:
203+
module (torch.nn.Module): Module to modify.
204+
module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
205+
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
206+
Linear submodules of these skipped modules will also be skipped.
207+
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
208+
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
209+
that pass the filter function will be swapped.
210+
"""
211+
module_names_to_skip = set(skip_fqn_list or [])
212+
if isinstance(module, nn.Linear) and (
213+
linear_layer_filter is None or linear_layer_filter(module)
214+
):
215+
if len(list(module.children())) > 0:
216+
raise AssertionError(
217+
f"Does not support a root nn.Linear with children: {module}"
218+
)
219+
return Float8LinearInference.from_float(module, quant_config)
220+
221+
# Mark all modules to skip as visited
222+
root_module = module
223+
visited_modules = {root_module}
224+
for module_name, module in root_module.named_modules():
225+
if module_name in module_names_to_skip:
226+
visited_modules.add(module)
227+
228+
# Run a post-order traversal to swap linears
229+
def post_order_traversal(
230+
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
231+
):
232+
nonlocal visited_modules
233+
for child_module_name, child_module in module.named_children():
234+
if child_module not in visited_modules:
235+
visited_modules.add(child_module)
236+
post_order_traversal(child_module, child_module_name, module)
237+
if isinstance(module, nn.Linear) and (
238+
linear_layer_filter is None or linear_layer_filter(module)
239+
):
240+
assert (
241+
parent_module is not None
242+
), f"Linear root module should return early: {module}"
243+
float8linear_module = Float8LinearInference.from_float(module, quant_config)
244+
setattr(parent_module, module_name, float8linear_module)
245+
246+
post_order_traversal(root_module, "", None)
247+
# Without this explicit `del`, this set only gets deleted upon an explicit
248+
# garbage collection (not from when its refcount hits zero)
249+
del visited_modules
250+
return root_module

0 commit comments

Comments
 (0)