14
14
import fire
15
15
16
16
import torch
17
+ import torch .nn as nn
18
+ import torch .nn .functional as F
17
19
from float8_experimental .float8_dynamic_linear import Float8DynamicLinear
18
20
from float8_experimental .float8_linear import Float8Linear
19
21
from float8_experimental .float8_linear_utils import (
@@ -38,6 +40,105 @@ def forward(self, x):
38
40
return x
39
41
40
42
43
+ # copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
44
+ class RMSNorm (nn .Module ):
45
+ """
46
+ Initialize the RMSNorm normalization layer.
47
+
48
+ Args:
49
+ dim (int): The dimension of the input tensor.
50
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
51
+
52
+ Attributes:
53
+ eps (float): A small value added to the denominator for numerical stability.
54
+ weight (nn.Parameter): Learnable scaling parameter.
55
+
56
+ """
57
+
58
+ def __init__ (self , dim : int , eps : float = 1e-6 ):
59
+ super ().__init__ ()
60
+ self .eps = eps
61
+ self .weight = nn .Parameter (torch .ones (dim ))
62
+
63
+ def _norm (self , x : torch .Tensor ):
64
+ return x * torch .rsqrt (x .pow (2 ).mean (- 1 , keepdim = True ) + self .eps )
65
+
66
+ def forward (self , x : torch .Tensor ):
67
+ output = self ._norm (x .float ()).type_as (x )
68
+ return output * self .weight
69
+
70
+ def reset_parameters (self ):
71
+ torch .nn .init .ones_ (self .weight ) # type: ignore
72
+
73
+
74
+ # copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
75
+ class FeedForward (nn .Module ):
76
+ """
77
+ FeedForward module
78
+
79
+ Args:
80
+ dim (int): Input dimension.
81
+ hidden_dim (int): Hidden dimension of the feedforward layer.
82
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
83
+ ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
84
+
85
+ Attributes:
86
+ w1 (Linear): Linear transformation for the first layer.
87
+ w2 (Linear): Linear transformation for the second layer.
88
+ w3 (Linear): Linear transformation for the third layer.
89
+
90
+ """
91
+
92
+ def __init__ (
93
+ self ,
94
+ dim : int ,
95
+ hidden_dim : int ,
96
+ multiple_of : int ,
97
+ ffn_dim_multiplier : Optional [float ],
98
+ ):
99
+ super ().__init__ ()
100
+ hidden_dim = int (2 * hidden_dim / 3 )
101
+ # custom dim factor multiplier
102
+ if ffn_dim_multiplier is not None :
103
+ hidden_dim = int (ffn_dim_multiplier * hidden_dim )
104
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1 ) // multiple_of )
105
+
106
+ self .w1 = nn .Linear (dim , hidden_dim , bias = False )
107
+ self .w2 = nn .Linear (hidden_dim , dim , bias = False )
108
+ self .w3 = nn .Linear (dim , hidden_dim , bias = False )
109
+
110
+ def forward (self , x ):
111
+ return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
112
+
113
+ def init_weights (self , init_std : float ):
114
+ nn .init .trunc_normal_ (self .w1 .weight , mean = 0.0 , std = 0.02 )
115
+ for linear in (self .w2 , self .w3 ):
116
+ nn .init .trunc_normal_ (linear .weight , mean = 0.0 , std = init_std )
117
+
118
+
119
+ class NormFFNResidualNorm (nn .Module ):
120
+ """
121
+ A fragment representing the end of TransformerBlock n and the start
122
+ of TransformerBlock n + 1, intended to include the fusions relevant
123
+ to float8 gemms in the FFN module in forward and backward.
124
+ """
125
+
126
+ def __init__ (self , dim , hidden_dim , multiple_of , ffn_dim_multiplier ):
127
+ super ().__init__ ()
128
+ self .ffn_norm = RMSNorm (dim )
129
+ self .ffn = FeedForward (dim , hidden_dim , multiple_of , ffn_dim_multiplier )
130
+ self .attn_norm = RMSNorm (dim )
131
+
132
+ def forward (self , h ):
133
+ # end of transformer block n
134
+ x = self .ffn_norm (h )
135
+ x = self .ffn (x )
136
+ x = h + x
137
+ # start of transformer block n + 1
138
+ x = self .attn_norm (x )
139
+ return x
140
+
141
+
41
142
@dataclass
42
143
class ProfileConfig :
43
144
file_path : Optional [str ] = None
@@ -93,40 +194,46 @@ def profile_function(
93
194
return prof
94
195
95
196
96
- @dataclass (frozen = True )
97
- class ModelParams :
98
- M : int
99
- K : int
100
- N : int
101
- ref_dtype : torch .dtype
102
- layer_norm : bool = True
103
-
104
-
105
197
def main (
106
198
profile_path_prefix : Path ,
107
199
compile : bool = True ,
108
200
linear_type : str = "dynamic" ,
109
- use_layer_norm : bool = False ,
201
+ model_type : str = "linear" ,
110
202
):
111
- params = ModelParams (
112
- M = 4 * 4096 ,
113
- K = 8192 ,
114
- N = 7168 ,
115
- ref_dtype = torch .bfloat16 ,
116
- layer_norm = use_layer_norm ,
117
- )
203
+ assert model_type in ("linear" , "ln_linear" , "norm_ffn_norm" ), "unsupported"
204
+
118
205
print (f"Compile is set to | { compile } " )
119
206
print (f"Using Linear type: | { linear_type } " )
120
- print (f"Use layer norm is set to | { params . layer_norm } " )
207
+ print (f"model_type is set to | { model_type } " )
121
208
122
209
device = "cuda"
123
- if params .layer_norm :
124
- m_ref = LNLinear (params .K , params .N )
210
+ ref_dtype = torch .bfloat16
211
+ if model_type == "ln_linear" :
212
+ M , K , N = 4 * 4096 , 8192 , 7168
213
+ m_ref = LNLinear (K , N )
214
+ input_tensor = torch .randn (
215
+ M , K , device = device , dtype = ref_dtype , requires_grad = True
216
+ )
217
+ elif model_type == "norm_ffn_norm" :
218
+ m_ref = NormFFNResidualNorm (
219
+ dim = 4096 ,
220
+ hidden_dim = 16384 ,
221
+ multiple_of = 1024 ,
222
+ ffn_dim_multiplier = 1.3 ,
223
+ )
224
+ input_tensor = torch .randn (
225
+ 1 , 8192 , 4096 , device = device , dtype = ref_dtype
226
+ ).requires_grad_ ()
125
227
else :
228
+ M , K , N = 4 * 4096 , 8192 , 7168
126
229
m_ref = torch .nn .Sequential (
127
- torch .nn .Linear (params . K , params . N , bias = False ),
230
+ torch .nn .Linear (K , N , bias = False ),
128
231
)
129
- m_ref = m_ref .to (device ).to (params .ref_dtype )
232
+ input_tensor = torch .randn (
233
+ M , K , device = device , dtype = ref_dtype , requires_grad = True
234
+ )
235
+
236
+ m_ref = m_ref .to (device ).to (ref_dtype )
130
237
131
238
linear_type = LinearType [linear_type .upper ()]
132
239
linear_cls = (
@@ -136,10 +243,6 @@ def main(
136
243
m_float8 = copy .deepcopy (m_ref )
137
244
swap_linear_with_float8_linear (m_float8 , linear_cls )
138
245
139
- input_tensor = torch .randn (
140
- params .M , params .K , device = "cuda" , dtype = params .ref_dtype , requires_grad = True
141
- )
142
-
143
246
def ref_forw_backward (x ):
144
247
out = m_ref (x )
145
248
out .sum ().backward ()
@@ -173,14 +276,14 @@ def float8_forw_backward_wrapper(x):
173
276
float8_forw_backward_wrapper (input_tensor )
174
277
175
278
# Profile Reference Model
176
- ref_suffix = f"_ref_compile_{ compile } .json"
279
+ ref_suffix = f"_ { model_type } _ref_compile_{ compile } .json"
177
280
profile_config = ProfileConfig (
178
281
profile_path_prefix + ref_suffix , ref_suffix , iters = 5 , warmup_iters = 5 , sync = True
179
282
)
180
283
profile_function (profile_config , ref_forw_backward , input_tensor )
181
284
182
285
# Profile Float8 Model
183
- float8_suffix = f"_float8_compile_{ compile } _{ linear_type } .json"
286
+ float8_suffix = f"_ { model_type } _float8_compile_{ compile } _{ linear_type } .json"
184
287
profile_config = ProfileConfig (
185
288
profile_path_prefix + float8_suffix ,
186
289
float8_suffix ,
0 commit comments