diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index b62020b..2c1ae6e 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. + +import copy import random from contextlib import nullcontext from dataclasses import dataclass, field @@ -12,15 +14,30 @@ import fire import torch +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear +from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( get_float8_linear, linear_requires_sync, LinearType, + swap_linear_with_float8_linear, sync_float8_amax_and_scale_history, ) from torch.profiler import profile, ProfilerActivity, record_function +class LNLinear(torch.nn.Module): + def __init__(self, fc_dim1, fc_dim2): + super().__init__() + self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False) + self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False) + + def forward(self, x): + x = self.ln(x) + x = self.fc(x) + return x + + @dataclass class ProfileConfig: file_path: Optional[str] = None @@ -77,65 +94,58 @@ def profile_function( @dataclass(frozen=True) -class LinearParams: +class ModelParams: M: int K: int N: int - input_bias: bool ref_dtype: torch.dtype layer_norm: bool = True - torch_compile: Optional[bool] = False -def main(profile_path: Path, compile: bool, linear_type: str): - profile_path = Path(profile_path) - assert profile_path.is_dir(), f"Path {profile_path} must be a directory" - params = LinearParams( +def main( + profile_path_prefix: Path, + compile: bool = True, + linear_type: str = "dynamic", + use_layer_norm: bool = False, +): + params = ModelParams( M=4 * 4096, K=8192, N=7168, - input_bias=False, ref_dtype=torch.bfloat16, - layer_norm=True, - torch_compile=compile, + layer_norm=use_layer_norm, ) print(f"Compile is set to | {compile}") print(f"Using Linear type: | {linear_type}") print(f"Use layer norm is set to | {params.layer_norm}") - linear_ref = torch.nn.Linear( - params.K, - params.N, - bias=params.input_bias, - device="cuda", - dtype=params.ref_dtype, - ) + + device = "cuda" + if params.layer_norm: + m_ref = LNLinear(params.K, params.N) + else: + m_ref = torch.nn.Sequential( + torch.nn.Linear(params.K, params.N, bias=False), + ) + m_ref = m_ref.to(device).to(params.ref_dtype) + linear_type = LinearType[linear_type.upper()] - linear_float8 = get_float8_linear(linear_type, linear_ref) + linear_cls = ( + Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear + ) + + m_float8 = copy.deepcopy(m_ref) + swap_linear_with_float8_linear(m_float8, linear_cls) input_tensor = torch.randn( params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True ) - if params.layer_norm: - ln = torch.nn.LayerNorm( - params.K, elementwise_affine=False, device="cuda", dtype=params.ref_dtype - ) - def ref_forw_backward(x): - if params.layer_norm: - with record_function("layer_norm"): - x = ln(x) - with record_function("forward"): - out = linear_ref(x) - with record_function("backward"): - out.sum().backward() + out = m_ref(x) + out.sum().backward() - def float8_forw_backward(x): - if params.layer_norm: - with record_function("layer_norm"): - x = ln(x) - with record_function("forward"): - out = linear_float8(x) + def float8_forw(x): + out = m_float8(x) return out def float8_forw_backward_wrapper(x): @@ -146,34 +156,34 @@ def float8_forw_backward_wrapper(x): # TODO(future): make this better if linear_requires_sync(linear_type): with record_function("scale_amax_and_scales"): - sync_float8_amax_and_scale_history(linear_float8) - out = float8_forw_backward(x) + sync_float8_amax_and_scale_history(m_float8) + out = float8_forw(x) # out.sum().backward() is also not torch.compile fullgraph # friendly with record_function("backward"): out.sum().backward() - if params.torch_compile: + if compile: ref_forw_backward = torch.compile(ref_forw_backward) - float8_forw_backward = torch.compile(float8_forw_backward, fullgraph=True) + float8_forw = torch.compile(float8_forw, fullgraph=True) for _ in range(5): ref_forw_backward(input_tensor) float8_forw_backward_wrapper(input_tensor) - # Profile Reference Linear - ref_string = f"linear_ref_dtype_{params.ref_dtype}_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}.json" + # Profile Reference Model + ref_suffix = f"_ref_compile_{compile}.json" profile_config = ProfileConfig( - str(profile_path / ref_string), ref_string, iters=5, warmup_iters=5, sync=True + profile_path_prefix + ref_suffix, ref_suffix, iters=5, warmup_iters=5, sync=True ) profile_function(profile_config, ref_forw_backward, input_tensor) - # # Profile Float8 Linear - float8_string = f"linear_float8_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}_{linear_type}.json" + # Profile Float8 Model + float8_suffix = f"_float8_compile_{compile}_{linear_type}.json" profile_config = ProfileConfig( - str(profile_path / float8_string), - float8_string, + profile_path_prefix + float8_suffix, + float8_suffix, iters=5, warmup_iters=5, sync=True, @@ -182,7 +192,7 @@ def float8_forw_backward_wrapper(x): def invoke_main() -> None: - # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles --compile=True --linear_type="dynamic" + # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles/current_profile --compile=True --linear_type="dynamic" fire.Fire(main)