Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit c434c82

Browse files
committed
add norm_ffn_norm to profile script
Summary: This PR adds an example FFN with the preceding and subsequent norms to the profile script. I hope for this to speed up debugging of kernel performance on LLaMa. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 7f7ea6b Pull Request resolved: #282
1 parent b240ce7 commit c434c82

File tree

2 files changed

+303
-52
lines changed

2 files changed

+303
-52
lines changed

benchmarks/profile_linear_float8.py

Lines changed: 256 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
from typing import Callable, Optional
1313

1414
import fire
15+
import pandas as pd
1516

1617
import torch
18+
import torch.nn as nn
19+
import torch.nn.functional as F
1720
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1821
from float8_experimental.float8_linear import Float8Linear
1922
from float8_experimental.float8_linear_utils import (
@@ -24,6 +27,15 @@
2427
sync_float8_amax_and_scale_history,
2528
)
2629
from torch.profiler import profile, ProfilerActivity, record_function
30+
from utils import (
31+
profiler_output_to_gpu_time_for_key,
32+
profiler_output_to_time_by_kernel_name,
33+
)
34+
35+
# don't truncate long kernel names
36+
pd.options.display.max_colwidth = 100
37+
# display 3 trailing decimal points for floats
38+
pd.set_option("display.float_format", "{:.3f}".format)
2739

2840

2941
class LNLinear(torch.nn.Module):
@@ -38,6 +50,105 @@ def forward(self, x):
3850
return x
3951

4052

