12
12
from typing import Callable , Optional
13
13
14
14
import fire
15
+ import pandas as pd
15
16
16
17
import torch
18
+ import torch .nn as nn
19
+ import torch .nn .functional as F
17
20
from float8_experimental .float8_dynamic_linear import Float8DynamicLinear
18
21
from float8_experimental .float8_linear import Float8Linear
19
22
from float8_experimental .float8_linear_utils import (
24
27
sync_float8_amax_and_scale_history ,
25
28
)
26
29
from torch .profiler import profile , ProfilerActivity , record_function
30
+ from utils import (
31
+ profiler_output_to_gpu_time_for_key ,
32
+ profiler_output_to_time_by_kernel_name ,
33
+ )
34
+
35
+ # don't truncate long kernel names
36
+ pd .options .display .max_colwidth = 100
37
+ # display 3 trailing decimal points for floats
38
+ pd .set_option ("display.float_format" , "{:.3f}" .format )
27
39
28
40
29
41
class LNLinear (torch .nn .Module ):
@@ -38,6 +50,105 @@ def forward(self, x):
38
50
return x
39
51
40
52
53
+ # copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
54
+ class RMSNorm (nn .Module ):
55
+ """
56
+ Initialize the RMSNorm normalization layer.
57
+
58
+ Args:
59
+ dim (int): The dimension of the input tensor.
60
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
61
+
62
+ Attributes:
63
+ eps (float): A small value added to the denominator for numerical stability.
64
+ weight (nn.Parameter): Learnable scaling parameter.
65
+
66
+ """
67
+
68
+ def __init__ (self , dim : int , eps : float = 1e-6 ):
69
+ super ().__init__ ()
70
+ self .eps = eps
71
+ self .weight = nn .Parameter (torch .ones (dim ))
72
+
73
+ def _norm (self , x : torch .Tensor ):
74
+ return x * torch .rsqrt (x .pow (2 ).mean (- 1 , keepdim = True ) + self .eps )
75
+
76
+ def forward (self , x : torch .Tensor ):
77
+ output = self ._norm (x .float ()).type_as (x )
78
+ return output * self .weight
79
+
80
+ def reset_parameters (self ):
81
+ torch .nn .init .ones_ (self .weight ) # type: ignore
82
+
83
+
84
+ # copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
85
+ class FeedForward (nn .Module ):
86
+ """
87
+ FeedForward module
88
+
89
+ Args:
90
+ dim (int): Input dimension.
91
+ hidden_dim (int): Hidden dimension of the feedforward layer.
92
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
93
+ ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
94
+
95
+ Attributes:
96
+ w1 (Linear): Linear transformation for the first layer.
97
+ w2 (Linear): Linear transformation for the second layer.
98
+ w3 (Linear): Linear transformation for the third layer.
99
+
100
+ """
101
+
102
+ def __init__ (
103
+ self ,
104
+ dim : int ,
105
+ hidden_dim : int ,
106
+ multiple_of : int ,
107
+ ffn_dim_multiplier : Optional [float ],
108
+ ):
109
+ super ().__init__ ()
110
+ hidden_dim = int (2 * hidden_dim / 3 )
111
+ # custom dim factor multiplier
112
+ if ffn_dim_multiplier is not None :
113
+ hidden_dim = int (ffn_dim_multiplier * hidden_dim )
114
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1 ) // multiple_of )
115
+
116
+ self .w1 = nn .Linear (dim , hidden_dim , bias = False )
117
+ self .w2 = nn .Linear (hidden_dim , dim , bias = False )
118
+ self .w3 = nn .Linear (dim , hidden_dim , bias = False )
119
+
120
+ def forward (self , x ):
121
+ return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
122
+
123
+ def init_weights (self , init_std : float ):
124
+ nn .init .trunc_normal_ (self .w1 .weight , mean = 0.0 , std = 0.02 )
125
+ for linear in (self .w2 , self .w3 ):
126
+ nn .init .trunc_normal_ (linear .weight , mean = 0.0 , std = init_std )
127
+
128
+
129
+ class NormFFNResidualNorm (nn .Module ):
130
+ """
131
+ A fragment representing the end of TransformerBlock n and the start
132
+ of TransformerBlock n + 1, intended to include the fusions relevant
133
+ to float8 gemms in the FFN module in forward and backward.
134
+ """
135
+
136
+ def __init__ (self , dim , hidden_dim , multiple_of , ffn_dim_multiplier ):
137
+ super ().__init__ ()
138
+ self .ffn_norm = RMSNorm (dim )
139
+ self .ffn = FeedForward (dim , hidden_dim , multiple_of , ffn_dim_multiplier )
140
+ self .attn_norm = RMSNorm (dim )
141
+
142
+ def forward (self , h ):
143
+ # end of transformer block n
144
+ x = self .ffn_norm (h )
145
+ x = self .ffn (x )
146
+ x = h + x
147
+ # start of transformer block n + 1
148
+ x = self .attn_norm (x )
149
+ return x
150
+
151
+
41
152
@dataclass
42
153
class ProfileConfig :
43
154
file_path : Optional [str ] = None
@@ -87,46 +198,51 @@ def profile_function(
87
198
if config .file_path is not None :
88
199
prof .export_chrome_trace (config .file_path )
89
200
90
- if config .file_path is None :
91
- print (prof .key_averages ().table (sort_by = "cpu_time_total" , row_limit = 10 ))
92
-
93
201
return prof
94
202
95
203
96
- @dataclass (frozen = True )
97
- class ModelParams :
98
- M : int
99
- K : int
100
- N : int
101
- ref_dtype : torch .dtype
102
- layer_norm : bool = True
103
-
104
-
105
204
def main (
106
205
profile_path_prefix : Path ,
107
206
compile : bool = True ,
108
207
linear_type : str = "dynamic" ,
109
- use_layer_norm : bool = False ,
208
+ model_type : str = "linear" ,
209
+ dtype_filter : str = "both" ,
110
210
):
111
- params = ModelParams (
112
- M = 4 * 4096 ,
113
- K = 8192 ,
114
- N = 7168 ,
115
- ref_dtype = torch .bfloat16 ,
116
- layer_norm = use_layer_norm ,
117
- )
211
+ assert model_type in ("linear" , "ln_linear" , "norm_ffn_norm" ), "unsupported"
212
+ assert dtype_filter in ("both" , "float8" , "bfloat16" )
213
+
118
214
print (f"Compile is set to | { compile } " )
119
215
print (f"Using Linear type: | { linear_type } " )
120
- print (f"Use layer norm is set to | { params . layer_norm } " )
216
+ print (f"model_type is set to | { model_type } " )
121
217
122
218
device = "cuda"
123
- if params .layer_norm :
124
- m_ref = LNLinear (params .K , params .N )
219
+ ref_dtype = torch .bfloat16
220
+ if model_type == "ln_linear" :
221
+ M , K , N = 4 * 4096 , 8192 , 7168
222
+ m_ref = LNLinear (K , N )
223
+ input_tensor = torch .randn (
224
+ M , K , device = device , dtype = ref_dtype , requires_grad = True
225
+ )
226
+ elif model_type == "norm_ffn_norm" :
227
+ m_ref = NormFFNResidualNorm (
228
+ dim = 4096 ,
229
+ hidden_dim = 16384 ,
230
+ multiple_of = 1024 ,
231
+ ffn_dim_multiplier = 1.3 ,
232
+ )
233
+ input_tensor = torch .randn (
234
+ 1 , 8192 , 4096 , device = device , dtype = ref_dtype
235
+ ).requires_grad_ ()
125
236
else :
237
+ M , K , N = 4 * 4096 , 8192 , 7168
126
238
m_ref = torch .nn .Sequential (
127
- torch .nn .Linear (params .K , params .N , bias = False ),
239
+ torch .nn .Linear (K , N , bias = False ),
240
+ )
241
+ input_tensor = torch .randn (
242
+ M , K , device = device , dtype = ref_dtype , requires_grad = True
128
243
)
129
- m_ref = m_ref .to (device ).to (params .ref_dtype )
244
+
245
+ m_ref = m_ref .to (device ).to (ref_dtype )
130
246
131
247
linear_type = LinearType [linear_type .upper ()]
132
248
linear_cls = (
@@ -136,10 +252,6 @@ def main(
136
252
m_float8 = copy .deepcopy (m_ref )
137
253
swap_linear_with_float8_linear (m_float8 , linear_cls )
138
254
139
- input_tensor = torch .randn (
140
- params .M , params .K , device = "cuda" , dtype = params .ref_dtype , requires_grad = True
141
- )
142
-
143
255
def ref_forw_backward (x ):
144
256
out = m_ref (x )
145
257
out .sum ().backward ()
@@ -148,6 +260,8 @@ def float8_forw(x):
148
260
out = m_float8 (x )
149
261
return out
150
262
263
+ sync_amax_history = sync_float8_amax_and_scale_history
264
+
151
265
def float8_forw_backward_wrapper (x ):
152
266
# sync_float8_amax_and_scale_history is not full graph torch
153
267
# compile friendly, so we add a high level wrapper to allow
@@ -156,7 +270,7 @@ def float8_forw_backward_wrapper(x):
156
270
# TODO(future): make this better
157
271
if linear_requires_sync (linear_type ):
158
272
with record_function ("scale_amax_and_scales" ):
159
- sync_float8_amax_and_scale_history (m_float8 )
273
+ sync_amax_history (m_float8 )
160
274
out = float8_forw (x )
161
275
162
276
# out.sum().backward() is also not torch.compile fullgraph
@@ -165,30 +279,120 @@ def float8_forw_backward_wrapper(x):
165
279
out .sum ().backward ()
166
280
167
281
if compile :
168
- ref_forw_backward = torch .compile (ref_forw_backward )
282
+ m_ref = torch .compile (m_ref , fullgraph = True )
169
283
float8_forw = torch .compile (float8_forw , fullgraph = True )
284
+ # Note: it's faster to compile the combination of sync_amax_history wit
285
+ # forward because we only look up from dynamo cache once.
286
+ # However, compiling the sync function separately makes it more
287
+ # convenient to analyze the total time spent on it.
288
+ sync_amax_history = torch .compile (sync_amax_history )
289
+
290
+ # warm up
291
+ for _ in range (1 ):
292
+ if dtype_filter != "float8" :
293
+ ref_forw_backward (input_tensor )
294
+ if dtype_filter != "bfloat16" :
295
+ float8_forw_backward_wrapper (input_tensor )
296
+
297
+ # profile_iters = 5
298
+ profile_iters = 2
299
+ ref_times , float8_times = None , None
300
+
301
+ if dtype_filter != "float8" :
302
+ # Profile Reference Model
303
+ print ("profiling ref" )
304
+ ref_suffix = f"_{ model_type } _ref_compile_{ compile } .json"
305
+ ref_path = profile_path_prefix + ref_suffix
306
+ profile_config = ProfileConfig (
307
+ ref_path , ref_suffix , iters = profile_iters , warmup_iters = 2 , sync = True
308
+ )
309
+ p = profile_function (profile_config , ref_forw_backward , input_tensor )
310
+ print (f"saved { ref_path } " )
311
+ ref_times = profiler_output_to_time_by_kernel_name (p )
312
+
313
+ if dtype_filter != "bfloat16" :
314
+ # Profile Float8 Model
315
+ print ("profiling float8" )
316
+ float8_suffix = f"_{ model_type } _float8_compile_{ compile } _{ linear_type } .json"
317
+ float8_path = profile_path_prefix + float8_suffix
318
+ profile_config = ProfileConfig (
319
+ float8_path ,
320
+ float8_suffix ,
321
+ iters = profile_iters ,
322
+ warmup_iters = 2 ,
323
+ sync = True ,
324
+ )
325
+ p = profile_function (profile_config , float8_forw_backward_wrapper , input_tensor )
326
+ print (f"saved { float8_path } " )
327
+ float8_times = profiler_output_to_time_by_kernel_name (p )
328
+
329
+ # get the time spent per user annotation
330
+ sync_time_us = profiler_output_to_gpu_time_for_key (p , "scale_amax_and_scales" )
331
+ sync_time_ms = sync_time_us / profile_iters / 1e3
332
+ print (f"Sync time ms: { sync_time_ms } " )
333
+
334
+ if dtype_filter == "both" :
335
+ data = []
336
+
337
+ def kernel_name_to_category (k ):
338
+ # number prefix is for easy sorting
339
+ if k in ("aten::mm" , "aten::addmm" , "aten::_scaled_mm" ):
340
+ return "0_gemm"
341
+ elif (
342
+ # max(abs(tensor))
343
+ ("abs" in k and "max" in k )
344
+ or
345
+ # casting pointwise to float8
346
+ ("clamp" in k )
347
+ or
348
+ # things related to scaled_mm
349
+ ("scaled_mm" in k )
350
+ or
351
+ # syncing amaxes and scales
352
+ ("roll" in k )
353
+ ):
354
+ # note: the above filter is approximate and will give false
355
+ # positives if model code contains other code to abs/max/clamp
356
+ return "1_f8_overhead"
357
+ return "2_other"
358
+
359
+ for k , v in ref_times .items ():
360
+ data .append (
361
+ ["0_ref" , k , kernel_name_to_category (k ), v / 1e3 / profile_iters ]
362
+ )
363
+ for k , v in float8_times .items ():
364
+ data .append (
365
+ ["1_float8" , k , kernel_name_to_category (k ), v / 1e3 / profile_iters ]
366
+ )
367
+
368
+ df = pd .DataFrame (data , columns = ["experiment" , "kernel" , "category" , "time_ms" ])
369
+ print ("\n Summary of GPU time by CPU kernel\n \n " , df )
370
+
371
+ # compare gemm and overhead time
372
+ df_p = df .pivot_table (
373
+ columns = ["category" ],
374
+ index = "experiment" ,
375
+ values = "time_ms" ,
376
+ aggfunc = "sum" ,
377
+ fill_value = 0 ,
378
+ margins = True ,
379
+ )
380
+ # drop last row, which has totals across ref + float8 which does not make sense
381
+ df_p = df_p [:- 1 ]
382
+
383
+ df_p = df_p .transpose ()
384
+ df_p ["f8_div_ref" ] = df_p ["1_float8" ] / df_p ["0_ref" ]
385
+ df_p ["ref_div_f8" ] = df_p ["0_ref" ] / df_p ["1_float8" ]
386
+ print (
387
+ "\n Summary of time (ms) by kernel category, across ref and float8\n \n " , df_p
388
+ )
170
389
171
- for _ in range (5 ):
172
- ref_forw_backward (input_tensor )
173
- float8_forw_backward_wrapper (input_tensor )
174
-
175
- # Profile Reference Model
176
- ref_suffix = f"_ref_compile_{ compile } .json"
177
- profile_config = ProfileConfig (
178
- profile_path_prefix + ref_suffix , ref_suffix , iters = 5 , warmup_iters = 5 , sync = True
179
- )
180
- profile_function (profile_config , ref_forw_backward , input_tensor )
181
-
182
- # Profile Float8 Model
183
- float8_suffix = f"_float8_compile_{ compile } _{ linear_type } .json"
184
- profile_config = ProfileConfig (
185
- profile_path_prefix + float8_suffix ,
186
- float8_suffix ,
187
- iters = 5 ,
188
- warmup_iters = 5 ,
189
- sync = True ,
190
- )
191
- profile_function (profile_config , float8_forw_backward_wrapper , input_tensor )
390
+ # calculate sync time as pct of total float time
391
+ total_float8_ms = df_p .iloc [3 ]["1_float8" ]
392
+ sync_approx_ratio = sync_time_ms / total_float8_ms
393
+ print (
394
+ f"\n Float8 amax/scale sync approx ratio of total time: { sync_approx_ratio :.3f} "
395
+ )
192
396
193
397
194
398
def invoke_main () -> None :
0 commit comments