Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 9ba2a8e

Browse files
vkuzofacebook-github-bot
authored andcommitted
QOL improvements to benchmarks/profile_linear_float8.py (#281)
Summary: Pull Request resolved: #281 Cleaning up this script in preparation of adding some more comprehensive benchmarking to target commonly occuring fusions. Reviewed By: drisspg Differential Revision: D59163494 fbshipit-source-id: 4b4bd21a5b65e0704d360e6f79d24142a9e35ad4
1 parent b5a444a commit 9ba2a8e

File tree

1 file changed

+58
-48
lines changed

1 file changed

+58
-48
lines changed

benchmarks/profile_linear_float8.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
68
import random
79
from contextlib import nullcontext
810
from dataclasses import dataclass, field
@@ -12,15 +14,30 @@
1214
import fire
1315

1416
import torch
17+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
18+
from float8_experimental.float8_linear import Float8Linear
1519
from float8_experimental.float8_linear_utils import (
1620
get_float8_linear,
1721
linear_requires_sync,
1822
LinearType,
23+
swap_linear_with_float8_linear,
1924
sync_float8_amax_and_scale_history,
2025
)
2126
from torch.profiler import profile, ProfilerActivity, record_function
2227

2328

29+
class LNLinear(torch.nn.Module):
30+
def __init__(self, fc_dim1, fc_dim2):
31+
super().__init__()
32+
self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False)
33+
self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False)
34+
35+
def forward(self, x):
36+
x = self.ln(x)
37+
x = self.fc(x)
38+
return x
39+
40+
2441
@dataclass
2542
class ProfileConfig:
2643
file_path: Optional[str] = None
@@ -77,65 +94,58 @@ def profile_function(
7794

7895

7996
@dataclass(frozen=True)
80-
class LinearParams:
97+
class ModelParams:
8198
M: int
8299
K: int
83100
N: int
84-
input_bias: bool
85101
ref_dtype: torch.dtype
86102
layer_norm: bool = True
87-
torch_compile: Optional[bool] = False
88103

89104

90-
def main(profile_path: Path, compile: bool, linear_type: str):
91-
profile_path = Path(profile_path)
92-
assert profile_path.is_dir(), f"Path {profile_path} must be a directory"
93-
params = LinearParams(
105+
def main(
106+
profile_path_prefix: Path,
107+
compile: bool = True,
108+
linear_type: str = "dynamic",
109+
use_layer_norm: bool = False,
110+
):
111+
params = ModelParams(
94112
M=4 * 4096,
95113
K=8192,
96114
N=7168,
97-
input_bias=False,
98115
ref_dtype=torch.bfloat16,
99-
layer_norm=True,
100-
torch_compile=compile,
116+
layer_norm=use_layer_norm,
101117
)
102118
print(f"Compile is set to | {compile}")
103119
print(f"Using Linear type: | {linear_type}")
104120
print(f"Use layer norm is set to | {params.layer_norm}")
105-
linear_ref = torch.nn.Linear(
106-
params.K,
107-
params.N,
108-
bias=params.input_bias,
109-
device="cuda",
110-
dtype=params.ref_dtype,
111-
)
121+
122+
device = "cuda"
123+
if params.layer_norm:
124+
m_ref = LNLinear(params.K, params.N)
125+
else:
126+
m_ref = torch.nn.Sequential(
127+
torch.nn.Linear(params.K, params.N, bias=False),
128+
)
129+
m_ref = m_ref.to(device).to(params.ref_dtype)
130+
112131
linear_type = LinearType[linear_type.upper()]
113-
linear_float8 = get_float8_linear(linear_type, linear_ref)
132+
linear_cls = (
133+
Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear
134+
)
135+
136+
m_float8 = copy.deepcopy(m_ref)
137+
swap_linear_with_float8_linear(m_float8, linear_cls)
114138

115139
input_tensor = torch.randn(
116140
params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True
117141
)
118142

119-
if params.layer_norm:
120-
ln = torch.nn.LayerNorm(
121-
params.K, elementwise_affine=False, device="cuda", dtype=params.ref_dtype
122-
)
123-
124143
def ref_forw_backward(x):
125-
if params.layer_norm:
126-
with record_function("layer_norm"):
127-
x = ln(x)
128-
with record_function("forward"):
129-
out = linear_ref(x)
130-
with record_function("backward"):
131-
out.sum().backward()
144+
out = m_ref(x)
145+
out.sum().backward()
132146

133-
def float8_forw_backward(x):
134-
if params.layer_norm:
135-
with record_function("layer_norm"):
136-
x = ln(x)
137-
with record_function("forward"):
138-
out = linear_float8(x)
147+
def float8_forw(x):
148+
out = m_float8(x)
139149
return out
140150

141151
def float8_forw_backward_wrapper(x):
@@ -146,34 +156,34 @@ def float8_forw_backward_wrapper(x):
146156
# TODO(future): make this better
147157
if linear_requires_sync(linear_type):
148158
with record_function("scale_amax_and_scales"):
149-
sync_float8_amax_and_scale_history(linear_float8)
150-
out = float8_forw_backward(x)
159+
sync_float8_amax_and_scale_history(m_float8)
160+
out = float8_forw(x)
151161

152162
# out.sum().backward() is also not torch.compile fullgraph
153163
# friendly
154164
with record_function("backward"):
155165
out.sum().backward()
156166

157-
if params.torch_compile:
167+
if compile:
158168
ref_forw_backward = torch.compile(ref_forw_backward)
159-
float8_forw_backward = torch.compile(float8_forw_backward, fullgraph=True)
169+
float8_forw = torch.compile(float8_forw, fullgraph=True)
160170

161171
for _ in range(5):
162172
ref_forw_backward(input_tensor)
163173
float8_forw_backward_wrapper(input_tensor)
164174

165-
# Profile Reference Linear
166-
ref_string = f"linear_ref_dtype_{params.ref_dtype}_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}.json"
175+
# Profile Reference Model
176+
ref_suffix = f"_ref_compile_{compile}.json"
167177
profile_config = ProfileConfig(
168-
str(profile_path / ref_string), ref_string, iters=5, warmup_iters=5, sync=True
178+
profile_path_prefix + ref_suffix, ref_suffix, iters=5, warmup_iters=5, sync=True
169179
)
170180
profile_function(profile_config, ref_forw_backward, input_tensor)
171181

172-
# # Profile Float8 Linear
173-
float8_string = f"linear_float8_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}_{linear_type}.json"
182+
# Profile Float8 Model
183+
float8_suffix = f"_float8_compile_{compile}_{linear_type}.json"
174184
profile_config = ProfileConfig(
175-
str(profile_path / float8_string),
176-
float8_string,
185+
profile_path_prefix + float8_suffix,
186+
float8_suffix,
177187
iters=5,
178188
warmup_iters=5,
179189
sync=True,
@@ -182,7 +192,7 @@ def float8_forw_backward_wrapper(x):
182192

183193

184194
def invoke_main() -> None:
185-
# Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles --compile=True --linear_type="dynamic"
195+
# Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles/current_profile --compile=True --linear_type="dynamic"
186196
fire.Fire(main)
187197

188198

0 commit comments

Comments
 (0)