@@ -62,8 +62,10 @@ class Float8DynamicLinear(torch.nn.Linear):
62
62
def __init__ (self , ** super_kwargs ):
63
63
super ().__init__ (** super_kwargs )
64
64
65
+ self .activation_scale : Optional [torch .Tensor ] = None
66
+
65
67
def forward (self , x ):
66
- x_fp8 = cast_to_float8_e4m3fn (x , self .forward_config )
68
+ x_fp8 = cast_to_float8_e4m3fn (x , self .forward_config , self . activation_scale )
67
69
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
68
70
w_fp8 = self .weight
69
71
else :
@@ -86,7 +88,11 @@ def static_quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> No
86
88
87
89
@classmethod
88
90
def from_float (
89
- cls , mod , emulate : bool = False , static_quantize_weight : bool = False
91
+ cls ,
92
+ mod ,
93
+ emulate : bool = False ,
94
+ static_quantize_weight : bool = False ,
95
+ activation_scale : Optional [torch .Tensor ] = None ,
90
96
) -> "Float8DynamicLinear" :
91
97
"""
92
98
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -96,6 +102,8 @@ def from_float(
96
102
emulate (bool): whether to emulate fp8 matmul logic in float32
97
103
static_quantize_weight (bool): whether to quantize the weight statically, this is useful
98
104
for inference where weights are not updated.
105
+ activation_scale (torch.Tensor): The scale of the input to this linear module, used for
106
+ for inference when a statically known scale is available.
99
107
"""
100
108
with torch .device ("meta" ):
101
109
super_kwargs = {
@@ -116,16 +124,38 @@ def from_float(
116
124
if static_quantize_weight :
117
125
new_mod .static_quantize_weight ()
118
126
127
+ new_mod .activation_scale = activation_scale
119
128
new_mod .bias = mod .bias
120
129
return new_mod
121
130
122
131
123
132
def cast_to_float8_e4m3fn (
124
- inpt_tensor : torch .Tensor , mm_config : ScaledMMConfig , reduce_amax : bool = False
133
+ inpt_tensor : torch .Tensor ,
134
+ mm_config : ScaledMMConfig ,
135
+ reduce_amax : bool = False ,
136
+ activation_scale : Optional [torch .Tensor ] = None ,
125
137
) -> Float8Tensor :
138
+ """Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.
139
+
140
+ Args:
141
+ inpt_tensor: The input tensor to be cast.
142
+ mm_config: Configuration settings for the matrix multiplication
143
+ reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
144
+ activation_scale: Optional tensor specifying the scale for activation. Default is None.
145
+
146
+ Returns:
147
+ Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
148
+
149
+ Note:
150
+ If the input tensor is already in Float8 format, it is returned as is without re-casting.
151
+ """
126
152
if tensor_already_casted_to_fp8 (inpt_tensor ):
127
153
return inpt_tensor
128
- scale = tensor_to_scale (inpt_tensor , torch .float8_e4m3fn , reduce_amax )
154
+ scale = (
155
+ activation_scale
156
+ if activation_scale is not None
157
+ else tensor_to_scale (inpt_tensor , torch .float8_e4m3fn )
158
+ )
129
159
return Float8Tensor .to_float8 (
130
160
inpt_tensor , scale , torch .float8_e4m3fn , mm_config = mm_config
131
161
)
0 commit comments