53+
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
54+
class RMSNorm(nn.Module):
55+
"""
56+
Initialize the RMSNorm normalization layer.
57+
58+
Args:
59+
dim (int): The dimension of the input tensor.
60+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
61+
62+
Attributes:
63+
eps (float): A small value added to the denominator for numerical stability.
64+
weight (nn.Parameter): Learnable scaling parameter.
65+
66+
"""
67+
68+
def __init__(self, dim: int, eps: float = 1e-6):
69+
super().__init__()
70+
self.eps = eps
71+
self.weight = nn.Parameter(torch.ones(dim))
72+
73+
def _norm(self, x: torch.Tensor):
74+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
75+
76+
def forward(self, x: torch.Tensor):
77+
output = self._norm(x.float()).type_as(x)
78+
return output * self.weight
79+
80+
def reset_parameters(self):
81+
torch.nn.init.ones_(self.weight) # type: ignore
82+
83+
84+
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
85+
class FeedForward(nn.Module):
86+
"""
87+
FeedForward module
88+
89+
Args:
90+
dim (int): Input dimension.
91+
hidden_dim (int): Hidden dimension of the feedforward layer.
92+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
93+
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
94+
95+
Attributes:
96+
w1 (Linear): Linear transformation for the first layer.
97+
w2 (Linear): Linear transformation for the second layer.
98+
w3 (Linear): Linear transformation for the third layer.
99+
100+
"""
101+
102+
def __init__(
103+
self,
104+
dim: int,
105+
hidden_dim: int,
106+
multiple_of: int,
107+
ffn_dim_multiplier: Optional[float],
108+
):
109+
super().__init__()
110+
hidden_dim = int(2 * hidden_dim / 3)
111+
# custom dim factor multiplier
112+
if ffn_dim_multiplier is not None:
113+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
114+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
115+
116+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
117+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
118+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
119+
120+
def forward(self, x):
121+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
122+
123+
def init_weights(self, init_std: float):
124+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
125+
for linear in (self.w2, self.w3):
126+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
127+
128+
129+
class NormFFNResidualNorm(nn.Module):
130+
"""
131+
A fragment representing the end of TransformerBlock n and the start
132+
of TransformerBlock n + 1, intended to include the fusions relevant
133+
to float8 gemms in the FFN module in forward and backward.
134+
"""
135+
136+
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier):
137+
super().__init__()
138+
self.ffn_norm = RMSNorm(dim)
139+
self.ffn = FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier)
140+
self.attn_norm = RMSNorm(dim)
141+
142+
def forward(self, h):
143+
# end of transformer block n
144+
x = self.ffn_norm(h)
145+
x = self.ffn(x)
146+
x = h + x
147+
# start of transformer block n + 1
148+
x = self.attn_norm(x)
149+
return x
150+
151+
41152
@dataclass
42153
class ProfileConfig:
43154
file_path: Optional[str] = None
@@ -87,46 +198,51 @@ def profile_function(
87198
if config.file_path is not None:
88199
prof.export_chrome_trace(config.file_path)
89200

90-
if config.file_path is None:
91-
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
92-
93201
return prof
94202

95203

96-
@dataclass(frozen=True)
97-
class ModelParams:
98-
M: int
99-
K: int
100-
N: int
101-
ref_dtype: torch.dtype
102-
layer_norm: bool = True
103-
104-
105204
def main(
106205
profile_path_prefix: Path,
107206
compile: bool = True,
108207
linear_type: str = "dynamic",
109-
use_layer_norm: bool = False,
208+
model_type: str = "linear",
209+
dtype_filter: str = "both",
110210
):
111-
params = ModelParams(
112-
M=4 * 4096,
113-
K=8192,
114-
N=7168,
115-
ref_dtype=torch.bfloat16,
116-
layer_norm=use_layer_norm,
117-
)
211+
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
212+
assert dtype_filter in ("both", "float8", "bfloat16")
213+
118214
print(f"Compile is set to | {compile}")
119215
print(f"Using Linear type: | {linear_type}")
120-
print(f"Use layer norm is set to | {params.layer_norm}")
216+
print(f"model_type is set to | {model_type}")
121217

122218
device = "cuda"
123-
if params.layer_norm:
124-
m_ref = LNLinear(params.K, params.N)
219+
ref_dtype = torch.bfloat16
220+
if model_type == "ln_linear":
221+
M, K, N = 4 * 4096, 8192, 7168
222+
m_ref = LNLinear(K, N)
223+
input_tensor = torch.randn(
224+
M, K, device=device, dtype=ref_dtype, requires_grad=True
225+
)
226+
elif model_type == "norm_ffn_norm":
227+
m_ref = NormFFNResidualNorm(
228+
dim=4096,
229+
hidden_dim=16384,
230+
multiple_of=1024,
231+
ffn_dim_multiplier=1.3,
232+
)
233+
input_tensor = torch.randn(
234+
1, 8192, 4096, device=device, dtype=ref_dtype
235+
).requires_grad_()
125236
else:
237+
M, K, N = 4 * 4096, 8192, 7168
126238
m_ref = torch.nn.Sequential(
127-
torch.nn.Linear(params.K, params.N, bias=False),
239+
torch.nn.Linear(K, N, bias=False),
240+
)
241+
input_tensor = torch.randn(
242+
M, K, device=device, dtype=ref_dtype, requires_grad=True
128243
)
129-
m_ref = m_ref.to(device).to(params.ref_dtype)
244+
245+
m_ref = m_ref.to(device).to(ref_dtype)
130246

131247
linear_type = LinearType[linear_type.upper()]
132248
linear_cls = (
@@ -136,10 +252,6 @@ def main(
136252
m_float8 = copy.deepcopy(m_ref)
137253
swap_linear_with_float8_linear(m_float8, linear_cls)
138254

139-
input_tensor = torch.randn(
140-
params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True
141-
)
142-
143255
def ref_forw_backward(x):
144256
out = m_ref(x)
145257
out.sum().backward()
@@ -148,6 +260,8 @@ def float8_forw(x):
148260
out = m_float8(x)
149261
return out
150262

263+
sync_amax_history = sync_float8_amax_and_scale_history
264+
151265
def float8_forw_backward_wrapper(x):
152266
# sync_float8_amax_and_scale_history is not full graph torch
153267
# compile friendly, so we add a high level wrapper to allow
@@ -156,7 +270,7 @@ def float8_forw_backward_wrapper(x):
156270
# TODO(future): make this better
157271
if linear_requires_sync(linear_type):
158272
with record_function("scale_amax_and_scales"):
159-
sync_float8_amax_and_scale_history(m_float8)
273+
sync_amax_history(m_float8)
160274
out = float8_forw(x)
161275

162276
# out.sum().backward() is also not torch.compile fullgraph
@@ -165,30 +279,120 @@ def float8_forw_backward_wrapper(x):
165279
out.sum().backward()
166280

167281
if compile:
168-
ref_forw_backward = torch.compile(ref_forw_backward)
282+
m_ref = torch.compile(m_ref, fullgraph=True)
169283
float8_forw = torch.compile(float8_forw, fullgraph=True)
284+
# Note: it's faster to compile the combination of sync_amax_history wit
285+
# forward because we only look up from dynamo cache once.
286+
# However, compiling the sync function separately makes it more
287+
# convenient to analyze the total time spent on it.
288+
sync_amax_history = torch.compile(sync_amax_history)
289+
290+
# warm up
291+
for _ in range(1):
292+
if dtype_filter != "float8":
293+
ref_forw_backward(input_tensor)
294+
if dtype_filter != "bfloat16":
295+
float8_forw_backward_wrapper(input_tensor)
296+
297+
# profile_iters = 5
298+
profile_iters = 2
299+
ref_times, float8_times = None, None
300+
301+
if dtype_filter != "float8":
302+
# Profile Reference Model
303+
print("profiling ref")
304+
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
305+
ref_path = profile_path_prefix + ref_suffix
306+
profile_config = ProfileConfig(
307+
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
308+
)
309+
p = profile_function(profile_config, ref_forw_backward, input_tensor)
310+
print(f"saved {ref_path}")
311+
ref_times = profiler_output_to_time_by_kernel_name(p)
312+
313+
if dtype_filter != "bfloat16":
314+
# Profile Float8 Model
315+
print("profiling float8")
316+
float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json"
317+
float8_path = profile_path_prefix + float8_suffix
318+
profile_config = ProfileConfig(
319+
float8_path,
320+
float8_suffix,
321+
iters=profile_iters,
322+
warmup_iters=2,
323+
sync=True,
324+
)
325+
p = profile_function(profile_config, float8_forw_backward_wrapper, input_tensor)
326+
print(f"saved {float8_path}")
327+
float8_times = profiler_output_to_time_by_kernel_name(p)
328+
329+
# get the time spent per user annotation
330+
sync_time_us = profiler_output_to_gpu_time_for_key(p, "scale_amax_and_scales")
331+
sync_time_ms = sync_time_us / profile_iters / 1e3
332+
print(f"Sync time ms: {sync_time_ms}")
333+
334+
if dtype_filter == "both":
335+
data = []
336+
337+
def kernel_name_to_category(k):
338+
# number prefix is for easy sorting
339+
if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"):
340+
return "0_gemm"
341+
elif (
342+
# max(abs(tensor))
343+
("abs" in k and "max" in k)
344+
or
345+
# casting pointwise to float8
346+
("clamp" in k)
347+
or
348+
# things related to scaled_mm
349+
("scaled_mm" in k)
350+
or
351+
# syncing amaxes and scales
352+
("roll" in k)
353+
):
354+
# note: the above filter is approximate and will give false
355+
# positives if model code contains other code to abs/max/clamp
356+
return "1_f8_overhead"
357+
return "2_other"
358+
359+
for k, v in ref_times.items():
360+
data.append(
361+
["0_ref", k, kernel_name_to_category(k), v / 1e3 / profile_iters]
362+
)
363+
for k, v in float8_times.items():
364+
data.append(
365+
["1_float8", k, kernel_name_to_category(k), v / 1e3 / profile_iters]
366+
)
367+
368+
df = pd.DataFrame(data, columns=["experiment", "kernel", "category", "time_ms"])
369+
print("\nSummary of GPU time by CPU kernel\n\n", df)
370+
371+
# compare gemm and overhead time
372+
df_p = df.pivot_table(
373+
columns=["category"],
374+
index="experiment",
375+
values="time_ms",
376+
aggfunc="sum",
377+
fill_value=0,
378+
margins=True,
379+
)
380+
# drop last row, which has totals across ref + float8 which does not make sense
381+
df_p = df_p[:-1]
382+
383+
df_p = df_p.transpose()
384+
df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"]
385+
df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"]
386+
print(
387+
"\nSummary of time (ms) by kernel category, across ref and float8\n\n", df_p
388+
)
170389

171-
for _ in range(5):
172-
ref_forw_backward(input_tensor)
173-
float8_forw_backward_wrapper(input_tensor)
174-
175-
# Profile Reference Model
176-
ref_suffix = f"_ref_compile_{compile}.json"
177-
profile_config = ProfileConfig(
178-
profile_path_prefix + ref_suffix, ref_suffix, iters=5, warmup_iters=5, sync=True
179-
)
180-
profile_function(profile_config, ref_forw_backward, input_tensor)
181-
182-
# Profile Float8 Model
183-
float8_suffix = f"_float8_compile_{compile}_{linear_type}.json"
184-
profile_config = ProfileConfig(
185-
profile_path_prefix + float8_suffix,
186-
float8_suffix,
187-
iters=5,
188-
warmup_iters=5,
189-
sync=True,
190-
)
191-
profile_function(profile_config, float8_forw_backward_wrapper, input_tensor)
390+
# calculate sync time as pct of total float time
391+
total_float8_ms = df_p.iloc[3]["1_float8"]
392+
sync_approx_ratio = sync_time_ms / total_float8_ms
393+
print(
394+
f"\nFloat8 amax/scale sync approx ratio of total time: {sync_approx_ratio:.3f}"
395+
)
192396

193397

194398
def invoke_main() -> None:

0 commit comments

Comments
 (0)