12
12
from typing import Callable , Optional
13
13
14
14
import fire
15
+ import pandas as pd
15
16
16
17
import torch
18
+ import torch .nn as nn
19
+ import torch .nn .functional as F
17
20
from float8_experimental .float8_dynamic_linear import Float8DynamicLinear
18
21
from float8_experimental .float8_linear import Float8Linear
19
22
from float8_experimental .float8_linear_utils import (
24
27
sync_float8_amax_and_scale_history ,
25
28
)
26
29
from torch .profiler import profile , ProfilerActivity , record_function
30
+ from utils import (
31
+ kernel_name_to_category ,
32
+ profiler_output_to_gpu_time_for_key ,
33
+ profiler_output_to_time_by_kernel_name ,
34
+ )
35
+
36
+ # don't truncate long kernel names
37
+ pd .options .display .max_colwidth = 100
38
+ # display 3 trailing decimal points for floats
39
+ pd .set_option ("display.float_format" , "{:.3f}" .format )
27
40
28
41
29
42
class LNLinear (torch .nn .Module ):
@@ -38,6 +51,105 @@ def forward(self, x):
38
51
return x
39
52
40
53
54
+ # copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
55
+ class RMSNorm (nn .Module ):
56
+ """
57
+ Initialize the RMSNorm normalization layer.
58
+
59
+ Args:
60
+ dim (int): The dimension of the input tensor.
61
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
62
+
63
+ Attributes:
64
+ eps (float): A small value added to the denominator for numerical stability.
65
+ weight (nn.Parameter): Learnable scaling parameter.
66
+
67
+ """
68
+
69
+ def __init__ (self , dim : int , eps : float = 1e-6 ):
70
+ super ().__init__ ()
71
+ self .eps = eps
72
+ self .weight = nn .Parameter (torch .ones (dim ))
73
+
74
+ def _norm (self , x : torch .Tensor ):
75
+ return x * torch .rsqrt (x .pow (2 ).mean (- 1 , keepdim = True ) + self .eps )
76
+
77
+ def forward (self , x : torch .Tensor ):
78
+ output = self ._norm (x .float ()).type_as (x )
79
+ return output * self .weight
80
+
81
+ def reset_parameters (self ):
82
+ torch .nn .init .ones_ (self .weight ) # type: ignore
83
+
84
+
85
+ # copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
86
+ class FeedForward (nn .Module ):
87
+ """
88
+ FeedForward module
89
+
90
+ Args:
91
+ dim (int): Input dimension.
92
+ hidden_dim (int): Hidden dimension of the feedforward layer.
93
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
94
+ ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
95
+
96
+ Attributes:
97
+ w1 (Linear): Linear transformation for the first layer.
98
+ w2 (Linear): Linear transformation for the second layer.
99
+ w3 (Linear): Linear transformation for the third layer.
100
+
101
+ """
102
+
103
+ def __init__ (
104
+ self ,
105
+ dim : int ,
106
+ hidden_dim : int ,
107
+ multiple_of : int ,
108
+ ffn_dim_multiplier : Optional [float ],
109
+ ):
110
+ super ().__init__ ()
111
+ hidden_dim = int (2 * hidden_dim / 3 )
112
+ # custom dim factor multiplier
113
+ if ffn_dim_multiplier is not None :
114
+ hidden_dim = int (ffn_dim_multiplier * hidden_dim )
115
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1 ) // multiple_of )
116
+
117
+ self .w1 = nn .Linear (dim , hidden_dim , bias = False )
118
+ self .w2 = nn .Linear (hidden_dim , dim , bias = False )
119
+ self .w3 = nn .Linear (dim , hidden_dim , bias = False )
120
+
121
+ def forward (self , x ):
122
+ return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
123
+
124
+ def init_weights (self , init_std : float ):
125
+ nn .init .trunc_normal_ (self .w1 .weight , mean = 0.0 , std = 0.02 )
126
+ for linear in (self .w2 , self .w3 ):
127
+ nn .init .trunc_normal_ (linear .weight , mean = 0.0 , std = init_std )
128
+
129
+
130
+ class NormFFNResidualNorm (nn .Module ):
131
+ """
132
+ A fragment representing the end of TransformerBlock n and the start
133
+ of TransformerBlock n + 1, intended to include the fusions relevant
134
+ to float8 gemms in the FFN module in forward and backward.
135
+ """
136
+
137
+ def __init__ (self , dim , hidden_dim , multiple_of , ffn_dim_multiplier ):
138
+ super ().__init__ ()
139
+ self .ffn_norm = RMSNorm (dim )
140
+ self .ffn = FeedForward (dim , hidden_dim , multiple_of , ffn_dim_multiplier )
141
+ self .attn_norm = RMSNorm (dim )
142
+
143
+ def forward (self , h ):
144
+ # end of transformer block n
145
+ x = self .ffn_norm (h )
146
+ x = self .ffn (x )
147
+ x = h + x
148
+ # start of transformer block n + 1
149
+ x = self .attn_norm (x )
150
+ return x
151
+
152
+
41
153
@dataclass
42
154
class ProfileConfig :
43
155
file_path : Optional [str ] = None
@@ -87,46 +199,51 @@ def profile_function(
87
199
if config .file_path is not None :
88
200
prof .export_chrome_trace (config .file_path )
89
201
90
- if config .file_path is None :
91
- print (prof .key_averages ().table (sort_by = "cpu_time_total" , row_limit = 10 ))
92
-
93
202
return prof
94
203
95
204
96
- @dataclass (frozen = True )
97
- class ModelParams :
98
- M : int
99
- K : int
100
- N : int
101
- ref_dtype : torch .dtype
102
- layer_norm : bool = True
103
-
104
-
105
205
def main (
106
206
profile_path_prefix : Path ,
107
207
compile : bool = True ,
108
208
linear_type : str = "dynamic" ,
109
- use_layer_norm : bool = False ,
209
+ model_type : str = "linear" ,
210
+ dtype_filter : str = "both" ,
110
211
):
111
- params = ModelParams (
112
- M = 4 * 4096 ,
113
- K = 8192 ,
114
- N = 7168 ,
115
- ref_dtype = torch .bfloat16 ,
116
- layer_norm = use_layer_norm ,
117
- )
212
+ assert model_type in ("linear" , "ln_linear" , "norm_ffn_norm" ), "unsupported"
213
+ assert dtype_filter in ("both" , "float8" , "bfloat16" )
214
+
118
215
print (f"Compile is set to | { compile } " )
119
216
print (f"Using Linear type: | { linear_type } " )
120
- print (f"Use layer norm is set to | { params . layer_norm } " )
217
+ print (f"model_type is set to | { model_type } " )
121
218
122
219
device = "cuda"
123
- if params .layer_norm :
124
- m_ref = LNLinear (params .K , params .N )
220
+ ref_dtype = torch .bfloat16
221
+ if model_type == "ln_linear" :
222
+ M , K , N = 4 * 4096 , 8192 , 7168
223
+ m_ref = LNLinear (K , N )
224
+ input_tensor = torch .randn (
225
+ M , K , device = device , dtype = ref_dtype , requires_grad = True
226
+ )
227
+ elif model_type == "norm_ffn_norm" :
228
+ m_ref = NormFFNResidualNorm (
229
+ dim = 4096 ,
230
+ hidden_dim = 16384 ,
231
+ multiple_of = 1024 ,
232
+ ffn_dim_multiplier = 1.3 ,
233
+ )
234
+ input_tensor = torch .randn (
235
+ 1 , 8192 , 4096 , device = device , dtype = ref_dtype
236
+ ).requires_grad_ ()
125
237
else :
238
+ M , K , N = 4 * 4096 , 8192 , 7168
126
239
m_ref = torch .nn .Sequential (
127
- torch .nn .Linear (params . K , params . N , bias = False ),
240
+ torch .nn .Linear (K , N , bias = False ),
128
241
)
129
- m_ref = m_ref .to (device ).to (params .ref_dtype )
242
+ input_tensor = torch .randn (
243
+ M , K , device = device , dtype = ref_dtype , requires_grad = True
244
+ )
245
+
246
+ m_ref = m_ref .to (device ).to (ref_dtype )
130
247
131
248
linear_type = LinearType [linear_type .upper ()]
132
249
linear_cls = (
@@ -136,10 +253,6 @@ def main(
136
253
m_float8 = copy .deepcopy (m_ref )
137
254
swap_linear_with_float8_linear (m_float8 , linear_cls )
138
255
139
- input_tensor = torch .randn (
140
- params .M , params .K , device = "cuda" , dtype = params .ref_dtype , requires_grad = True
141
- )
142
-
143
256
def ref_forw_backward (x ):
144
257
out = m_ref (x )
145
258
out .sum ().backward ()
@@ -148,6 +261,8 @@ def float8_forw(x):
148
261
out = m_float8 (x )
149
262
return out
150
263
264
+ sync_amax_history = sync_float8_amax_and_scale_history
265
+
151
266
def float8_forw_backward_wrapper (x ):
152
267
# sync_float8_amax_and_scale_history is not full graph torch
153
268
# compile friendly, so we add a high level wrapper to allow
@@ -156,7 +271,7 @@ def float8_forw_backward_wrapper(x):
156
271
# TODO(future): make this better
157
272
if linear_requires_sync (linear_type ):
158
273
with record_function ("scale_amax_and_scales" ):
159
- sync_float8_amax_and_scale_history (m_float8 )
274
+ sync_amax_history (m_float8 )
160
275
out = float8_forw (x )
161
276
162
277
# out.sum().backward() is also not torch.compile fullgraph
@@ -165,30 +280,106 @@ def float8_forw_backward_wrapper(x):
165
280
out .sum ().backward ()
166
281
167
282
if compile :
168
- ref_forw_backward = torch .compile (ref_forw_backward )
283
+ m_ref = torch .compile (m_ref , fullgraph = True )
169
284
float8_forw = torch .compile (float8_forw , fullgraph = True )
170
-
171
- for _ in range (5 ):
172
- ref_forw_backward (input_tensor )
173
- float8_forw_backward_wrapper (input_tensor )
174
-
175
- # Profile Reference Model
176
- ref_suffix = f"_ref_compile_{ compile } .json"
177
- profile_config = ProfileConfig (
178
- profile_path_prefix + ref_suffix , ref_suffix , iters = 5 , warmup_iters = 5 , sync = True
285
+ # Note: it's faster to compile the combination of sync_amax_history wit
286
+ # forward because we only look up from dynamo cache once.
287
+ # However, compiling the sync function separately makes it more
288
+ # convenient to analyze the total time spent on it.
289
+ sync_amax_history = torch .compile (sync_amax_history )
290
+
291
+ # warm up
292
+ for _ in range (1 ):
293
+ if dtype_filter != "float8" :
294
+ ref_forw_backward (input_tensor )
295
+ if dtype_filter != "bfloat16" :
296
+ float8_forw_backward_wrapper (input_tensor )
297
+
298
+ profile_iters = 5
299
+ ref_times , float8_times = None , None
300
+ data = []
301
+
302
+ if dtype_filter != "float8" :
303
+ # Profile Reference Model
304
+ print ("profiling ref" )
305
+ ref_suffix = f"_{ model_type } _ref_compile_{ compile } .json"
306
+ ref_path = profile_path_prefix + ref_suffix
307
+ profile_config = ProfileConfig (
308
+ ref_path , ref_suffix , iters = profile_iters , warmup_iters = 2 , sync = True
309
+ )
310
+ p = profile_function (profile_config , ref_forw_backward , input_tensor )
311
+ print (f"saved { ref_path } " )
312
+ ref_times = profiler_output_to_time_by_kernel_name (p )
313
+ total_time_ms = sum (v for v in ref_times .values ()) / 1e3 / profile_iters
314
+ for k , v in ref_times .items ():
315
+ v_ms = v / 1e3 / profile_iters
316
+ data .append (
317
+ ["0_ref" , k , kernel_name_to_category (k ), v_ms , v_ms / total_time_ms ]
318
+ )
319
+
320
+ if dtype_filter != "bfloat16" :
321
+ # Profile Float8 Model
322
+ print ("profiling float8" )
323
+ float8_suffix = f"_{ model_type } _float8_compile_{ compile } _{ linear_type } .json"
324
+ float8_path = profile_path_prefix + float8_suffix
325
+ profile_config = ProfileConfig (
326
+ float8_path ,
327
+ float8_suffix ,
328
+ iters = profile_iters ,
329
+ warmup_iters = 2 ,
330
+ sync = True ,
331
+ )
332
+ p = profile_function (profile_config , float8_forw_backward_wrapper , input_tensor )
333
+ print (f"saved { float8_path } " )
334
+ float8_times = profiler_output_to_time_by_kernel_name (p )
335
+ total_time_ms = sum (v for v in float8_times .values ()) / 1e3 / profile_iters
336
+ for k , v in float8_times .items ():
337
+ v_ms = v / 1e3 / profile_iters
338
+ data .append (
339
+ [
340
+ "1_float8" ,
341
+ k ,
342
+ kernel_name_to_category (k ),
343
+ v / 1e3 / profile_iters ,
344
+ v_ms / total_time_ms ,
345
+ ]
346
+ )
347
+
348
+ # get the time spent per user annotation
349
+ sync_time_us = profiler_output_to_gpu_time_for_key (p , "scale_amax_and_scales" )
350
+ sync_time_ms = sync_time_us / profile_iters / 1e3
351
+ print (f"Sync time ms: { sync_time_ms } " )
352
+
353
+ df = pd .DataFrame (
354
+ data , columns = ["experiment" , "kernel" , "category" , "time_ms" , "pct_gpu_time" ]
179
355
)
180
- profile_function ( profile_config , ref_forw_backward , input_tensor )
181
-
182
- # Profile Float8 Model
183
- float8_suffix = f"_float8_compile_ { compile } _ { linear_type } .json"
184
- profile_config = ProfileConfig (
185
- profile_path_prefix + float8_suffix ,
186
- float8_suffix ,
187
- iters = 5 ,
188
- warmup_iters = 5 ,
189
- sync = True ,
356
+ print ( " \n Summary of GPU time by CPU kernel \n \n " , df )
357
+
358
+ # compare gemm and overhead time
359
+ df_p = df . pivot_table (
360
+ columns = [ "category" ],
361
+ index = "experiment" ,
362
+ values = "time_ms" ,
363
+ aggfunc = "sum" ,
364
+ fill_value = 0 ,
365
+ margins = True ,
190
366
)
191
- profile_function (profile_config , float8_forw_backward_wrapper , input_tensor )
367
+ # drop last row, which has totals across ref + float8 which does not make sense
368
+ df_p = df_p [:- 1 ]
369
+ df_p = df_p .transpose ()
370
+
371
+ if dtype_filter == "both" :
372
+ df_p ["f8_div_ref" ] = df_p ["1_float8" ] / df_p ["0_ref" ]
373
+ df_p ["ref_div_f8" ] = df_p ["0_ref" ] / df_p ["1_float8" ]
374
+
375
+ # calculate sync time as pct of total float time
376
+ total_float8_ms = df_p .iloc [3 ]["1_float8" ]
377
+ sync_approx_ratio = sync_time_ms / total_float8_ms
378
+ print (
379
+ f"\n Float8 amax/scale sync approx ratio of total time: { sync_approx_ratio :.3f} "
380
+ )
381
+
382
+ print ("\n Summary of time (ms) by kernel category\n \n " , df_p )
192
383
193
384
194
385
def invoke_main () -> None :
0 commit comments