Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 2ff810c

Browse files
committed
added fp8 to fp8 training flow tests
1 parent 7983b78 commit 2ff810c

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,20 @@ def forward(self, x):
7474
y = cast_to_float8_e5m2_bw(y, self.backward_config)
7575
return y
7676

77-
def quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> None:
78-
"""Used to perform static_quantization, useful for inference where weights are not updated."""
77+
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
78+
"""This functions converts the weight to a Float8Tensor and sets its requires_grad to False.
7979
80+
Args:
81+
dtype: The dtype to quantize the weight to. Default is e4m3_dtype.
82+
83+
Note:
84+
This function is typically called during inference to quantize the weight once since
85+
the weight is not updated during inference.
86+
87+
"""
88+
assert not isinstance(
89+
self.weight, Float8Tensor
90+
), "Weight has already been quantized, cannot quantize again."
8091
scale = tensor_to_scale(self.weight, dtype)
8192
quantized_weight = to_fp8_no_autograd(
8293
self.weight,

test/test_inference_flows.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7+
import io
78
import random
89
import unittest
910

@@ -14,6 +15,7 @@
1415
import torch.nn.functional as F
1516
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1617
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
18+
from float8_experimental.float8_tensor import Float8Tensor
1719
from float8_experimental.float8_utils import compute_error
1820

1921

@@ -123,5 +125,80 @@ def test_static_fp8_mlp(self, compile_backend, dtype):
123125
)
124126

125127

128+
class TestFP8TrainToFP8:
129+
def train(self, model: nn.Module, dtype: torch.dtype):
130+
model.train()
131+
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
132+
criterion = nn.MSELoss()
133+
target_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
134+
for _ in range(10):
135+
input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
136+
optimizer.zero_grad()
137+
output = model(input_tensor)
138+
loss = criterion(output, target_tensor)
139+
loss.backward()
140+
optimizer.step()
141+
model.eval()
142+
return model
143+
144+
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
145+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
146+
@unittest.skipIf(
147+
not torch.cuda.is_available() or not is_H100,
148+
"CUDA not available or on non H100 machine",
149+
)
150+
def test_fp8_save_and_load(self, compile_backend: str, dtype: torch.dtype):
151+
# Initialize FP8 model
152+
fp8_mlp = FeedForward().to("cuda", dtype=torch.float32)
153+
fp8_mlp.reset_parameters()
154+
swap_linear_with_float8_linear(
155+
fp8_mlp,
156+
Float8DynamicLinear,
157+
)
158+
159+
# Train the model
160+
self.train(fp8_mlp, dtype)
161+
162+
# Generate input tensor and original out
163+
input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
164+
og_out = fp8_mlp(input_tensor)
165+
166+
# Save model state dict
167+
buffer = io.BytesIO()
168+
torch.save(fp8_mlp.state_dict(), buffer)
169+
170+
# Reset buffer position to the beginning
171+
buffer.seek(0)
172+
173+
# Later on you load the model, will be w/ Float8DynamicLinear on meta device
174+
with torch.device("meta"):
175+
new_fp8_mlp = FeedForward().to(dtype=dtype)
176+
swap_linear_with_float8_linear(
177+
new_fp8_mlp,
178+
Float8DynamicLinear,
179+
)
180+
181+
# Load the actual data
182+
new_fp8_mlp.load_state_dict(torch.load(buffer), strict=True, assign=True)
183+
184+
# Dynamic Activations + Quantized Weights
185+
def quantize_dynamic_linear(x: nn.Module):
186+
if isinstance(x, Float8DynamicLinear):
187+
x.set_quantization_scales(True)
188+
return x
189+
190+
new_fp8_mlp.apply(quantize_dynamic_linear)
191+
192+
for module in new_fp8_mlp.modules():
193+
if isinstance(module, Float8DynamicLinear):
194+
assert isinstance(module.weight, Float8Tensor)
195+
assert module.weight.requires_grad is False
196+
197+
new_out = new_fp8_mlp(input_tensor)
198+
199+
# Assert exact equality
200+
assert torch.all(og_out == new_out).item()
201+
202+
126203
if __name__ == "__main__":
127204
pytest.main([__file__])

0 commit comments

Comments
 (0)