@@ -62,8 +62,12 @@ class Float8DynamicLinear(torch.nn.Linear):
62
62
def __init__ (self , ** super_kwargs ):
63
63
super ().__init__ (** super_kwargs )
64
64
65
+ self .activation_scale : Optional [torch .Tensor ] = None
66
+
65
67
def forward (self , x ):
66
- x_fp8 = cast_to_float8_e4m3fn (x , self .forward_config )
68
+ x_fp8 = cast_to_float8_e4m3fn (
69
+ x , self .forward_config , activation_scale = self .activation_scale
70
+ )
67
71
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
68
72
w_fp8 = self .weight
69
73
else :
@@ -86,7 +90,11 @@ def static_quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> No
86
90
87
91
@classmethod
88
92
def from_float (
89
- cls , mod , emulate : bool = False , static_quantize_weight : bool = False
93
+ cls ,
94
+ mod ,
95
+ emulate : bool = False ,
96
+ static_quantize_weight : bool = False ,
97
+ activation_scale : Optional [torch .Tensor ] = None ,
90
98
) -> "Float8DynamicLinear" :
91
99
"""
92
100
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -96,6 +104,8 @@ def from_float(
96
104
emulate (bool): whether to emulate fp8 matmul logic in float32
97
105
static_quantize_weight (bool): whether to quantize the weight statically, this is useful
98
106
for inference where weights are not updated.
107
+ activation_scale (torch.Tensor): The scale of the input to this linear module, used for
108
+ for inference when a statically known scale is available.
99
109
"""
100
110
with torch .device ("meta" ):
101
111
super_kwargs = {
@@ -116,16 +126,38 @@ def from_float(
116
126
if static_quantize_weight :
117
127
new_mod .static_quantize_weight ()
118
128
129
+ new_mod .activation_scale = activation_scale
119
130
new_mod .bias = mod .bias
120
131
return new_mod
121
132
122
133
123
134
def cast_to_float8_e4m3fn (
124
- inpt_tensor : torch .Tensor , mm_config : ScaledMMConfig , reduce_amax : bool = False
135
+ inpt_tensor : torch .Tensor ,
136
+ mm_config : ScaledMMConfig ,
137
+ reduce_amax : bool = False ,
138
+ activation_scale : Optional [torch .Tensor ] = None ,
125
139
) -> Float8Tensor :
140
+ """Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.
141
+
142
+ Args:
143
+ inpt_tensor: The input tensor to be cast.
144
+ mm_config: Configuration settings for the matrix multiplication
145
+ reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
146
+ activation_scale: Optional tensor specifying the scale for activation. Default is None.
147
+
148
+ Returns:
149
+ Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
150
+
151
+ Note:
152
+ If the input tensor is already in Float8 format, it is returned as is without re-casting.
153
+ """
126
154
if tensor_already_casted_to_fp8 (inpt_tensor ):
127
155
return inpt_tensor
128
- scale = tensor_to_scale (inpt_tensor , torch .float8_e4m3fn , reduce_amax )
156
+ scale = (
157
+ activation_scale
158
+ if activation_scale is not None
159
+ else tensor_to_scale (inpt_tensor , torch .float8_e4m3fn )
160
+ )
129
161
return Float8Tensor .to_float8 (
130
162
inpt_tensor , scale , torch .float8_e4m3fn , mm_config = mm_config
131
163
)
0 commit comments