Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 64bc5d7

Browse files
committed
add norm_ffn_norm to profile script
Summary: This PR adds an example FFN with the preceding and subsequent norms to the profile script. I hope for this to speed up debugging of kernel performance on LLaMa. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0affdc3 Pull Request resolved: #282
1 parent b240ce7 commit 64bc5d7

File tree

2 files changed

+312
-51
lines changed

2 files changed

+312
-51
lines changed

benchmarks/profile_linear_float8.py

Lines changed: 242 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
from typing import Callable, Optional
1313

1414
import fire
15+
import pandas as pd
1516

1617
import torch
18+
import torch.nn as nn
19+
import torch.nn.functional as F
1720
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1821
from float8_experimental.float8_linear import Float8Linear
1922
from float8_experimental.float8_linear_utils import (
@@ -24,6 +27,16 @@
2427
sync_float8_amax_and_scale_history,
2528
)
2629
from torch.profiler import profile, ProfilerActivity, record_function
30+
from utils import (
31+
kernel_name_to_category,
32+
profiler_output_to_gpu_time_for_key,
33+
profiler_output_to_time_by_kernel_name,
34+
)
35+
36+
# don't truncate long kernel names
37+
pd.options.display.max_colwidth = 100
38+
# display 3 trailing decimal points for floats
39+
pd.set_option("display.float_format", "{:.3f}".format)
2740

2841

2942
class LNLinear(torch.nn.Module):
@@ -38,6 +51,105 @@ def forward(self, x):
3851
return x
3952

4053

