Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 12b32d3

Browse files
committed
add static scaling option
1 parent d1eae9a commit 12b32d3

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ class Float8DynamicLinear(torch.nn.Linear):
6262
def __init__(self, **super_kwargs):
6363
super().__init__(**super_kwargs)
6464

65+
self.activation_scale: Optional[torch.Tensor] = None
66+
6567
def forward(self, x):
66-
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config)
68+
x_fp8 = cast_to_float8_e4m3fn(x, self.forward_config, self.activation_scale)
6769
if isinstance(self.weight, Float8Tensor): # cast by FSDP
6870
w_fp8 = self.weight
6971
else:
@@ -86,7 +88,11 @@ def static_quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> No
8688

8789
@classmethod
8890
def from_float(
89-
cls, mod, emulate: bool = False, static_quantize_weight: bool = False
91+
cls,
92+
mod,
93+
emulate: bool = False,
94+
static_quantize_weight: bool = False,
95+
activation_scale: Optional[torch.Tensor] = None,
9096
) -> "Float8DynamicLinear":
9197
"""
9298
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -96,6 +102,8 @@ def from_float(
96102
emulate (bool): whether to emulate fp8 matmul logic in float32
97103
static_quantize_weight (bool): whether to quantize the weight statically, this is useful
98104
for inference where weights are not updated.
105+
activation_scale (torch.Tensor): The scale of the input to this linear module, used for
106+
for inference when a statically known scale is available.
99107
"""
100108
with torch.device("meta"):
101109
super_kwargs = {
@@ -116,16 +124,38 @@ def from_float(
116124
if static_quantize_weight:
117125
new_mod.static_quantize_weight()
118126

127+
new_mod.activation_scale = activation_scale
119128
new_mod.bias = mod.bias
120129
return new_mod
121130

122131

123132
def cast_to_float8_e4m3fn(
124-
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
133+
inpt_tensor: torch.Tensor,
134+
mm_config: ScaledMMConfig,
135+
reduce_amax: bool = False,
136+
activation_scale: Optional[torch.Tensor] = None,
125137
) -> Float8Tensor:
138+
"""Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.
139+
140+
Args:
141+
inpt_tensor: The input tensor to be cast.
142+
mm_config: Configuration settings for the matrix multiplication
143+
reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
144+
activation_scale: Optional tensor specifying the scale for activation. Default is None.
145+
146+
Returns:
147+
Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
148+
149+
Note:
150+
If the input tensor is already in Float8 format, it is returned as is without re-casting.
151+
"""
126152
if tensor_already_casted_to_fp8(inpt_tensor):
127153
return inpt_tensor
128-
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
154+
scale = (
155+
activation_scale
156+
if activation_scale is not None
157+
else tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
158+
)
129159
return Float8Tensor.to_float8(
130160
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=mm_config
131161
)

0 commit comments

Comments
 (0)