diff --git a/README.md b/README.md index 32974079..c7ac029d 100644 --- a/README.md +++ b/README.md @@ -90,13 +90,17 @@ for _ in range(N_ITER): optimizer.step() ``` -# code tips +# 🧭 Code Organization -* `float8_experimental/float8_linear.py` - `Float8Linear` (main user facing entry point for delayed scaling) -* `float8_experimental/float8_dynamic_linear.py` - `Float8DynamicLinear` (main user facing entry point for dynamic scaling) -* `float8_experimental/float8_tensor.py` - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction +* `float8_experimental/float8_linear.py` + - `Float8Linear` (main user facing entry point for delayed scaling) +* `float8_experimental/float8_dynamic_linear.py` + - `Float8DynamicLinear` (main user facing entry point for dynamic scaling) +* `float8_experimental/float8_tensor.py` + - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction + - `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass -# testing +# Testing ```bash # run single-GPU unit tests @@ -117,7 +121,7 @@ pytest test/test_compile.py ./test/run_everything.sh ``` -# benchmarking +# Benchmarking ```bash # benchmark the torch._scaled_mm function on LLaMa 2 70B shapes @@ -130,4 +134,3 @@ pytest test/test_compile.py # License PyTorch has a BSD 3-Clause License, as found in the LICENSE file. - diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 42341923..58bf8848 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -16,6 +16,7 @@ import torch.utils.benchmark as benchmark from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history +from float8_experimental.float8_tensor import ScaledMMConfig from tqdm import tqdm # estimating TOPs for matmuls in fp32, fp16, fp8 @@ -54,8 +55,8 @@ class Experiment: ref_time_sec: float float8_time_sec: float dtype: torch.dtype - compiled: bool = False - float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn + compiled: bool + use_fast_accum: bool # 3 Times since we are calculating forward backward @property @@ -74,7 +75,7 @@ def float8_tops_sec(self): @property def float8_pct_top_peak(self): - return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype] + return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn] def main( @@ -95,9 +96,10 @@ def main( } input_bias = False ref_dtypes = [torch.bfloat16, torch.float16] + use_fast_accum = [True, False] experiment_list: List[Experiment] = [] - for idx, (dtype, (name, (K, N))) in enumerate( - tqdm(list(product(ref_dtypes, name_to_shapes_70b.items()))) + for idx, (dtype, fast_accum, (name, (K, N))) in enumerate( + tqdm(list(product(ref_dtypes, use_fast_accum, name_to_shapes_70b.items()))) ): if n_limit is not None and idx >= n_limit: break @@ -108,6 +110,10 @@ def main( linear_float8 = Float8Linear.from_float( copy.deepcopy(linear_ref), emulate=False ) + if fast_accum: + linear_float8.forward_config = ScaledMMConfig(False, True, False) + else: + linear_float8.forward_config = ScaledMMConfig(False, False, False) bsz, seq_len = 4, 4096 M = bsz * seq_len @@ -155,6 +161,7 @@ def wrapper(*args, **kwargs): float8_time, dtype, compile, + use_fast_accum=fast_accum, ) print(experiment) print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) @@ -168,7 +175,7 @@ def wrapper(*args, **kwargs): "N", "ref_dtype", "compiled", - "fp8_dtype", + "use_fast_accum", "ref_time_sec", "pt_fp8_time_sec", "ref_tops_sec", @@ -186,7 +193,7 @@ def wrapper(*args, **kwargs): experiment.shape[2], experiment.dtype, experiment.compiled, - experiment.float_8_dtype, + experiment.use_fast_accum, experiment.ref_time_sec, experiment.float8_time_sec, experiment.ref_tops_sec, @@ -214,6 +221,7 @@ def wrapper(*args, **kwargs): "shape", "ref_dtype", "compiled", + "use_fast_accum", "ref_time_sec", "pt_fp8_time_sec", "pt_fp8_speedup", diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 4a5ac2e1..5c6edb8c 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -10,6 +10,7 @@ from float8_experimental.float8_tensor import ( Float8Tensor, + ScaledMMConfig, tensor_already_casted_to_fp8, to_fp8_no_autograd, ) @@ -27,9 +28,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function): def forward( ctx, tensor, - emulate: bool, + mm_config: ScaledMMConfig, ): - ctx.emulate = emulate + ctx.mm_config = mm_config return tensor @staticmethod @@ -39,7 +40,7 @@ def backward(ctx, gradY): return gradY, None gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2) fp8_tensor = to_fp8_no_autograd( - gradY, gradY_scale, torch.float8_e5m2, ctx.emulate + gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config ) return fp8_tensor, None @@ -73,11 +74,11 @@ def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor: return inpt_tensor scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn) return Float8Tensor.to_float8( - inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate + inpt_tensor, scale, torch.float8_e4m3fn, mm_config=self.forward_config ) def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor: - return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate) + return NoopFwToFloat8E5M2Bw.apply(gradY, self.backward_config) @classmethod def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": @@ -97,5 +98,6 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": new_mod = cls(**super_kwargs) new_mod.weight = mod.weight new_mod.bias = mod.bias - new_mod.emulate = emulate + new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False) + new_mod.backward_config = ScaledMMConfig(emulate, False) return new_mod diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 09dd7147..ebaefd3e 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -20,7 +20,11 @@ import torch -from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd +from float8_experimental.float8_tensor import ( + Float8Tensor, + ScaledMMConfig, + to_fp8_no_autograd, +) from float8_experimental.float8_utils import ( amax_history_to_scale, @@ -73,12 +77,12 @@ def forward( fp8_scale_dL_dY, scale_fn_name, is_amax_initialized, - emulate: bool, + mm_config: ScaledMMConfig, ): ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized - ctx.emulate = emulate + ctx.mm_config = mm_config return tensor @staticmethod @@ -99,7 +103,9 @@ def backward(ctx, go): fp8_amax_dL_dY.fill_(tensor_to_amax(go)) - res = to_fp8_no_autograd(go, fp8_scale_dL_dY, torch.float8_e5m2, ctx.emulate) + res = to_fp8_no_autograd( + go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config + ) empty_grads = None, None, None, None, None, None return res, *empty_grads @@ -154,8 +160,9 @@ def __init__(self, *args, **kwargs): ) self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0])) - # Whether to emulate the fp8 matmul logic in float32 - self.emulate = False + # Defines the behavior of the matmul in the forward and backward pass + self.forward_config = ScaledMMConfig() + self.backward_config = ScaledMMConfig() # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo @@ -216,7 +223,11 @@ def cast_x_to_float8( is_amax_initialized, ) x_fp8 = Float8Tensor.to_float8( - x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, self.emulate + x, + self.fp8_scale_x, + torch.float8_e4m3fn, + self.fp8_amax_x, + self.forward_config, ) return x_fp8 @@ -239,13 +250,11 @@ def cast_w_to_float8( self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, - self.emulate, + self.forward_config, ) return w_fp8 - def cast_y_to_float8_in_bw( - self, y: torch.Tensor, emulate: bool = False - ) -> torch.Tensor: + def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: scale_fn_name = self.recipe.scale_fn_name y = NoopFwToFloat8E5M2Bw.apply( y, @@ -254,7 +263,7 @@ def cast_y_to_float8_in_bw( self.fp8_scale_dL_dY, scale_fn_name, self.is_amax_initialized, - emulate, + self.backward_config, ) return y @@ -295,7 +304,7 @@ def forward(self, x): y = torch.matmul(x_fp8, w_fp8.t()) # Cast gradY to float8_e5m2 during backward - y = self.cast_y_to_float8_in_bw(y, self.emulate) + y = self.cast_y_to_float8_in_bw(y) if self.bias is not None: y = y + self.bias.to(y.dtype) @@ -318,7 +327,12 @@ def from_float(cls, mod, emulate: bool = False): new_mod = cls(mod.in_features, mod.out_features, bias=False) new_mod.weight = mod.weight new_mod.bias = mod.bias - new_mod.emulate = emulate + + # Defines the behavior of the matmul in the forward and backward + # Forward we use fast_accum, backwards we do not + new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False) + new_mod.backward_config = ScaledMMConfig(emulate, False) + # I think its okay to send all params and buffers to device new_mod.to(mod.weight.device) return new_mod diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 955649a3..7eec3b6c 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -8,7 +8,11 @@ import torch from float8_experimental.float8_python_api import addmm_float8_unwrapped -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_tensor import ( + Float8Tensor, + merge_mm_configs, + ScaledMMConfig, +) from float8_experimental.float8_utils import is_row_major from torch.utils._pytree import tree_map @@ -41,7 +45,9 @@ def decorator(func): ) def float8_desugar_op(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) - return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype, args[0]._emulate) + return Float8Tensor( + new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config + ) @implements([aten.sum.dim_IntList]) @@ -89,13 +95,22 @@ def float8_mm(aten_op, args, kwargs=None): ) a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype - if a._emulate: - assert a._emulate == b._emulate + a_mm_config: ScaledMMConfig = a._mm_config + b_mm_config: ScaledMMConfig = b._mm_config + mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config) + if mm_config.emulate: return torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype )[0] tensor_out, amax = addmm_float8_unwrapped( - a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None + a_data, + a_scale, + b_data, + b_scale, + output_dtype, + output_scale=None, + bias=None, + use_fast_accum=mm_config.use_fast_accum, ) return tensor_out @@ -113,14 +128,23 @@ def float8_addmm(aten_op, args, kwargs=None): a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) output_dtype = a._orig_dtype assert bias.dtype == output_dtype, "bias dtype must match output dtype" - if a._emulate: - assert a._emulate == b._emulate + a_mm_config: ScaledMMConfig = a._mm_config + b_mm_config: ScaledMMConfig = b._mm_config + mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config) + if mm_config.emulate: out = torch.ops.aten.mm_float8_emulated( a._data, a._scale, b._data, b._scale, output_dtype )[0] return out + bias tensor_out, amax = addmm_float8_unwrapped( - a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=bias + a_data, + a_scale, + b_data, + b_scale, + output_dtype, + output_scale=None, + bias=bias, + use_fast_accum=mm_config.use_fast_accum, ) return tensor_out @@ -145,7 +169,7 @@ def autocast_to_copy(aten_op, args, kwargs=None): torch.bfloat16, }, "Only support floating point conversion for autocast w/ Float8Tensor" return Float8Tensor( - args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate + args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config ) @@ -170,7 +194,7 @@ def allgather_fp8(aten_op, args, kwargs=None): fp8_out = aten_op(fp8_data, *args[1:], **kwargs) fp8_out = fp8_out.view(fp8_input._data.dtype) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._emulate + fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config ) @@ -182,5 +206,5 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): fp8_data = fp8_input._data fp8_out = aten_op(fp8_data, *args[1:], **kwargs) return Float8Tensor( - fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._emulate + fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config ) diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index 9182f626..6cb406d4 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -15,17 +15,22 @@ import float8_experimental.float8_aten_api # noqa import torch -from float8_experimental.float8_tensor import Float8Tensor +# [Note] Usage of scales +# The meaning of scale in this library can be found in the definition of the Float8Tensor +# Cublas defines scale to always mean a multiplicative factor for the respective matrices +# For a,b going from fp8 -> fp32 we multiple by the inverse of the scale +# For output going from fp32 -> fp8 we multiply by the scale def addmm_float8_unwrapped( a_data: torch.Tensor, a_scale: torch.Tensor, b_data: torch.Tensor, b_scale: torch.tensor, output_dtype: torch.dtype, - output_scale: Optional[torch.Tensor], + output_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, + use_fast_accum: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ This is the unwrapped version of addmm_float8, which does not take in Float8Tensors @@ -43,6 +48,7 @@ def addmm_float8_unwrapped( scale_a=a_inverse_scale, scale_b=b_inverse_scale, scale_result=output_scale, + use_fast_accum=use_fast_accum, ) output += bias return output, output_amax @@ -54,41 +60,6 @@ def addmm_float8_unwrapped( scale_a=a_inverse_scale, scale_b=b_inverse_scale, scale_result=output_scale, + use_fast_accum=use_fast_accum, ) return output, output_amax - - -# [Note] Usage of scales -# The meaning of scale in this library can be found in the definition of the Float8Tensor -# Cublas defines scale to always mean a multiplicative factor for the respective matrices -# For a,b going from fp8 -> fp32 we multiple by the inverse of the scale -# For output going from fp32 -> fp8 we multiply by the scale -def mm_float8( - a: Float8Tensor, # input 1 - b: Float8Tensor, # input 2 - output_dtype: torch.dtype, # output dtype - output_scale: Optional[torch.Tensor] = None, # output scale, precomputed - emulate: bool = False, # whether to emulate the operation using fp32 -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Performs a matrix multiplication of two Float8Tensors `a` and `b`. - - Args: - a: The first matrix multiplication term. - b: The second matrix multiplication term. - output_dtype: The output tensor's dtype. - output_scale: The output tensor's scale, precomputed. - emulate: Whether to emulate the operation using fp32. - - Returns: - torch.Tensor: The result of the matrix multiplication. - """ - if emulate: - assert output_scale is None, "unsupported" - return torch.ops.aten.mm_float8_emulated( - a._data, a._scale, b._data, b._scale, output_dtype - ) - - return addmm_float8_unwrapped( - a._data, a._scale, b._data, b._scale, output_dtype, output_scale - ) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 2bf854f1..7287c11d 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from collections import namedtuple from typing import Dict, Optional import torch @@ -14,6 +15,38 @@ aten = torch.ops.aten +# ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass. +# emulate: whether to emulate the matmuls in fp32 +# use_fast_accum: whether to use the fast-accumulation option for scaled_mm +# fp8_output: whether to output the result of the scaled_mm in fp8 +ScaledMMConfig = namedtuple( + "ScaledMMConfig", + ["emulate", "use_fast_accum", "fp8_output"], + defaults=[False, False, False], +) + + +def merge_mm_configs( + a_mm_config: ScaledMMConfig, b_mm_config: ScaledMMConfig +) -> ScaledMMConfig: + """Merges two mm_configs together emulate behavior must match, + However we want to use_fast_accum in forward and not in backward. + We do this by populating the fields of the backproping grad. Same applies for fp8_output. + + For both use_fast_accum and fp8_output, if either config is False, the merged config will be False. + """ + assert ( + a_mm_config.emulate == b_mm_config.emulate + ), "Both mm_configs must have the same emulate value, but got {} and {}".format( + a_mm_config.emulate, b_mm_config.emulate + ) + return ScaledMMConfig( + emulate=a_mm_config.emulate, + use_fast_accum=a_mm_config.use_fast_accum and b_mm_config.use_fast_accum, + fp8_output=a_mm_config.fp8_output and b_mm_config.fp8_output, + ) + + def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: """ Check if the tensor is already casted to fp8 @@ -30,7 +63,10 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: def to_fp8_no_autograd( - x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, emulate: bool + x: torch.Tensor, + x_scale: torch.Tensor, + float8_dtype: torch.dtype, + mm_config: Optional[ScaledMMConfig], ) -> "Float8Tensor": """Convert a tensor to float8 without autograd This is used in multiple places in the codebase to convert a tensor to float8 @@ -48,7 +84,7 @@ def to_fp8_no_autograd( x: the tensor to convert scale: the scale to use to convert the tensor float8_dtype: the float8 dtype to use - emulate: whether to emulate the matmuls in fp32 + mm_config: Defines the configuration for the scaled_mm """ x_scaled = x * x_scale bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) @@ -62,7 +98,7 @@ def to_fp8_no_autograd( local_bits = bits_fp8.to_local() local_scale = x_scale.to_local() inner_float8_tensor = Float8Tensor( - local_bits, local_scale, x.dtype, emulate=emulate + local_bits, local_scale, x.dtype, mm_config=mm_config ) return DTensor.from_local( inner_float8_tensor, @@ -73,7 +109,7 @@ def to_fp8_no_autograd( stride=bits_fp8.stride(), ) - return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate) + return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config) def from_fp8_no_autograd(x: torch.Tensor) -> torch.Tensor: @@ -122,7 +158,7 @@ def forward( scale: torch.Tensor, float8_dtype=torch.float8_e4m3fn, amax_buffer: Optional[torch.Tensor] = None, - emulate: bool = False, + mm_config: Optional[ScaledMMConfig] = None, ): """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer. Args @@ -135,7 +171,7 @@ def forward( if amax_buffer is not None: amax_buffer.fill_(tensor_to_amax(tensor)) - return to_fp8_no_autograd(tensor, scale, float8_dtype, emulate) + return to_fp8_no_autograd(tensor, scale, float8_dtype, mm_config=mm_config) @staticmethod def backward(ctx, g): @@ -182,15 +218,15 @@ class Float8Tensor(torch.Tensor): _data: torch.Tensor _scale: torch.Tensor _orig_dtype: torch.dtype - _emulate: bool - __slots__ = ["_data", "_scale", "_orig_dtype", "_emulate"] + _mm_config: ScaledMMConfig + __slots__ = ["_data", "_scale", "_orig_dtype", "_mm_config"] def __new__( cls, data: torch.Tensor, scale: torch.Tensor, orig_dtype: torch.dtype, - emulate=False, + mm_config: Optional[ScaledMMConfig], ): assert ( scale.numel() == 1 @@ -211,16 +247,17 @@ def __new__( self._data = data self._scale = scale self._orig_dtype = orig_dtype - self._emulate = emulate + self._mm_config = mm_config if mm_config is not None else ScaledMMConfig() + return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, emulate={self._emulate}\nas_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, mm_config={self._mm_config}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { "_orig_dtype": self._orig_dtype, - "_emulate": self._emulate, + "_mm_config": self._mm_config, } return ["_data", "_scale"], ctx @@ -231,7 +268,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride inner_tensors["_data"], inner_tensors["_scale"], metadata["_orig_dtype"], - metadata["_emulate"], + metadata["_mm_config"], ) def to_original_precision(self): @@ -244,7 +281,7 @@ def to_float8( scale: torch.Tensor, float8_dtype: torch.dtype, amax_buffer: Optional[torch.Tensor] = None, - emulate: bool = False, + mm_config: Optional[ScaledMMConfig] = None, ): """Converts a higher precision tensor to float8 in a differentiable way. @@ -258,11 +295,7 @@ def to_float8( Float8Tensor: a float8 tensor """ return ToFloat8ConstrFunc.apply( - tensor, - scale, - float8_dtype, - amax_buffer, - emulate, + tensor, scale, float8_dtype, amax_buffer, mm_config ) @classmethod diff --git a/test/test_base.py b/test/test_base.py index 8a8233d4..b5ad1102 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -22,8 +22,12 @@ swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_python_api import mm_float8 -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_python_api import addmm_float8_unwrapped +from float8_experimental.float8_tensor import ( + Float8Tensor, + merge_mm_configs, + ScaledMMConfig, +) from float8_experimental.float8_utils import ( amax_to_scale, compute_error, @@ -281,7 +285,8 @@ class TestScaledMM: @pytest.mark.parametrize( "base_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - def test_scaled_mm_vs_emulated(self, base_dtype): + @pytest.mark.parametrize("use_fast_accum", [True, False]) + def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): torch.manual_seed(42) input_dtype = torch.float8_e4m3fn output_dtype = base_dtype @@ -296,11 +301,16 @@ def test_scaled_mm_vs_emulated(self, base_dtype): a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) - out_scaled_mm, output_amax_scaled = mm_float8( - a_fp8, b_fp8, output_dtype=output_dtype, emulate=False + out_scaled_mm, output_amax_scaled = addmm_float8_unwrapped( + a_fp8._data, + a_fp8._scale, + b_fp8._data, + b_fp8._scale, + output_dtype=output_dtype, + use_fast_accum=use_fast_accum, ) - out_emulated, output_amax_emulated = mm_float8( - a_fp8, b_fp8, output_dtype=output_dtype, emulate=True + out_emulated, output_amax_emulated = torch.ops.aten.mm_float8_emulated( + a_fp8._data, a_fp8._scale, b_fp8._data, b_fp8._scale, output_dtype ) if output_dtype != base_dtype: @@ -320,6 +330,43 @@ def test_scaled_mm_vs_emulated(self, base_dtype): atol, rtol = 2e-3, 2e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + @unittest.skipIf(not is_H100, "CUDA not available") + def test_different_configs_error(self): + x_fp32 = torch.randn(16, 16, device="cuda") + x_scale = torch.tensor(1.0, device="cuda") + fp8_dtype = torch.float8_e4m3fn + a = Float8Tensor.to_float8(x_fp32, x_scale, fp8_dtype) + b = Float8Tensor.to_float8( + x_fp32, x_scale, fp8_dtype, mm_config=ScaledMMConfig(True) + ) + with pytest.raises( + AssertionError, + match="Both mm_configs must have the same emulate value, but got False and True", + ): + a @ b + + def test_merge_configs(sel): + a = ScaledMMConfig(False, True, True) + b = ScaledMMConfig(True, False, False) + with pytest.raises( + AssertionError, + match="Both mm_configs must have the same emulate value, but got False and True", + ): + merge_mm_configs(a, b) + a = ScaledMMConfig(False, True, True) + b = ScaledMMConfig(False, False, False) + c = merge_mm_configs(a, b) + assert c.emulate is False + assert c.use_fast_accum is False + assert c.fp8_output is False + + a = ScaledMMConfig(False, True, False) + b = ScaledMMConfig(False, True, False) + c = merge_mm_configs(a, b) + assert c.emulate is False + assert c.use_fast_accum is True + assert c.fp8_output is False + class TestNumerics: @pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @@ -356,7 +403,8 @@ def test_swap_root_linear(self): module = nn.Linear(3, 3) module = swap_linear_with_float8_linear(module, module_cls, emulate=emulate) self.assertIsInstance(module, module_cls) - self.assertEqual(module.emulate, emulate) + self.assertEqual(module.forward_config.emulate, emulate) + self.assertEqual(module.backward_config.emulate, emulate) def test_swap_root_linear_with_children_raises(self): for module_cls, emulate in itertools.product( diff --git a/test/test_compile.py b/test/test_compile.py index 2a9abba9..9cc64d32 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -21,7 +21,7 @@ swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend @@ -118,7 +118,7 @@ def forward(self, x): self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, - emulate=True, # TODO: I set this to True so that people on A100 can test, but once fix is in, set to False + ScaledMMConfig(), ) if self.graph_break: torch._dynamo.graph_break() @@ -181,9 +181,9 @@ def test_float8_graph_output(self): type(y_compiled._orig_dtype) ) assert isinstance( - y_compiled._emulate, bool + y_compiled._mm_config.emulate, bool ), "Float8Tensor._emulate should be a bool but got {}".format( - type(y_compiled._emulate) + type(y_compiled._mm_config.emulate) ) diff --git a/test/test_dtensor.py b/test/test_dtensor.py index cb76325d..54a87fa4 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -18,7 +18,7 @@ NoopFwToFloat8E5M2Bw, ) from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear -from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig from float8_experimental.float8_tensor_parallel import ( Float8ColwiseParallel, Float8RowwiseParallel, @@ -152,7 +152,7 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2Bw.apply(out, False) + out = NoopFwToFloat8E5M2Bw.apply(out, ScaledMMConfig()) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward()