3
3
#
4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
6
8
import random
7
9
from contextlib import nullcontext
8
10
from dataclasses import dataclass , field
12
14
import fire
13
15
14
16
import torch
17
+ from float8_experimental .float8_dynamic_linear import Float8DynamicLinear
18
+ from float8_experimental .float8_linear import Float8Linear
15
19
from float8_experimental .float8_linear_utils import (
16
20
get_float8_linear ,
17
21
linear_requires_sync ,
18
22
LinearType ,
23
+ swap_linear_with_float8_linear ,
19
24
sync_float8_amax_and_scale_history ,
20
25
)
21
26
from torch .profiler import profile , ProfilerActivity , record_function
22
27
23
28
29
+ class LNLinear (torch .nn .Module ):
30
+ def __init__ (self , fc_dim1 , fc_dim2 ):
31
+ super ().__init__ ()
32
+ self .ln = torch .nn .LayerNorm (fc_dim1 , elementwise_affine = False )
33
+ self .fc = torch .nn .Linear (fc_dim1 , fc_dim2 , bias = False )
34
+
35
+ def forward (self , x ):
36
+ x = self .ln (x )
37
+ x = self .fc (x )
38
+ return x
39
+
40
+
24
41
@dataclass
25
42
class ProfileConfig :
26
43
file_path : Optional [str ] = None
@@ -77,65 +94,58 @@ def profile_function(
77
94
78
95
79
96
@dataclass (frozen = True )
80
- class LinearParams :
97
+ class ModelParams :
81
98
M : int
82
99
K : int
83
100
N : int
84
- input_bias : bool
85
101
ref_dtype : torch .dtype
86
102
layer_norm : bool = True
87
- torch_compile : Optional [bool ] = False
88
103
89
104
90
- def main (profile_path : Path , compile : bool , linear_type : str ):
91
- profile_path = Path (profile_path )
92
- assert profile_path .is_dir (), f"Path { profile_path } must be a directory"
93
- params = LinearParams (
105
+ def main (
106
+ profile_path_prefix : Path ,
107
+ compile : bool = True ,
108
+ linear_type : str = "dynamic" ,
109
+ use_layer_norm : bool = False ,
110
+ ):
111
+ params = ModelParams (
94
112
M = 4 * 4096 ,
95
113
K = 8192 ,
96
114
N = 7168 ,
97
- input_bias = False ,
98
115
ref_dtype = torch .bfloat16 ,
99
- layer_norm = True ,
100
- torch_compile = compile ,
116
+ layer_norm = use_layer_norm ,
101
117
)
102
118
print (f"Compile is set to | { compile } " )
103
119
print (f"Using Linear type: | { linear_type } " )
104
120
print (f"Use layer norm is set to | { params .layer_norm } " )
105
- linear_ref = torch .nn .Linear (
106
- params .K ,
107
- params .N ,
108
- bias = params .input_bias ,
109
- device = "cuda" ,
110
- dtype = params .ref_dtype ,
111
- )
121
+
122
+ device = "cuda"
123
+ if params .layer_norm :
124
+ m_ref = LNLinear (params .K , params .N )
125
+ else :
126
+ m_ref = torch .nn .Sequential (
127
+ torch .nn .Linear (params .K , params .N , bias = False ),
128
+ )
129
+ m_ref = m_ref .to (device ).to (params .ref_dtype )
130
+
112
131
linear_type = LinearType [linear_type .upper ()]
113
- linear_float8 = get_float8_linear (linear_type , linear_ref )
132
+ linear_cls = (
133
+ Float8Linear if linear_type is LinearType .DELAYED else Float8DynamicLinear
134
+ )
135
+
136
+ m_float8 = copy .deepcopy (m_ref )
137
+ swap_linear_with_float8_linear (m_float8 , linear_cls )
114
138
115
139
input_tensor = torch .randn (
116
140
params .M , params .K , device = "cuda" , dtype = params .ref_dtype , requires_grad = True
117
141
)
118
142
119
- if params .layer_norm :
120
- ln = torch .nn .LayerNorm (
121
- params .K , elementwise_affine = False , device = "cuda" , dtype = params .ref_dtype
122
- )
123
-
124
143
def ref_forw_backward (x ):
125
- if params .layer_norm :
126
- with record_function ("layer_norm" ):
127
- x = ln (x )
128
- with record_function ("forward" ):
129
- out = linear_ref (x )
130
- with record_function ("backward" ):
131
- out .sum ().backward ()
144
+ out = m_ref (x )
145
+ out .sum ().backward ()
132
146
133
- def float8_forw_backward (x ):
134
- if params .layer_norm :
135
- with record_function ("layer_norm" ):
136
- x = ln (x )
137
- with record_function ("forward" ):
138
- out = linear_float8 (x )
147
+ def float8_forw (x ):
148
+ out = m_float8 (x )
139
149
return out
140
150
141
151
def float8_forw_backward_wrapper (x ):
@@ -146,34 +156,34 @@ def float8_forw_backward_wrapper(x):
146
156
# TODO(future): make this better
147
157
if linear_requires_sync (linear_type ):
148
158
with record_function ("scale_amax_and_scales" ):
149
- sync_float8_amax_and_scale_history (linear_float8 )
150
- out = float8_forw_backward (x )
159
+ sync_float8_amax_and_scale_history (m_float8 )
160
+ out = float8_forw (x )
151
161
152
162
# out.sum().backward() is also not torch.compile fullgraph
153
163
# friendly
154
164
with record_function ("backward" ):
155
165
out .sum ().backward ()
156
166
157
- if params . torch_compile :
167
+ if compile :
158
168
ref_forw_backward = torch .compile (ref_forw_backward )
159
- float8_forw_backward = torch .compile (float8_forw_backward , fullgraph = True )
169
+ float8_forw = torch .compile (float8_forw , fullgraph = True )
160
170
161
171
for _ in range (5 ):
162
172
ref_forw_backward (input_tensor )
163
173
float8_forw_backward_wrapper (input_tensor )
164
174
165
- # Profile Reference Linear
166
- ref_string = f"linear_ref_dtype_ { params . ref_dtype } _M_ { params . M } _K_ { params . K } _N_ { params . N } _input_bias_ { params . input_bias } _compile_ { params . torch_compile } .json"
175
+ # Profile Reference Model
176
+ ref_suffix = f"_ref_compile_ { compile } .json"
167
177
profile_config = ProfileConfig (
168
- str ( profile_path / ref_string ), ref_string , iters = 5 , warmup_iters = 5 , sync = True
178
+ profile_path_prefix + ref_suffix , ref_suffix , iters = 5 , warmup_iters = 5 , sync = True
169
179
)
170
180
profile_function (profile_config , ref_forw_backward , input_tensor )
171
181
172
- # # Profile Float8 Linear
173
- float8_string = f"linear_float8_M_ { params . M } _K_ { params . K } _N_ { params . N } _input_bias_ { params . input_bias } _compile_ { params . torch_compile } _{ linear_type } .json"
182
+ # Profile Float8 Model
183
+ float8_suffix = f"_float8_compile_ { compile } _{ linear_type } .json"
174
184
profile_config = ProfileConfig (
175
- str ( profile_path / float8_string ) ,
176
- float8_string ,
185
+ profile_path_prefix + float8_suffix ,
186
+ float8_suffix ,
177
187
iters = 5 ,
178
188
warmup_iters = 5 ,
179
189
sync = True ,
@@ -182,7 +192,7 @@ def float8_forw_backward_wrapper(x):
182
192
183
193
184
194
def invoke_main () -> None :
185
- # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles --compile=True --linear_type="dynamic"
195
+ # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles/current_profile --compile=True --linear_type="dynamic"
186
196
fire .Fire (main )
187
197
188
198
0 commit comments