From 899002b0ee8f8fc5fef8660f1c49f8a7fe8c326f Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 13 Jun 2024 21:44:05 -0700 Subject: [PATCH 1/5] add norm_ffn_norm to profile script Summary: This PR adds an example FFN with the preceding and subsequent norms to the profile script. I hope for this to speed up debugging of kernel performance on LLaMa. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- benchmarks/profile_linear_float8.py | 159 +++++++++++++++++++++++----- 1 file changed, 131 insertions(+), 28 deletions(-) diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 2c1ae6e..b6629fd 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -14,6 +14,8 @@ import fire import torch +import torch.nn as nn +import torch.nn.functional as F from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( @@ -38,6 +40,105 @@ def forward(self, x): return x +# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class NormFFNResidualNorm(nn.Module): + """ + A fragment representing the end of TransformerBlock n and the start + of TransformerBlock n + 1, intended to include the fusions relevant + to float8 gemms in the FFN module in forward and backward. + """ + + def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier): + super().__init__() + self.ffn_norm = RMSNorm(dim) + self.ffn = FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier) + self.attn_norm = RMSNorm(dim) + + def forward(self, h): + # end of transformer block n + x = self.ffn_norm(h) + x = self.ffn(x) + x = h + x + # start of transformer block n + 1 + x = self.attn_norm(x) + return x + + @dataclass class ProfileConfig: file_path: Optional[str] = None @@ -93,40 +194,46 @@ def profile_function( return prof -@dataclass(frozen=True) -class ModelParams: - M: int - K: int - N: int - ref_dtype: torch.dtype - layer_norm: bool = True - - def main( profile_path_prefix: Path, compile: bool = True, linear_type: str = "dynamic", - use_layer_norm: bool = False, + model_type: str = "linear", ): - params = ModelParams( - M=4 * 4096, - K=8192, - N=7168, - ref_dtype=torch.bfloat16, - layer_norm=use_layer_norm, - ) + assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported" + print(f"Compile is set to | {compile}") print(f"Using Linear type: | {linear_type}") - print(f"Use layer norm is set to | {params.layer_norm}") + print(f"model_type is set to | {model_type}") device = "cuda" - if params.layer_norm: - m_ref = LNLinear(params.K, params.N) + ref_dtype = torch.bfloat16 + if model_type == "ln_linear": + M, K, N = 4 * 4096, 8192, 7168 + m_ref = LNLinear(K, N) + input_tensor = torch.randn( + M, K, device=device, dtype=ref_dtype, requires_grad=True + ) + elif model_type == "norm_ffn_norm": + m_ref = NormFFNResidualNorm( + dim=4096, + hidden_dim=16384, + multiple_of=1024, + ffn_dim_multiplier=1.3, + ) + input_tensor = torch.randn( + 1, 8192, 4096, device=device, dtype=ref_dtype + ).requires_grad_() else: + M, K, N = 4 * 4096, 8192, 7168 m_ref = torch.nn.Sequential( - torch.nn.Linear(params.K, params.N, bias=False), + torch.nn.Linear(K, N, bias=False), ) - m_ref = m_ref.to(device).to(params.ref_dtype) + input_tensor = torch.randn( + M, K, device=device, dtype=ref_dtype, requires_grad=True + ) + + m_ref = m_ref.to(device).to(ref_dtype) linear_type = LinearType[linear_type.upper()] linear_cls = ( @@ -136,10 +243,6 @@ def main( m_float8 = copy.deepcopy(m_ref) swap_linear_with_float8_linear(m_float8, linear_cls) - input_tensor = torch.randn( - params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True - ) - def ref_forw_backward(x): out = m_ref(x) out.sum().backward() @@ -173,14 +276,14 @@ def float8_forw_backward_wrapper(x): float8_forw_backward_wrapper(input_tensor) # Profile Reference Model - ref_suffix = f"_ref_compile_{compile}.json" + ref_suffix = f"_{model_type}_ref_compile_{compile}.json" profile_config = ProfileConfig( profile_path_prefix + ref_suffix, ref_suffix, iters=5, warmup_iters=5, sync=True ) profile_function(profile_config, ref_forw_backward, input_tensor) # Profile Float8 Model - float8_suffix = f"_float8_compile_{compile}_{linear_type}.json" + float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json" profile_config = ProfileConfig( profile_path_prefix + float8_suffix, float8_suffix, From 998b0c89ca557f02210cc7e3a5b80f3da87518f2 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 17 Jun 2024 14:30:07 -0700 Subject: [PATCH 2/5] Update on "add norm_ffn_norm to profile script" Summary: This PR adds an example FFN with the preceding and subsequent norms to the profile script. I hope for this to speed up debugging of kernel performance on LLaMa. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- benchmarks/profile_linear_float8.py | 153 +++++++++++++++++++++++----- benchmarks/utils.py | 47 +++++++++ 2 files changed, 174 insertions(+), 26 deletions(-) create mode 100644 benchmarks/utils.py diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index b6629fd..9d75308 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -12,6 +12,7 @@ from typing import Callable, Optional import fire +import pandas as pd import torch import torch.nn as nn @@ -26,6 +27,15 @@ sync_float8_amax_and_scale_history, ) from torch.profiler import profile, ProfilerActivity, record_function +from utils import ( + profiler_output_to_gpu_time_for_key, + profiler_output_to_time_by_kernel_name, +) + +# don't truncate long kernel names +pd.options.display.max_colwidth = 100 +# display 3 trailing decimal points for floats +pd.set_option("display.float_format", "{:.3f}".format) class LNLinear(torch.nn.Module): @@ -188,9 +198,6 @@ def profile_function( if config.file_path is not None: prof.export_chrome_trace(config.file_path) - if config.file_path is None: - print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) - return prof @@ -199,8 +206,10 @@ def main( compile: bool = True, linear_type: str = "dynamic", model_type: str = "linear", + dtype_filter: str = "both", ): assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported" + assert dtype_filter in ("both", "float8", "bfloat16") print(f"Compile is set to | {compile}") print(f"Using Linear type: | {linear_type}") @@ -251,6 +260,8 @@ def float8_forw(x): out = m_float8(x) return out + sync_amax_history = sync_float8_amax_and_scale_history + def float8_forw_backward_wrapper(x): # sync_float8_amax_and_scale_history is not full graph torch # compile friendly, so we add a high level wrapper to allow @@ -259,7 +270,7 @@ def float8_forw_backward_wrapper(x): # TODO(future): make this better if linear_requires_sync(linear_type): with record_function("scale_amax_and_scales"): - sync_float8_amax_and_scale_history(m_float8) + sync_amax_history(m_float8) out = float8_forw(x) # out.sum().backward() is also not torch.compile fullgraph @@ -268,30 +279,120 @@ def float8_forw_backward_wrapper(x): out.sum().backward() if compile: - ref_forw_backward = torch.compile(ref_forw_backward) + m_ref = torch.compile(m_ref, fullgraph=True) float8_forw = torch.compile(float8_forw, fullgraph=True) + # Note: it's faster to compile the combination of sync_amax_history wit + # forward because we only look up from dynamo cache once. + # However, compiling the sync function separately makes it more + # convenient to analyze the total time spent on it. + sync_amax_history = torch.compile(sync_amax_history) + + # warm up + for _ in range(1): + if dtype_filter != "float8": + ref_forw_backward(input_tensor) + if dtype_filter != "bfloat16": + float8_forw_backward_wrapper(input_tensor) + + # profile_iters = 5 + profile_iters = 2 + ref_times, float8_times = None, None + + if dtype_filter != "float8": + # Profile Reference Model + print("profiling ref") + ref_suffix = f"_{model_type}_ref_compile_{compile}.json" + ref_path = profile_path_prefix + ref_suffix + profile_config = ProfileConfig( + ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True + ) + p = profile_function(profile_config, ref_forw_backward, input_tensor) + print(f"saved {ref_path}") + ref_times = profiler_output_to_time_by_kernel_name(p) + + if dtype_filter != "bfloat16": + # Profile Float8 Model + print("profiling float8") + float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json" + float8_path = profile_path_prefix + float8_suffix + profile_config = ProfileConfig( + float8_path, + float8_suffix, + iters=profile_iters, + warmup_iters=2, + sync=True, + ) + p = profile_function(profile_config, float8_forw_backward_wrapper, input_tensor) + print(f"saved {float8_path}") + float8_times = profiler_output_to_time_by_kernel_name(p) + + # get the time spent per user annotation + sync_time_us = profiler_output_to_gpu_time_for_key(p, "scale_amax_and_scales") + sync_time_ms = sync_time_us / profile_iters / 1e3 + print(f"Sync time ms: {sync_time_ms}") + + if dtype_filter == "both": + data = [] + + def kernel_name_to_category(k): + # number prefix is for easy sorting + if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"): + return "0_gemm" + elif ( + # max(abs(tensor)) + ("abs" in k and "max" in k) + or + # casting pointwise to float8 + ("clamp" in k) + or + # things related to scaled_mm + ("scaled_mm" in k) + or + # syncing amaxes and scales + ("roll" in k) + ): + # note: the above filter is approximate and will give false + # positives if model code contains other code to abs/max/clamp + return "1_f8_overhead" + return "2_other" + + for k, v in ref_times.items(): + data.append( + ["0_ref", k, kernel_name_to_category(k), v / 1e3 / profile_iters] + ) + for k, v in float8_times.items(): + data.append( + ["1_float8", k, kernel_name_to_category(k), v / 1e3 / profile_iters] + ) + + df = pd.DataFrame(data, columns=["experiment", "kernel", "category", "time_ms"]) + print("\nSummary of GPU time by CPU kernel\n\n", df) + + # compare gemm and overhead time + df_p = df.pivot_table( + columns=["category"], + index="experiment", + values="time_ms", + aggfunc="sum", + fill_value=0, + margins=True, + ) + # drop last row, which has totals across ref + float8 which does not make sense + df_p = df_p[:-1] + + df_p = df_p.transpose() + df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"] + df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"] + print( + "\nSummary of time (ms) by kernel category, across ref and float8\n\n", df_p + ) - for _ in range(5): - ref_forw_backward(input_tensor) - float8_forw_backward_wrapper(input_tensor) - - # Profile Reference Model - ref_suffix = f"_{model_type}_ref_compile_{compile}.json" - profile_config = ProfileConfig( - profile_path_prefix + ref_suffix, ref_suffix, iters=5, warmup_iters=5, sync=True - ) - profile_function(profile_config, ref_forw_backward, input_tensor) - - # Profile Float8 Model - float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json" - profile_config = ProfileConfig( - profile_path_prefix + float8_suffix, - float8_suffix, - iters=5, - warmup_iters=5, - sync=True, - ) - profile_function(profile_config, float8_forw_backward_wrapper, input_tensor) + # calculate sync time as pct of total float time + total_float8_ms = df_p.iloc[3]["1_float8"] + sync_approx_ratio = sync_time_ms / total_float8_ms + print( + f"\nFloat8 amax/scale sync approx ratio of total time: {sync_approx_ratio:.3f}" + ) def invoke_main() -> None: diff --git a/benchmarks/utils.py b/benchmarks/utils.py new file mode 100644 index 0000000..9377e27 --- /dev/null +++ b/benchmarks/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import json + + +def profiler_output_to_time_by_kernel_name(prof): + """ + Input: a profiler with captured events. + Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name + + Note that if there are user_annotations in the captured events, `torch.profiler` + will include their time in the total GPU time displayed at the bottom of + `key_averages.table()`. The filter below excludes them to prevent double + counting. + """ + key_averages = prof.key_averages() + thresh = 1e-10 + kernel_name_to_gpu_time_us = collections.defaultdict(float) + for e in key_averages: + # manually filter top-level CPU events with attributed CUDA time + # example CPU event row: + # aten::addmm 0.83% 76.554us 0.98% 90.846us 90.846us 1.022ms 31.82% 1.022ms 1.022ms 1 + # and it maps to this CUDA event: + # sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize256x64... 0.00% 0.000us 0.00% 0.000us 0.000us 1.022ms 31.82% 1.022ms 1.022ms 1 + if not (e.self_cpu_time_total > thresh and e.self_device_time_total > thresh): + continue + kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total + return kernel_name_to_gpu_time_us + + +def profiler_output_to_gpu_time_for_key(prof, key): + """ + Input: an event name + Output: sum of GPU time of all events with that name in `prof` + + This is useful to get the total time of a user annotation + """ + total = 0 + for e in prof.profiler.function_events: + if e.key == key: + total += e.device_time_total + return total From 82bdec71e035f62b44317e640727a6a63f909263 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 17 Jun 2024 15:08:36 -0700 Subject: [PATCH 3/5] Update on "add norm_ffn_norm to profile script" Summary: This PR adds an example FFN with the preceding and subsequent norms to the profile script. It also adds a couple of automatic data exctration QOL items: 1. extract GPU time and aggregate it per CPU kernel name 2. attribute the kernel GPU time to gemms, float8 overhead or other 3. approximate the time spent syncing scales/amaxes and display as pct of total time I hope for this to speed up debugging of kernel performance on various models, as this automates a lot of high level metrics which take more time to get from visualizing the traces. Example output when testing `norm_ffn_norm` with delayed scaling and compile: ``` Summary of GPU time by CPU kernel experiment kernel category time_ms 0 0_ref triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 2_other 0.061 1 0_ref aten::mm 0_gemm 14.691 2 0_ref triton_poi_fused_mul_silu_1 2_other 0.304 3 0_ref triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2 2_other 0.083 4 0_ref aten::sum 2_other 0.050 5 0_ref aten::fill_ 2_other 0.002 6 0_ref aten::copy_ 2_other 0.059 7 0_ref triton_red_fused__to_copy_add_div_mul_pow_sum_0 2_other 0.129 8 0_ref triton_poi_fused_add_fill_mul_sigmoid_silu_sub_1 2_other 0.520 9 0_ref triton_red_fused__to_copy_add_mul_sum_2 2_other 1.363 10 0_ref triton_red_fused__to_copy_add_div_mul_pow_sum_3 2_other 0.151 11 0_ref aten::add_ 2_other 0.567 12 1_float8 triton_per_fused_cat_copy_max_roll_0 1_f8_overhead 0.009 13 1_float8 triton_poi_fused_copy_1 2_other 0.004 14 1_float8 triton_poi_fused_copy_2 2_other 0.002 15 1_float8 triton_poi_fused_copy_3 2_other 0.004 16 1_float8 triton_poi_fused_copy_4 2_other 0.002 17 1_float8 triton_poi_fused_copy_5 2_other 0.004 18 1_float8 triton_poi_fused_copy_6 2_other 0.002 19 1_float8 triton_red_fused__to_copy_abs_add_clamp_max_mean_mul_pow_rsqrt_0 1_f8_overhead 0.089 20 1_float8 triton_red_fused__to_copy_abs_fill_max_mul_1 1_f8_overhead 0.003 21 1_float8 triton_red_fused_abs_max_2 1_f8_overhead 0.140 22 1_float8 triton_per_fused_abs_fill_max_3 1_f8_overhead 0.010 23 1_float8 triton_poi_fused_reciprocal_4 2_other 0.014 24 1_float8 triton_poi_fused__scaled_mm_clone_5 1_f8_overhead 0.289 25 1_float8 aten::_scaled_mm 0_gemm 8.054 26 1_float8 triton_red_fused_abs_max_mul_silu_6 1_f8_overhead 0.246 27 1_float8 triton_red_fused_abs_max_7 1_f8_overhead 0.061 28 1_float8 triton_poi_fused__to_copy_clamp_mul_silu_8 1_f8_overhead 0.254 29 1_float8 triton_poi_fused__scaled_mm_clone_9 1_f8_overhead 0.149 30 1_float8 triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10 2_other 0.115 31 1_float8 triton_poi_fused_clone_11 2_other 0.151 32 1_float8 triton_poi_fused_clone_12 2_other 0.089 33 1_float8 aten::sum 2_other 0.049 34 1_float8 aten::fill_ 2_other 0.013 35 1_float8 aten::copy_ 2_other 0.060 36 1_float8 triton_red_fused__to_copy_mul_sum_0 2_other 0.548 37 1_float8 triton_red_fused__scaled_mm__to_copy_abs_add_div_max_mul_pow_reciprocal_sum_1 1_f8_overhead 0.109 38 1_float8 triton_red_fused_abs_fill_max_2 1_f8_overhead 0.004 39 1_float8 triton_poi_fused__scaled_mm_clone_reciprocal_3 1_f8_overhead 0.049 40 1_float8 triton_poi_fused__scaled_mm_clone_reciprocal_4 1_f8_overhead 0.007 41 1_float8 triton_red_fused_abs_add_fill_max_mul_sigmoid_silu_sub_5 1_f8_overhead 0.316 42 1_float8 triton_per_fused_abs_fill_max_mul_silu_6 1_f8_overhead 0.005 43 1_float8 triton_poi_fused__to_copy_add_clamp_fill_mul_sigmoid_silu_sub_7 1_f8_overhead 0.408 44 1_float8 triton_poi_fused__scaled_mm_clone_reciprocal_8 1_f8_overhead 0.294 45 1_float8 triton_red_fused__to_copy_add_mul_sum_9 2_other 0.810 46 1_float8 triton_red_fused__to_copy_add_div_mul_pow_sum_10 2_other 0.148 47 1_float8 aten::add_ 2_other 0.567 Summary of time (ms) by kernel category, across ref and float8 experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 14.691 8.054 0.548 1.824 1_f8_overhead 0.000 2.441 inf 0.000 2_other 3.291 2.582 0.785 1.274 All 17.981 13.077 0.727 1.375 Float8 amax/scale sync approx ratio of total time: 0.014 ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- benchmarks/profile_linear_float8.py | 95 +++++++++++++---------------- benchmarks/utils.py | 23 +++++++ 2 files changed, 64 insertions(+), 54 deletions(-) diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 9d75308..ee6ebfe 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -28,6 +28,7 @@ ) from torch.profiler import profile, ProfilerActivity, record_function from utils import ( + kernel_name_to_category, profiler_output_to_gpu_time_for_key, profiler_output_to_time_by_kernel_name, ) @@ -294,9 +295,9 @@ def float8_forw_backward_wrapper(x): if dtype_filter != "bfloat16": float8_forw_backward_wrapper(input_tensor) - # profile_iters = 5 - profile_iters = 2 + profile_iters = 5 ref_times, float8_times = None, None + data = [] if dtype_filter != "float8": # Profile Reference Model @@ -309,6 +310,12 @@ def float8_forw_backward_wrapper(x): p = profile_function(profile_config, ref_forw_backward, input_tensor) print(f"saved {ref_path}") ref_times = profiler_output_to_time_by_kernel_name(p) + total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters + for k, v in ref_times.items(): + v_ms = v / 1e3 / profile_iters + data.append( + ["0_ref", k, kernel_name_to_category(k), v_ms, v_ms / total_time_ms] + ) if dtype_filter != "bfloat16": # Profile Float8 Model @@ -325,67 +332,45 @@ def float8_forw_backward_wrapper(x): p = profile_function(profile_config, float8_forw_backward_wrapper, input_tensor) print(f"saved {float8_path}") float8_times = profiler_output_to_time_by_kernel_name(p) + total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters + for k, v in float8_times.items(): + v_ms = v / 1e3 / profile_iters + data.append( + [ + "1_float8", + k, + kernel_name_to_category(k), + v / 1e3 / profile_iters, + v_ms / total_time_ms, + ] + ) # get the time spent per user annotation sync_time_us = profiler_output_to_gpu_time_for_key(p, "scale_amax_and_scales") sync_time_ms = sync_time_us / profile_iters / 1e3 print(f"Sync time ms: {sync_time_ms}") - if dtype_filter == "both": - data = [] - - def kernel_name_to_category(k): - # number prefix is for easy sorting - if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"): - return "0_gemm" - elif ( - # max(abs(tensor)) - ("abs" in k and "max" in k) - or - # casting pointwise to float8 - ("clamp" in k) - or - # things related to scaled_mm - ("scaled_mm" in k) - or - # syncing amaxes and scales - ("roll" in k) - ): - # note: the above filter is approximate and will give false - # positives if model code contains other code to abs/max/clamp - return "1_f8_overhead" - return "2_other" - - for k, v in ref_times.items(): - data.append( - ["0_ref", k, kernel_name_to_category(k), v / 1e3 / profile_iters] - ) - for k, v in float8_times.items(): - data.append( - ["1_float8", k, kernel_name_to_category(k), v / 1e3 / profile_iters] - ) - - df = pd.DataFrame(data, columns=["experiment", "kernel", "category", "time_ms"]) - print("\nSummary of GPU time by CPU kernel\n\n", df) - - # compare gemm and overhead time - df_p = df.pivot_table( - columns=["category"], - index="experiment", - values="time_ms", - aggfunc="sum", - fill_value=0, - margins=True, - ) - # drop last row, which has totals across ref + float8 which does not make sense - df_p = df_p[:-1] + df = pd.DataFrame( + data, columns=["experiment", "kernel", "category", "time_ms", "pct_gpu_time"] + ) + print("\nSummary of GPU time by CPU kernel\n\n", df) + + # compare gemm and overhead time + df_p = df.pivot_table( + columns=["category"], + index="experiment", + values="time_ms", + aggfunc="sum", + fill_value=0, + margins=True, + ) + # drop last row, which has totals across ref + float8 which does not make sense + df_p = df_p[:-1] + df_p = df_p.transpose() - df_p = df_p.transpose() + if dtype_filter == "both": df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"] df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"] - print( - "\nSummary of time (ms) by kernel category, across ref and float8\n\n", df_p - ) # calculate sync time as pct of total float time total_float8_ms = df_p.iloc[3]["1_float8"] @@ -394,6 +379,8 @@ def kernel_name_to_category(k): f"\nFloat8 amax/scale sync approx ratio of total time: {sync_approx_ratio:.3f}" ) + print("\nSummary of time (ms) by kernel category\n\n", df_p) + def invoke_main() -> None: # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles/current_profile --compile=True --linear_type="dynamic" diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 9377e27..c102a54 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -45,3 +45,26 @@ def profiler_output_to_gpu_time_for_key(prof, key): if e.key == key: total += e.device_time_total return total + + +def kernel_name_to_category(k): + # number prefix is for easy sorting + if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"): + return "0_gemm" + elif ( + # max(abs(tensor)) + ("abs" in k and "max" in k) + or + # casting pointwise to float8 + ("clamp" in k) + or + # things related to scaled_mm + ("scaled_mm" in k) + or + # syncing amaxes and scales + ("roll" in k) + ): + # note: the above filter is approximate and will give false + # positives if model code contains other code to abs/max/clamp + return "1_f8_overhead" + return "2_other" From 0925ce8bf84660cce56afcfb388e4fef8bffbf68 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 17 Jun 2024 15:39:20 -0700 Subject: [PATCH 4/5] Update on "add norm_ffn_norm to profile script" Summary: This PR adds an example FFN with the preceding and subsequent norms to the profile script. It also adds a couple of automatic data exctration QOL items: 1. extract GPU time and aggregate it per CPU kernel name 2. attribute the kernel GPU time to gemms, float8 overhead or other 3. approximate the time spent syncing scales/amaxes and display as pct of total time I hope for this to speed up debugging of kernel performance on various models, as this automates a lot of high level metrics which take more time to get from visualizing the traces. Example output when testing `norm_ffn_norm` with delayed scaling and compile: ``` Summary of GPU time by CPU kernel experiment kernel category time_ms 0 0_ref triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 2_other 0.061 1 0_ref aten::mm 0_gemm 14.691 2 0_ref triton_poi_fused_mul_silu_1 2_other 0.304 3 0_ref triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2 2_other 0.083 4 0_ref aten::sum 2_other 0.050 5 0_ref aten::fill_ 2_other 0.002 6 0_ref aten::copy_ 2_other 0.059 7 0_ref triton_red_fused__to_copy_add_div_mul_pow_sum_0 2_other 0.129 8 0_ref triton_poi_fused_add_fill_mul_sigmoid_silu_sub_1 2_other 0.520 9 0_ref triton_red_fused__to_copy_add_mul_sum_2 2_other 1.363 10 0_ref triton_red_fused__to_copy_add_div_mul_pow_sum_3 2_other 0.151 11 0_ref aten::add_ 2_other 0.567 12 1_float8 triton_per_fused_cat_copy_max_roll_0 1_f8_overhead 0.009 13 1_float8 triton_poi_fused_copy_1 2_other 0.004 14 1_float8 triton_poi_fused_copy_2 2_other 0.002 15 1_float8 triton_poi_fused_copy_3 2_other 0.004 16 1_float8 triton_poi_fused_copy_4 2_other 0.002 17 1_float8 triton_poi_fused_copy_5 2_other 0.004 18 1_float8 triton_poi_fused_copy_6 2_other 0.002 19 1_float8 triton_red_fused__to_copy_abs_add_clamp_max_mean_mul_pow_rsqrt_0 1_f8_overhead 0.089 20 1_float8 triton_red_fused__to_copy_abs_fill_max_mul_1 1_f8_overhead 0.003 21 1_float8 triton_red_fused_abs_max_2 1_f8_overhead 0.140 22 1_float8 triton_per_fused_abs_fill_max_3 1_f8_overhead 0.010 23 1_float8 triton_poi_fused_reciprocal_4 2_other 0.014 24 1_float8 triton_poi_fused__scaled_mm_clone_5 1_f8_overhead 0.289 25 1_float8 aten::_scaled_mm 0_gemm 8.054 26 1_float8 triton_red_fused_abs_max_mul_silu_6 1_f8_overhead 0.246 27 1_float8 triton_red_fused_abs_max_7 1_f8_overhead 0.061 28 1_float8 triton_poi_fused__to_copy_clamp_mul_silu_8 1_f8_overhead 0.254 29 1_float8 triton_poi_fused__scaled_mm_clone_9 1_f8_overhead 0.149 30 1_float8 triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10 2_other 0.115 31 1_float8 triton_poi_fused_clone_11 2_other 0.151 32 1_float8 triton_poi_fused_clone_12 2_other 0.089 33 1_float8 aten::sum 2_other 0.049 34 1_float8 aten::fill_ 2_other 0.013 35 1_float8 aten::copy_ 2_other 0.060 36 1_float8 triton_red_fused__to_copy_mul_sum_0 2_other 0.548 37 1_float8 triton_red_fused__scaled_mm__to_copy_abs_add_div_max_mul_pow_reciprocal_sum_1 1_f8_overhead 0.109 38 1_float8 triton_red_fused_abs_fill_max_2 1_f8_overhead 0.004 39 1_float8 triton_poi_fused__scaled_mm_clone_reciprocal_3 1_f8_overhead 0.049 40 1_float8 triton_poi_fused__scaled_mm_clone_reciprocal_4 1_f8_overhead 0.007 41 1_float8 triton_red_fused_abs_add_fill_max_mul_sigmoid_silu_sub_5 1_f8_overhead 0.316 42 1_float8 triton_per_fused_abs_fill_max_mul_silu_6 1_f8_overhead 0.005 43 1_float8 triton_poi_fused__to_copy_add_clamp_fill_mul_sigmoid_silu_sub_7 1_f8_overhead 0.408 44 1_float8 triton_poi_fused__scaled_mm_clone_reciprocal_8 1_f8_overhead 0.294 45 1_float8 triton_red_fused__to_copy_add_mul_sum_9 2_other 0.810 46 1_float8 triton_red_fused__to_copy_add_div_mul_pow_sum_10 2_other 0.148 47 1_float8 aten::add_ 2_other 0.567 Summary of time (ms) by kernel category, across ref and float8 experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 14.691 8.054 0.548 1.824 1_f8_overhead 0.000 2.441 inf 0.000 2_other 3.291 2.582 0.785 1.274 All 17.981 13.077 0.727 1.375 Float8 amax/scale sync approx ratio of total time: 0.014 ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- benchmarks/profile_linear_float8.py | 156 +++++++++++++++++----------- benchmarks/utils.py | 14 +++ 2 files changed, 112 insertions(+), 58 deletions(-) diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index ee6ebfe..221d6b8 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -5,8 +5,9 @@ # LICENSE file in the root directory of this source tree. import copy +import io import random -from contextlib import nullcontext +from contextlib import nullcontext, redirect_stdout from dataclasses import dataclass, field from pathlib import Path from typing import Callable, Optional @@ -29,6 +30,7 @@ from torch.profiler import profile, ProfilerActivity, record_function from utils import ( kernel_name_to_category, + parse_bw_and_kernel_name, profiler_output_to_gpu_time_for_key, profiler_output_to_time_by_kernel_name, ) @@ -288,70 +290,106 @@ def float8_forw_backward_wrapper(x): # convenient to analyze the total time spent on it. sync_amax_history = torch.compile(sync_amax_history) - # warm up - for _ in range(1): + # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output + # to populate triton kernel bandwidth further down in the script + f = io.StringIO() + with redirect_stdout(f): + # warm up + for _ in range(1): + if dtype_filter != "float8": + ref_forw_backward(input_tensor) + if dtype_filter != "bfloat16": + float8_forw_backward_wrapper(input_tensor) + + profile_iters = 5 + ref_times, float8_times = None, None + data = [] + if dtype_filter != "float8": - ref_forw_backward(input_tensor) - if dtype_filter != "bfloat16": - float8_forw_backward_wrapper(input_tensor) - - profile_iters = 5 - ref_times, float8_times = None, None - data = [] - - if dtype_filter != "float8": - # Profile Reference Model - print("profiling ref") - ref_suffix = f"_{model_type}_ref_compile_{compile}.json" - ref_path = profile_path_prefix + ref_suffix - profile_config = ProfileConfig( - ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True - ) - p = profile_function(profile_config, ref_forw_backward, input_tensor) - print(f"saved {ref_path}") - ref_times = profiler_output_to_time_by_kernel_name(p) - total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters - for k, v in ref_times.items(): - v_ms = v / 1e3 / profile_iters - data.append( - ["0_ref", k, kernel_name_to_category(k), v_ms, v_ms / total_time_ms] + # Profile Reference Model + print("profiling ref") + ref_suffix = f"_{model_type}_ref_compile_{compile}.json" + ref_path = profile_path_prefix + ref_suffix + profile_config = ProfileConfig( + ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True ) + p = profile_function(profile_config, ref_forw_backward, input_tensor) + print(f"saved {ref_path}") + ref_times = profiler_output_to_time_by_kernel_name(p) + total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters + for k, v in ref_times.items(): + v_ms = v / 1e3 / profile_iters + data.append( + [ + "0_ref", + k, + kernel_name_to_category(k), + v_ms, + v_ms / total_time_ms, + None, + ] + ) - if dtype_filter != "bfloat16": - # Profile Float8 Model - print("profiling float8") - float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json" - float8_path = profile_path_prefix + float8_suffix - profile_config = ProfileConfig( - float8_path, - float8_suffix, - iters=profile_iters, - warmup_iters=2, - sync=True, - ) - p = profile_function(profile_config, float8_forw_backward_wrapper, input_tensor) - print(f"saved {float8_path}") - float8_times = profiler_output_to_time_by_kernel_name(p) - total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters - for k, v in float8_times.items(): - v_ms = v / 1e3 / profile_iters - data.append( - [ - "1_float8", - k, - kernel_name_to_category(k), - v / 1e3 / profile_iters, - v_ms / total_time_ms, - ] + if dtype_filter != "bfloat16": + # Profile Float8 Model + print("profiling float8") + float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json" + float8_path = profile_path_prefix + float8_suffix + profile_config = ProfileConfig( + float8_path, + float8_suffix, + iters=profile_iters, + warmup_iters=2, + sync=True, + ) + p = profile_function( + profile_config, float8_forw_backward_wrapper, input_tensor ) + print(f"saved {float8_path}") + float8_times = profiler_output_to_time_by_kernel_name(p) + total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters + for k, v in float8_times.items(): + v_ms = v / 1e3 / profile_iters + data.append( + [ + "1_float8", + k, + kernel_name_to_category(k), + v / 1e3 / profile_iters, + v_ms / total_time_ms, + None, + ] + ) + + # get the time spent per user annotation + sync_time_us = profiler_output_to_gpu_time_for_key( + p, "scale_amax_and_scales" + ) + sync_time_ms = sync_time_us / profile_iters / 1e3 + print(f"Sync time ms: {sync_time_ms}") + + # print the redirected stdout back to regular stdout + print(f.getvalue()) - # get the time spent per user annotation - sync_time_us = profiler_output_to_gpu_time_for_key(p, "scale_amax_and_scales") - sync_time_ms = sync_time_us / profile_iters / 1e3 - print(f"Sync time ms: {sync_time_ms}") + # populate the triton kernel bandwidth + for line in f.getvalue().split("\n"): + maybe_bw, maybe_kernel_name = parse_bw_and_kernel_name(line) + if maybe_kernel_name is not None: + # O(N) search, but it's ok since lists are small + for datum in data: + if datum[1] == maybe_kernel_name: + datum[-1] = maybe_bw df = pd.DataFrame( - data, columns=["experiment", "kernel", "category", "time_ms", "pct_gpu_time"] + data, + columns=[ + "experiment", + "kernel", + "category", + "time_ms", + "pct_gpu_time", + "bw_gpbs", + ], ) print("\nSummary of GPU time by CPU kernel\n\n", df) @@ -373,6 +411,7 @@ def float8_forw_backward_wrapper(x): df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"] # calculate sync time as pct of total float time + # note: this time is not useful if TORCHINDUCTOR_PROFILE is on total_float8_ms = df_p.iloc[3]["1_float8"] sync_approx_ratio = sync_time_ms / total_float8_ms print( @@ -384,6 +423,7 @@ def float8_forw_backward_wrapper(x): def invoke_main() -> None: # Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles/current_profile --compile=True --linear_type="dynamic" + # You can set TORCHINDUCTOR_PROFILE=1 to also capture triton kernel bandwidth fire.Fire(main) diff --git a/benchmarks/utils.py b/benchmarks/utils.py index c102a54..fd4d501 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -6,6 +6,7 @@ import collections import json +import re def profiler_output_to_time_by_kernel_name(prof): @@ -68,3 +69,16 @@ def kernel_name_to_category(k): # positives if model code contains other code to abs/max/clamp return "1_f8_overhead" return "2_other" + + +def parse_bw_and_kernel_name(line): + """ + Input: a single line of stdout of TORCHINDUCTOR_PROFILE=1 output, such as + 0.257ms 0.537 GB 2092.43GB/s triton_red_fused_native_layer_norm_0 + Output: the bandwidth value and the kernel name, or None and None + """ + result = re.search(".* ([0-9\.]+)GB/s.*(triton_[a-z_0-9]+)", line) + if result: + return result.group(1), result.group(2) + else: + return None, None From 321e3248ef801327afe6350045235b1d5762b85b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 17 Jun 2024 15:49:30 -0700 Subject: [PATCH 5/5] Update on "add norm_ffn_norm to profile script" Summary: This PR adds an example FFN with the preceding and subsequent norms to the profile script. It also adds a couple of automatic data exctration QOL items: 1. extract GPU time and aggregate it per CPU kernel name 2. attribute the kernel GPU time to gemms, float8 overhead or other 3. approximate the time spent syncing scales/amaxes and display as pct of total time I hope for this to speed up debugging of kernel performance on various models, as this automates a lot of high level metrics which take more time to get from visualizing the traces. Example output when testing `norm_ffn_norm` with delayed scaling and compile: bandwidth off ``` Summary of GPU time by CPU kernel experiment kernel category time_ms 0 0_ref triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 2_other 0.061 1 0_ref aten::mm 0_gemm 14.691 2 0_ref triton_poi_fused_mul_silu_1 2_other 0.304 3 0_ref triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2 2_other 0.083 4 0_ref aten::sum 2_other 0.050 5 0_ref aten::fill_ 2_other 0.002 6 0_ref aten::copy_ 2_other 0.059 7 0_ref triton_red_fused__to_copy_add_div_mul_pow_sum_0 2_other 0.129 8 0_ref triton_poi_fused_add_fill_mul_sigmoid_silu_sub_1 2_other 0.520 9 0_ref triton_red_fused__to_copy_add_mul_sum_2 2_other 1.363 10 0_ref triton_red_fused__to_copy_add_div_mul_pow_sum_3 2_other 0.151 11 0_ref aten::add_ 2_other 0.567 12 1_float8 triton_per_fused_cat_copy_max_roll_0 1_f8_overhead 0.009 13 1_float8 triton_poi_fused_copy_1 2_other 0.004 14 1_float8 triton_poi_fused_copy_2 2_other 0.002 15 1_float8 triton_poi_fused_copy_3 2_other 0.004 16 1_float8 triton_poi_fused_copy_4 2_other 0.002 17 1_float8 triton_poi_fused_copy_5 2_other 0.004 18 1_float8 triton_poi_fused_copy_6 2_other 0.002 19 1_float8 triton_red_fused__to_copy_abs_add_clamp_max_mean_mul_pow_rsqrt_0 1_f8_overhead 0.089 20 1_float8 triton_red_fused__to_copy_abs_fill_max_mul_1 1_f8_overhead 0.003 21 1_float8 triton_red_fused_abs_max_2 1_f8_overhead 0.140 22 1_float8 triton_per_fused_abs_fill_max_3 1_f8_overhead 0.010 23 1_float8 triton_poi_fused_reciprocal_4 2_other 0.014 24 1_float8 triton_poi_fused__scaled_mm_clone_5 1_f8_overhead 0.289 25 1_float8 aten::_scaled_mm 0_gemm 8.054 26 1_float8 triton_red_fused_abs_max_mul_silu_6 1_f8_overhead 0.246 27 1_float8 triton_red_fused_abs_max_7 1_f8_overhead 0.061 28 1_float8 triton_poi_fused__to_copy_clamp_mul_silu_8 1_f8_overhead 0.254 29 1_float8 triton_poi_fused__scaled_mm_clone_9 1_f8_overhead 0.149 30 1_float8 triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10 2_other 0.115 31 1_float8 triton_poi_fused_clone_11 2_other 0.151 32 1_float8 triton_poi_fused_clone_12 2_other 0.089 33 1_float8 aten::sum 2_other 0.049 34 1_float8 aten::fill_ 2_other 0.013 35 1_float8 aten::copy_ 2_other 0.060 36 1_float8 triton_red_fused__to_copy_mul_sum_0 2_other 0.548 37 1_float8 triton_red_fused__scaled_mm__to_copy_abs_add_div_max_mul_pow_reciprocal_sum_1 1_f8_overhead 0.109 38 1_float8 triton_red_fused_abs_fill_max_2 1_f8_overhead 0.004 39 1_float8 triton_poi_fused__scaled_mm_clone_reciprocal_3 1_f8_overhead 0.049 40 1_float8 triton_poi_fused__scaled_mm_clone_reciprocal_4 1_f8_overhead 0.007 41 1_float8 triton_red_fused_abs_add_fill_max_mul_sigmoid_silu_sub_5 1_f8_overhead 0.316 42 1_float8 triton_per_fused_abs_fill_max_mul_silu_6 1_f8_overhead 0.005 43 1_float8 triton_poi_fused__to_copy_add_clamp_fill_mul_sigmoid_silu_sub_7 1_f8_overhead 0.408 44 1_float8 triton_poi_fused__scaled_mm_clone_reciprocal_8 1_f8_overhead 0.294 45 1_float8 triton_red_fused__to_copy_add_mul_sum_9 2_other 0.810 46 1_float8 triton_red_fused__to_copy_add_div_mul_pow_sum_10 2_other 0.148 47 1_float8 aten::add_ 2_other 0.567 Summary of time (ms) by kernel category, across ref and float8 experiment 0_ref 1_float8 f8_div_ref ref_div_f8 category 0_gemm 14.691 8.054 0.548 1.824 1_f8_overhead 0.000 2.441 inf 0.000 2_other 3.291 2.582 0.785 1.274 All 17.981 13.077 0.727 1.375 Float8 amax/scale sync approx ratio of total time: 0.014 ``` bandwidth on ``` experiment kernel category time_ms pct_gpu_time bw_gpbs 0 0_ref triton_red_fused_native_layer_norm_0 2_other 0.242 0.021 2085.93 1 0_ref aten::mm 0_gemm 10.120 0.877 None 2 0_ref aten::sum 2_other 0.121 0.010 None 3 0_ref aten::fill_ 2_other 0.002 0.000 None 4 0_ref aten::copy_ 2_other 0.200 0.017 None 5 i 0_ref driton_red_fused_nytive_layer_norm_native_layer_norm_backward_0 2_other 0.350 0.030 2207.69 6 0_ref aten::add_ 2_other 0.511 0.044 None 7 1_float8 triton_per_fused_copy_max_roll_0 1_f8_overhead 0.005 0.001 0.01 8 1_float8 triton_per_fused_copy_max_roll_1 1_f8_overhead 0.003 0.000 0.01 9 1_float8 triton_red_fused__to_copy_abs_clamp_max_mul_native_layer_norm_0 1_f8_overhead 0.367 0.048 1083.39 10 1_float8 triton_red_fused_abs_fill_max_native_layer_norm_1 1_f8_overhead 0.004 0.001 5.52 11 1_float8 triton_red_fused_abs_max_2 1_f8_overhead 0.069 0.009 1486.46 12 1_float8 triton_per_fused_abs_fill_max_3 1_f8_overhead 0.002 0.000 0.20 13 1_float8 triton_poi_fused_reciprocal_4 2_other 0.004 0.001 0.00 14 1_float8 triton_poi_fused__scaled_mm_clone_5 1_f8_overhead 0.152 0.020 1460.56 15 1_float8 aten::_scaled_mm 0_gemm 5.213 0.683 None 16 1_float8 triton_poi_fused_clone_6 2_other 0.172 0.023 1488.13 17 1_float8 aten::sum 2_other 0.126 0.017 None 18 1_float8 aten::fill_ 2_other 0.006 0.001 None 19 1_float8 aten::copy_ 2_other 0.200 0.026 None 20 1_float8 triton_red_fused_abs_max_0 1_f8_overhead 0.129 0.017 1732.38 21 1_float8 triton_per_fused_abs_fill_max_1 1_f8_overhead 0.002 0.000 0.20 22 1_float8 triton_poi_fused__scaled_mm__to_copy_clamp_clone_mul_reciprocal_2 1_f8_overhead 0.310 0.041 1459.47 23 1_float8 triton_poi_fused__scaled_mm__to_copy_clamp_clone_mul_reciprocal_3 1_f8_overhead 0.002 0.000 0.00 24 1_float8 triton_red_fused_native_layer_norm_native_layer_norm_backward_4 2_other 0.352 0.046 2205.86 25 1_float8 aten::add_ 2_other 0.510 0.067 None ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- benchmarks/profile_linear_float8.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 221d6b8..148ca6d 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -391,6 +391,11 @@ def float8_forw_backward_wrapper(x): "bw_gpbs", ], ) + df.sort_values( + ["experiment", "category", "pct_gpu_time"], + ascending=[True, True, False], + inplace=True, + ) print("\nSummary of GPU time by CPU kernel\n\n", df) # compare gemm and overhead time