@@ -47,170 +47,6 @@ def reset_parameters(self):
47
47
m .reset_parameters ()
48
48
49
49
50
- class TestHPTrainToFP8 :
51
- def base_test_mlp_transform (self , base_mlp , quantized_mlp , input_tensor ):
52
- with torch .no_grad ():
53
- base_output = base_mlp (input_tensor )
54
- transformed_output = quantized_mlp (input_tensor )
55
-
56
- # Compute and check SQNR
57
- sqnr = compute_error (base_output , transformed_output )
58
- assert sqnr .item () > 20 , f"SQNR is too low: { sqnr .item ()} dB"
59
-
60
- @pytest .mark .parametrize ("compile_backend" , ["eager" , "inductor" ])
61
- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
62
- @unittest .skipIf (
63
- not torch .cuda .is_available () or not is_H100 ,
64
- "CUDA not availabl or on non H100 machine" ,
65
- )
66
- def test_dynamic_fp8_mlp (self , compile_backend , dtype ):
67
- original_mlp = FeedForward ().to ("cuda" , dtype = dtype )
68
- original_mlp .reset_parameters ()
69
-
70
- dynamic_fp8_mlp = copy .deepcopy (original_mlp )
71
- swap_linear_with_float8_linear (
72
- dynamic_fp8_mlp ,
73
- Float8DynamicLinear ,
74
- from_float_kwargs = {"pre_quantize_weight" : True },
75
- )
76
-
77
- batch_size = 4
78
- num_tokens = 1024
79
- embedding_dim = 4096
80
-
81
- input_tensor = torch .randn (
82
- batch_size , num_tokens , embedding_dim , device = "cuda" , dtype = dtype
83
- )
84
-
85
- # Compile the models
86
- compiled_original_mlp = torch .compile (original_mlp , backend = compile_backend )
87
- compiled_dynamic_fp8_mlp = torch .compile (
88
- dynamic_fp8_mlp , backend = compile_backend
89
- )
90
-
91
- self .base_test_mlp_transform (
92
- compiled_original_mlp , compiled_dynamic_fp8_mlp , input_tensor
93
- )
94
-
95
- @pytest .mark .parametrize ("compile_backend" , ["eager" , "inductor" ])
96
- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
97
- @unittest .skipIf (
98
- not torch .cuda .is_available () or not is_H100 ,
99
- "CUDA not availabl or on non H100 machine" ,
100
- )
101
- def test_static_fp8_mlp (self , compile_backend , dtype ):
102
- original_mlp = FeedForward ().to ("cuda" , dtype = dtype )
103
- original_mlp .reset_parameters ()
104
-
105
- static_fp8_mlp = copy .deepcopy (original_mlp )
106
- swap_linear_with_float8_linear (
107
- static_fp8_mlp ,
108
- Float8DynamicLinear ,
109
- from_float_kwargs = {
110
- "pre_quantize_weight" : True ,
111
- "activation_scale" : torch .tensor (
112
- [1.0 ], device = "cuda" , dtype = torch .float32
113
- ),
114
- },
115
- )
116
-
117
- batch_size = 4
118
- num_tokens = 1024
119
- embedding_dim = 4096
120
-
121
- input_tensor = torch .randn (
122
- batch_size , num_tokens , embedding_dim , device = "cuda" , dtype = dtype
123
- )
124
-
125
- # Compile the models
126
- compiled_original_mlp = torch .compile (original_mlp , backend = compile_backend )
127
- compiled_static_fp8_mlp = torch .compile (static_fp8_mlp , backend = compile_backend )
128
-
129
- self .base_test_mlp_transform (
130
- compiled_original_mlp , compiled_static_fp8_mlp , input_tensor
131
- )
132
-
133
-
134
- class TestFP8TrainToFP8 :
135
- def train (self , model : nn .Module , dtype : torch .dtype ):
136
- model .train ()
137
- optimizer = torch .optim .SGD (model .parameters (), lr = 0.001 )
138
- criterion = nn .MSELoss ()
139
- target_tensor = torch .randn (4 , 1024 , 4096 , device = "cuda" , dtype = dtype )
140
- for _ in range (10 ):
141
- input_tensor = torch .randn (4 , 1024 , 4096 , device = "cuda" , dtype = dtype )
142
- optimizer .zero_grad ()
143
- output = model (input_tensor )
144
- loss = criterion (output , target_tensor )
145
- loss .backward ()
146
- optimizer .step ()
147
- model .eval ()
148
- return model
149
-
150
- @pytest .mark .parametrize ("compile_backend" , ["eager" , "inductor" ])
151
- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
152
- @unittest .skipIf (
153
- not torch .cuda .is_available () or not is_H100 ,
154
- "CUDA not available or on non H100 machine" ,
155
- )
156
- def test_fp8_save_and_load (self , compile_backend : str , dtype : torch .dtype ):
157
- # Initialize FP8 model
158
- fp8_mlp = FeedForward ().to ("cuda" , dtype = torch .float32 )
159
- fp8_mlp .reset_parameters ()
160
- swap_linear_with_float8_linear (
161
- fp8_mlp ,
162
- Float8DynamicLinear ,
163
- )
164
-
165
- # Train the model
166
- self .train (fp8_mlp , dtype )
167
-
168
- # Generate input tensor and original out
169
- input_tensor = torch .randn (4 , 1024 , 4096 , device = "cuda" , dtype = dtype )
170
- og_out = fp8_mlp (input_tensor )
171
-
172
- # Save model state dict
173
- buffer = io .BytesIO ()
174
- torch .save (fp8_mlp .state_dict (), buffer )
175
-
176
- # Reset buffer position to the beginning
177
- buffer .seek (0 )
178
-
179
- # Later on you load the model, will be w/ Float8DynamicLinear on meta device
180
- with torch .device ("meta" ):
181
- new_fp8_mlp = FeedForward ().to (dtype = dtype )
182
- swap_linear_with_float8_linear (
183
- new_fp8_mlp ,
184
- Float8DynamicLinear ,
185
- )
186
-
187
- # Load the actual data
188
- new_fp8_mlp .load_state_dict (
189
- torch .load (buffer , weights_only = True ), strict = True , assign = True
190
- )
191
-
192
- # Dynamic Activations + Quantized Weights
193
- def quantize_dynamic_linear (x : nn .Module ):
194
- if isinstance (x , Float8DynamicLinear ):
195
- x .set_quantization_scales (pre_quantize_weight = True )
196
- return x
197
-
198
- new_fp8_mlp .apply (quantize_dynamic_linear )
199
-
200
- for module in new_fp8_mlp .modules ():
201
- if isinstance (module , Float8DynamicLinear ):
202
- assert isinstance (module .weight , Float8Tensor )
203
- assert module .weight .requires_grad is False
204
-
205
- new_out = new_fp8_mlp (input_tensor )
206
-
207
- # Assert exact equality
208
- assert torch .all (og_out == new_out ).item ()
209
-
210
-
211
- # WE ARE GOING TO KEEP 1 or the other BELOW IS THE SEPARATE MODULE WORKFLOW
212
-
213
-
214
50
class TestHPTrainToFP8LinearInference :
215
51
def base_test_mlp_transform (self , base_mlp , quantized_mlp , input_tensor ):
216
52
with torch .no_grad ():
0 commit comments