Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 2f00319

Browse files
committed
comments
1 parent f0e8d7c commit 2f00319

File tree

4 files changed

+45
-7
lines changed

4 files changed

+45
-7
lines changed

float8_experimental/float8_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __init__(self, *args, **kwargs):
160160
)
161161
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0]))
162162

163-
# Whether to emulate the fp8 matmul logic in float32
163+
# Defines the behavior of the matmul in the forward and backward pass
164164
self.forward_config = ScaledMMConfig()
165165
self.backward_config = ScaledMMConfig()
166166

float8_experimental/float8_linear_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,7 @@ def swap_linear_with_float8_linear(
122122
raise AssertionError(
123123
f"Does not support a root nn.Linear with children: {module}"
124124
)
125-
print(f"Emulating: {emulate}")
126-
new_mod = module_cls.from_float(module, emulate=emulate)
127-
print(f"New mod: {new_mod.forward_config}")
128-
return new_mod
125+
return module_cls.from_float(module, emulate=emulate)
129126

130127
# Mark all modules to skip as visited
131128
root_module = module

float8_experimental/float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def to_fp8_no_autograd(
8484
x: the tensor to convert
8585
scale: the scale to use to convert the tensor
8686
float8_dtype: the float8 dtype to use
87-
mm_config: configuration for the scaled_mm will bread from this dataclass
87+
mm_config: Defines the configuration for the scaled_mm
8888
"""
8989
x_scaled = x * x_scale
9090
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)

test/test_base.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
sync_float8_amax_and_scale_history,
2424
)
2525
from float8_experimental.float8_python_api import addmm_float8_unwrapped
26-
from float8_experimental.float8_tensor import Float8Tensor
26+
from float8_experimental.float8_tensor import (
27+
Float8Tensor,
28+
merge_mm_configs,
29+
ScaledMMConfig,
30+
)
2731
from float8_experimental.float8_utils import (
2832
amax_to_scale,
2933
compute_error,
@@ -326,6 +330,43 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
326330
atol, rtol = 2e-3, 2e-3
327331
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
328332

333+
@unittest.skipIf(not is_H100, "CUDA not available")
334+
def test_different_configs_error(self):
335+
x_fp32 = torch.randn(16, 16, device="cuda")
336+
x_scale = torch.tensor(1.0, device="cuda")
337+
fp8_dtype = torch.float8_e4m3fn
338+
a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype)
339+
b = Float8Tensor.to_float8(
340+
x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True)
341+
)
342+
with pytest.raises(
343+
AssertionError,
344+
match="Both mm_configs must have the same emulate value, but got False and True",
345+
):
346+
a @ b
347+
348+
def test_merge_configs(sel):
349+
a = ScaledMMConfig(False, True, True)
350+
b = ScaledMMConfig(True, False, False)
351+
with pytest.raises(
352+
AssertionError,
353+
match="Both mm_configs must have the same emulate value, but got False and True",
354+
):
355+
merge_mm_configs(a, b)
356+
a = ScaledMMConfig(False, True, True)
357+
b = ScaledMMConfig(False, False, False)
358+
c = merge_mm_configs(a, b)
359+
assert c.emulate is False
360+
assert c.use_fast_accum is False
361+
assert c.fp8_output is False
362+
363+
a = ScaledMMConfig(False, True, False)
364+
b = ScaledMMConfig(False, True, False)
365+
c = merge_mm_configs(a, b)
366+
assert c.emulate is False
367+
assert c.use_fast_accum is True
368+
assert c.fp8_output is False
369+
329370

330371
class TestNumerics:
331372
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])

0 commit comments

Comments
 (0)