54+
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
55+
class RMSNorm(nn.Module):
56+
"""
57+
Initialize the RMSNorm normalization layer.
58+
59+
Args:
60+
dim (int): The dimension of the input tensor.
61+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
62+
63+
Attributes:
64+
eps (float): A small value added to the denominator for numerical stability.
65+
weight (nn.Parameter): Learnable scaling parameter.
66+
67+
"""
68+
69+
def __init__(self, dim: int, eps: float = 1e-6):
70+
super().__init__()
71+
self.eps = eps
72+
self.weight = nn.Parameter(torch.ones(dim))
73+
74+
def _norm(self, x: torch.Tensor):
75+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
76+
77+
def forward(self, x: torch.Tensor):
78+
output = self._norm(x.float()).type_as(x)
79+
return output * self.weight
80+
81+
def reset_parameters(self):
82+
torch.nn.init.ones_(self.weight) # type: ignore
83+
84+
85+
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
86+
class FeedForward(nn.Module):
87+
"""
88+
FeedForward module
89+
90+
Args:
91+
dim (int): Input dimension.
92+
hidden_dim (int): Hidden dimension of the feedforward layer.
93+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
94+
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
95+
96+
Attributes:
97+
w1 (Linear): Linear transformation for the first layer.
98+
w2 (Linear): Linear transformation for the second layer.
99+
w3 (Linear): Linear transformation for the third layer.
100+
101+
"""
102+
103+
def __init__(
104+
self,
105+
dim: int,
106+
hidden_dim: int,
107+
multiple_of: int,
108+
ffn_dim_multiplier: Optional[float],
109+
):
110+
super().__init__()
111+
hidden_dim = int(2 * hidden_dim / 3)
112+
# custom dim factor multiplier
113+
if ffn_dim_multiplier is not None:
114+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
115+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
116+
117+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
118+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
119+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
120+
121+
def forward(self, x):
122+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
123+
124+
def init_weights(self, init_std: float):
125+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
126+
for linear in (self.w2, self.w3):
127+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
128+
129+
130+
class NormFFNResidualNorm(nn.Module):
131+
"""
132+
A fragment representing the end of TransformerBlock n and the start
133+
of TransformerBlock n + 1, intended to include the fusions relevant
134+
to float8 gemms in the FFN module in forward and backward.
135+
"""
136+
137+
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier):
138+
super().__init__()
139+
self.ffn_norm = RMSNorm(dim)
140+
self.ffn = FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier)
141+
self.attn_norm = RMSNorm(dim)
142+
143+
def forward(self, h):
144+
# end of transformer block n
145+
x = self.ffn_norm(h)
146+
x = self.ffn(x)
147+
x = h + x
148+
# start of transformer block n + 1
149+
x = self.attn_norm(x)
150+
return x
151+
152+
41153
@dataclass
42154
class ProfileConfig:
43155
file_path: Optional[str] = None
@@ -87,46 +199,51 @@ def profile_function(
87199
if config.file_path is not None:
88200
prof.export_chrome_trace(config.file_path)
89201

90-
if config.file_path is None:
91-
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
92-
93202
return prof
94203

95204

96-
@dataclass(frozen=True)
97-
class ModelParams:
98-
M: int
99-
K: int
100-
N: int
101-
ref_dtype: torch.dtype
102-
layer_norm: bool = True
103-
104-
105205
def main(
106206
profile_path_prefix: Path,
107207
compile: bool = True,
108208
linear_type: str = "dynamic",
109-
use_layer_norm: bool = False,
209+
model_type: str = "linear",
210+
dtype_filter: str = "both",
110211
):
111-
params = ModelParams(
112-
M=4 * 4096,
113-
K=8192,
114-
N=7168,
115-
ref_dtype=torch.bfloat16,
116-
layer_norm=use_layer_norm,
117-
)
212+
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
213+
assert dtype_filter in ("both", "float8", "bfloat16")
214+
118215
print(f"Compile is set to | {compile}")
119216
print(f"Using Linear type: | {linear_type}")
120-
print(f"Use layer norm is set to | {params.layer_norm}")
217+
print(f"model_type is set to | {model_type}")
121218

122219
device = "cuda"
123-
if params.layer_norm:
124-
m_ref = LNLinear(params.K, params.N)
220+
ref_dtype = torch.bfloat16
221+
if model_type == "ln_linear":
222+
M, K, N = 4 * 4096, 8192, 7168
223+
m_ref = LNLinear(K, N)
224+
input_tensor = torch.randn(
225+
M, K, device=device, dtype=ref_dtype, requires_grad=True
226+
)
227+
elif model_type == "norm_ffn_norm":
228+
m_ref = NormFFNResidualNorm(
229+
dim=4096,
230+
hidden_dim=16384,
231+
multiple_of=1024,
232+
ffn_dim_multiplier=1.3,
233+
)
234+
input_tensor = torch.randn(
235+
1, 8192, 4096, device=device, dtype=ref_dtype
236+
).requires_grad_()
125237
else:
238+
M, K, N = 4 * 4096, 8192, 7168
126239
m_ref = torch.nn.Sequential(
127-
torch.nn.Linear(params.K, params.N, bias=False),
240+
torch.nn.Linear(K, N, bias=False),
128241
)
129-
m_ref = m_ref.to(device).to(params.ref_dtype)
242+
input_tensor = torch.randn(
243+
M, K, device=device, dtype=ref_dtype, requires_grad=True
244+
)
245+
246+
m_ref = m_ref.to(device).to(ref_dtype)
130247

131248
linear_type = LinearType[linear_type.upper()]
132249
linear_cls = (
@@ -136,10 +253,6 @@ def main(
136253
m_float8 = copy.deepcopy(m_ref)
137254
swap_linear_with_float8_linear(m_float8, linear_cls)
138255

139-
input_tensor = torch.randn(
140-
params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True
141-
)
142-
143256
def ref_forw_backward(x):
144257
out = m_ref(x)
145258
out.sum().backward()
@@ -148,6 +261,8 @@ def float8_forw(x):
148261
out = m_float8(x)
149262
return out
150263

264+
sync_amax_history = sync_float8_amax_and_scale_history
265+
151266
def float8_forw_backward_wrapper(x):
152267
# sync_float8_amax_and_scale_history is not full graph torch
153268
# compile friendly, so we add a high level wrapper to allow
@@ -156,7 +271,7 @@ def float8_forw_backward_wrapper(x):
156271
# TODO(future): make this better
157272
if linear_requires_sync(linear_type):
158273
with record_function("scale_amax_and_scales"):
159-
sync_float8_amax_and_scale_history(m_float8)
274+
sync_amax_history(m_float8)
160275
out = float8_forw(x)
161276

162277
# out.sum().backward() is also not torch.compile fullgraph
@@ -165,30 +280,106 @@ def float8_forw_backward_wrapper(x):
165280
out.sum().backward()
166281

167282
if compile:
168-
ref_forw_backward = torch.compile(ref_forw_backward)
283+
m_ref = torch.compile(m_ref, fullgraph=True)
169284
float8_forw = torch.compile(float8_forw, fullgraph=True)
170-
171-
for _ in range(5):
172-
ref_forw_backward(input_tensor)
173-
float8_forw_backward_wrapper(input_tensor)
174-
175-
# Profile Reference Model
176-
ref_suffix = f"_ref_compile_{compile}.json"
177-
profile_config = ProfileConfig(
178-
profile_path_prefix + ref_suffix, ref_suffix, iters=5, warmup_iters=5, sync=True
285+
# Note: it's faster to compile the combination of sync_amax_history wit
286+
# forward because we only look up from dynamo cache once.
287+
# However, compiling the sync function separately makes it more
288+
# convenient to analyze the total time spent on it.
289+
sync_amax_history = torch.compile(sync_amax_history)
290+
291+
# warm up
292+
for _ in range(1):
293+
if dtype_filter != "float8":
294+
ref_forw_backward(input_tensor)
295+
if dtype_filter != "bfloat16":
296+
float8_forw_backward_wrapper(input_tensor)
297+
298+
profile_iters = 5
299+
ref_times, float8_times = None, None
300+
data = []
301+
302+
if dtype_filter != "float8":
303+
# Profile Reference Model
304+
print("profiling ref")
305+
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
306+
ref_path = profile_path_prefix + ref_suffix
307+
profile_config = ProfileConfig(
308+
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
309+
)
310+
p = profile_function(profile_config, ref_forw_backward, input_tensor)
311+
print(f"saved {ref_path}")
312+
ref_times = profiler_output_to_time_by_kernel_name(p)
313+
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
314+
for k, v in ref_times.items():
315+
v_ms = v / 1e3 / profile_iters
316+
data.append(
317+
["0_ref", k, kernel_name_to_category(k), v_ms, v_ms / total_time_ms]
318+
)
319+
320+
if dtype_filter != "bfloat16":
321+
# Profile Float8 Model
322+
print("profiling float8")
323+
float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json"
324+
float8_path = profile_path_prefix + float8_suffix
325+
profile_config = ProfileConfig(
326+
float8_path,
327+
float8_suffix,
328+
iters=profile_iters,
329+
warmup_iters=2,
330+
sync=True,
331+
)
332+
p = profile_function(profile_config, float8_forw_backward_wrapper, input_tensor)
333+
print(f"saved {float8_path}")
334+
float8_times = profiler_output_to_time_by_kernel_name(p)
335+
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
336+
for k, v in float8_times.items():
337+
v_ms = v / 1e3 / profile_iters
338+
data.append(
339+
[
340+
"1_float8",
341+
k,
342+
kernel_name_to_category(k),
343+
v / 1e3 / profile_iters,
344+
v_ms / total_time_ms,
345+
]
346+
)
347+
348+
# get the time spent per user annotation
349+
sync_time_us = profiler_output_to_gpu_time_for_key(p, "scale_amax_and_scales")
350+
sync_time_ms = sync_time_us / profile_iters / 1e3
351+
print(f"Sync time ms: {sync_time_ms}")
352+
353+
df = pd.DataFrame(
354+
data, columns=["experiment", "kernel", "category", "time_ms", "pct_gpu_time"]
179355
)
180-
profile_function(profile_config, ref_forw_backward, input_tensor)
181-
182-
# Profile Float8 Model
183-
float8_suffix = f"_float8_compile_{compile}_{linear_type}.json"
184-
profile_config = ProfileConfig(
185-
profile_path_prefix + float8_suffix,
186-
float8_suffix,
187-
iters=5,
188-
warmup_iters=5,
189-
sync=True,
356+
print("\nSummary of GPU time by CPU kernel\n\n", df)
357+
358+
# compare gemm and overhead time
359+
df_p = df.pivot_table(
360+
columns=["category"],
361+
index="experiment",
362+
values="time_ms",
363+
aggfunc="sum",
364+
fill_value=0,
365+
margins=True,
190366
)
191-
profile_function(profile_config, float8_forw_backward_wrapper, input_tensor)
367+
# drop last row, which has totals across ref + float8 which does not make sense
368+
df_p = df_p[:-1]
369+
df_p = df_p.transpose()
370+
371+
if dtype_filter == "both":
372+
df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"]
373+
df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"]
374+
375+
# calculate sync time as pct of total float time
376+
total_float8_ms = df_p.iloc[3]["1_float8"]
377+
sync_approx_ratio = sync_time_ms / total_float8_ms
378+
print(
379+
f"\nFloat8 amax/scale sync approx ratio of total time: {sync_approx_ratio:.3f}"
380+
)
381+
382+
print("\nSummary of time (ms) by kernel category\n\n", df_p)
192383

193384

194385
def invoke_main() -> None:

0 commit comments

Comments
 (0)