|
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 | import copy
|
| 7 | +import io |
7 | 8 | import random
|
8 | 9 | import unittest
|
9 | 10 |
|
@@ -123,5 +124,76 @@ def test_static_fp8_mlp(self, compile_backend, dtype):
|
123 | 124 | )
|
124 | 125 |
|
125 | 126 |
|
| 127 | +class TestFP8TrainToFP8: |
| 128 | + def train(self, model: nn.Module, dtype: torch.dtype): |
| 129 | + model.train() |
| 130 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| 131 | + criterion = nn.MSELoss() |
| 132 | + target_tensor = torch.randn(4, 4096, device="cuda", dtype=dtype) |
| 133 | + for _ in range(10): |
| 134 | + input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) |
| 135 | + optimizer.zero_grad() |
| 136 | + output = model(input_tensor) |
| 137 | + loss = criterion(output, target_tensor) |
| 138 | + loss.backward() |
| 139 | + optimizer.step() |
| 140 | + model.eval() |
| 141 | + return model |
| 142 | + |
| 143 | + @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) |
| 144 | + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) |
| 145 | + @unittest.skipIf( |
| 146 | + not torch.cuda.is_available() or not is_H100, |
| 147 | + "CUDA not available or on non H100 machine", |
| 148 | + ) |
| 149 | + def test_fp8_save_and_load(self, compile_backend: str, dtype: torch.dtype): |
| 150 | + # Initialize FP8 model |
| 151 | + fp8_mlp = FeedForward().to("cuda", dtype=torch.float32) |
| 152 | + fp8_mlp.reset_parameters() |
| 153 | + swap_linear_with_float8_linear( |
| 154 | + fp8_mlp, |
| 155 | + Float8DynamicLinear, |
| 156 | + ) |
| 157 | + |
| 158 | + # Train the model |
| 159 | + self.train(fp8_mlp, dtype) |
| 160 | + |
| 161 | + # Generate input tensor and original out |
| 162 | + input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) |
| 163 | + og_out = fp8_mlp(input_tensor) |
| 164 | + |
| 165 | + # Save model state dict |
| 166 | + buffer = io.BytesIO() |
| 167 | + torch.save(fp8_mlp.state_dict(), buffer) |
| 168 | + |
| 169 | + # Reset buffer position to the beginning |
| 170 | + buffer.seek(0) |
| 171 | + |
| 172 | + # Later on you load the model, will be w/ Float8DynamicLinear on meta device |
| 173 | + with torch.device("meta"): |
| 174 | + new_fp8_mlp = FeedForward().to(dtype=dtype) |
| 175 | + |
| 176 | + # Load the actual data |
| 177 | + new_fp8_mlp.load_state_dict(torch.load(buffer), strict=True, assign=True) |
| 178 | + |
| 179 | + # Dynamic Activations + Quantized Weights |
| 180 | + def quantize_dynamic_linear(x: nn.Module): |
| 181 | + if isinstance(x, Float8DynamicLinear): |
| 182 | + x.set_quantization_scales(True) |
| 183 | + return x |
| 184 | + |
| 185 | + new_fp8_mlp.apply(quantize_dynamic_linear) |
| 186 | + |
| 187 | + for module in new_fp8_mlp.modules(): |
| 188 | + if isinstance(module, Float8DynamicLinear): |
| 189 | + assert isinstance(module.weight, Float8DynamicLinear) |
| 190 | + assert module.weight.requires_grad is False |
| 191 | + |
| 192 | + new_out = new_fp8_mlp(input_tensor) |
| 193 | + |
| 194 | + # Assert exact equality |
| 195 | + assert torch.all(og_out == new_out).item() |
| 196 | + |
| 197 | + |
126 | 198 | if __name__ == "__main__":
|
127 | 199 | pytest.main([__file__])
|
0 commit comments