Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 5d5a48e

Browse files
committed
add static scaling option
1 parent d1eae9a commit 5d5a48e

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

float8_experimental/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,9 @@
77
from float8_experimental.float8_linear import Float8Linear
88
from float8_experimental.float8_tensor import Float8Tensor
99

10+
# Needed to load Float8Tensor with weights_only = True
11+
from torch.serialization import add_safe_globals
12+
13+
add_safe_globals([Float8Tensor])
14+
1015
__all__ = ["Float8Tensor", "Float8Linear"]

float8_experimental/float8_dynamic_linear.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ class Float8DynamicLinear(torch.nn.Linear):
6262
def __init__(self, **super_kwargs):
6363
super().__init__(**super_kwargs)
6464

65+
self.activation_scale: Optional[torch.Tensor] = None
66+
6567
def forward(self, x):
66-
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
68+
x_fp8 = cast_to_float8_e4m3fn(
69+
x, self.forward_config, activation_scale=self.activation_scale
70+
)
6771
if isinstance(self.weight, Float8Tensor): # cast by FSDP
6872
w_fp8 = self.weight
6973
else:
@@ -86,7 +90,11 @@ def static_quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> No
8690

8791
@classmethod
8892
def from_float(
89-
cls, mod, emulate: bool = False, static_quantize_weight: bool = False
93+
cls,
94+
mod,
95+
emulate: bool = False,
96+
static_quantize_weight: bool = False,
97+
activation_scale: Optional[torch.Tensor] = None,
9098
) -> "Float8DynamicLinear":
9199
"""
92100
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -96,6 +104,8 @@ def from_float(
96104
emulate (bool): whether to emulate fp8 matmul logic in float32
97105
static_quantize_weight (bool): whether to quantize the weight statically, this is useful
98106
for inference where weights are not updated.
107+
activation_scale (torch.Tensor): The scale of the input to this linear module, used for
108+
for inference when a statically known scale is available.
99109
"""
100110
with torch.device("meta"):
101111
super_kwargs = {
@@ -116,16 +126,38 @@ def from_float(
116126
if static_quantize_weight:
117127
new_mod.static_quantize_weight()
118128

129+
new_mod.activation_scale = activation_scale
119130
new_mod.bias = mod.bias
120131
return new_mod
121132

122133

123134
def cast_to_float8_e4m3fn(
124-
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
135+
inpt_tensor: torch.Tensor,
136+
mm_config: ScaledMMConfig,
137+
reduce_amax: bool = False,
138+
activation_scale: Optional[torch.Tensor] = None,
125139
) -> Float8Tensor:
140+
"""Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.
141+
142+
Args:
143+
inpt_tensor: The input tensor to be cast.
144+
mm_config: Configuration settings for the matrix multiplication
145+
reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
146+
activation_scale: Optional tensor specifying the scale for activation. Default is None.
147+
148+
Returns:
149+
Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
150+
151+
Note:
152+
If the input tensor is already in Float8 format, it is returned as is without re-casting.
153+
"""
126154
if tensor_already_casted_to_fp8(inpt_tensor):
127155
return inpt_tensor
128-
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
156+
scale = (
157+
activation_scale
158+
if activation_scale is not None
159+
else tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
160+
)
129161
return Float8Tensor.to_float8(
130162
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config
131163
)

0 commit comments

Comments
 (0)