Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ef585cf

Browse files
committed
add norm_ffn_norm to profile script
Summary: This PR adds an example FFN with the preceding and subsequent norms to the profile script. I hope for this to speed up debugging of kernel performance on LLaMa. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8eb0020 Pull Request resolved: #282
1 parent b240ce7 commit ef585cf

File tree

1 file changed

+131
-28
lines changed

1 file changed

+131
-28
lines changed

benchmarks/profile_linear_float8.py

Lines changed: 131 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import fire
1515

1616
import torch
17+
import torch.nn as nn
18+
import torch.nn.functional as F
1719
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1820
from float8_experimental.float8_linear import Float8Linear
1921
from float8_experimental.float8_linear_utils import (
@@ -38,6 +40,105 @@ def forward(self, x):
3840
return x
3941

4042

43+
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
44+
class RMSNorm(nn.Module):
45+
"""
46+
Initialize the RMSNorm normalization layer.
47+
48+
Args:
49+
dim (int): The dimension of the input tensor.
50+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
51+
52+
Attributes:
53+
eps (float): A small value added to the denominator for numerical stability.
54+
weight (nn.Parameter): Learnable scaling parameter.
55+
56+
"""
57+
58+
def __init__(self, dim: int, eps: float = 1e-6):
59+
super().__init__()
60+
self.eps = eps
61+
self.weight = nn.Parameter(torch.ones(dim))
62+
63+
def _norm(self, x: torch.Tensor):
64+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
65+
66+
def forward(self, x: torch.Tensor):
67+
output = self._norm(x.float()).type_as(x)
68+
return output * self.weight
69+
70+
def reset_parameters(self):
71+
torch.nn.init.ones_(self.weight) # type: ignore
72+
73+
74+
# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py
75+
class FeedForward(nn.Module):
76+
"""
77+
FeedForward module
78+
79+
Args:
80+
dim (int): Input dimension.
81+
hidden_dim (int): Hidden dimension of the feedforward layer.
82+
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
83+
ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None.
84+
85+
Attributes:
86+
w1 (Linear): Linear transformation for the first layer.
87+
w2 (Linear): Linear transformation for the second layer.
88+
w3 (Linear): Linear transformation for the third layer.
89+
90+
"""
91+
92+
def __init__(
93+
self,
94+
dim: int,
95+
hidden_dim: int,
96+
multiple_of: int,
97+
ffn_dim_multiplier: Optional[float],
98+
):
99+
super().__init__()
100+
hidden_dim = int(2 * hidden_dim / 3)
101+
# custom dim factor multiplier
102+
if ffn_dim_multiplier is not None:
103+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
104+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
105+
106+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
107+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
108+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
109+
110+
def forward(self, x):
111+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
112+
113+
def init_weights(self, init_std: float):
114+
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
115+
for linear in (self.w2, self.w3):
116+
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
117+
118+
119+
class NormFFNResidualNorm(nn.Module):
120+
"""
121+
A fragment representing the end of TransformerBlock n and the start
122+
of TransformerBlock n + 1, intended to include the fusions relevant
123+
to float8 gemms in the FFN module in forward and backward.
124+
"""
125+
126+
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier):
127+
super().__init__()
128+
self.ffn_norm = RMSNorm(dim)
129+
self.ffn = FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier)
130+
self.attn_norm = RMSNorm(dim)
131+
132+
def forward(self, h):
133+
# end of transformer block n
134+
x = self.ffn_norm(h)
135+
x = self.ffn(x)
136+
x = h + x
137+
# start of transformer block n + 1
138+
x = self.attn_norm(x)
139+
return x
140+
141+
41142
@dataclass
42143
class ProfileConfig:
43144
file_path: Optional[str] = None
@@ -93,40 +194,46 @@ def profile_function(
93194
return prof
94195

95196

96-
@dataclass(frozen=True)
97-
class ModelParams:
98-
M: int
99-
K: int
100-
N: int
101-
ref_dtype: torch.dtype
102-
layer_norm: bool = True
103-
104-
105197
def main(
106198
profile_path_prefix: Path,
107199
compile: bool = True,
108200
linear_type: str = "dynamic",
109-
use_layer_norm: bool = False,
201+
model_type: str = "linear",
110202
):
111-
params = ModelParams(
112-
M=4 * 4096,
113-
K=8192,
114-
N=7168,
115-
ref_dtype=torch.bfloat16,
116-
layer_norm=use_layer_norm,
117-
)
203+
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
204+
118205
print(f"Compile is set to | {compile}")
119206
print(f"Using Linear type: | {linear_type}")
120-
print(f"Use layer norm is set to | {params.layer_norm}")
207+
print(f"model_type is set to | {model_type}")
121208

122209
device = "cuda"
123-
if params.layer_norm:
124-
m_ref = LNLinear(params.K, params.N)
210+
ref_dtype = torch.bfloat16
211+
if model_type == "ln_linear":
212+
M, K, N = 4 * 4096, 8192, 7168
213+
m_ref = LNLinear(K, N)
214+
input_tensor = torch.randn(
215+
M, K, device=device, dtype=ref_dtype, requires_grad=True
216+
)
217+
elif model_type == "norm_ffn_norm":
218+
m_ref = NormFFNResidualNorm(
219+
dim=4096,
220+
hidden_dim=16384,
221+
multiple_of=1024,
222+
ffn_dim_multiplier=1.3,
223+
)
224+
input_tensor = torch.randn(
225+
1, 8192, 4096, device=device, dtype=ref_dtype
226+
).requires_grad_()
125227
else:
228+
M, K, N = 4 * 4096, 8192, 7168
126229
m_ref = torch.nn.Sequential(
127-
torch.nn.Linear(params.K, params.N, bias=False),
230+
torch.nn.Linear(K, N, bias=False),
128231
)
129-
m_ref = m_ref.to(device).to(params.ref_dtype)
232+
input_tensor = torch.randn(
233+
M, K, device=device, dtype=ref_dtype, requires_grad=True
234+
)
235+
236+
m_ref = m_ref.to(device).to(ref_dtype)
130237

131238
linear_type = LinearType[linear_type.upper()]
132239
linear_cls = (
@@ -136,10 +243,6 @@ def main(
136243
m_float8 = copy.deepcopy(m_ref)
137244
swap_linear_with_float8_linear(m_float8, linear_cls)
138245

139-
input_tensor = torch.randn(
140-
params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True
141-
)
142-
143246
def ref_forw_backward(x):
144247
out = m_ref(x)
145248
out.sum().backward()
@@ -173,14 +276,14 @@ def float8_forw_backward_wrapper(x):
173276
float8_forw_backward_wrapper(input_tensor)
174277

175278
# Profile Reference Model
176-
ref_suffix = f"_ref_compile_{compile}.json"
279+
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
177280
profile_config = ProfileConfig(
178281
profile_path_prefix + ref_suffix, ref_suffix, iters=5, warmup_iters=5, sync=True
179282
)
180283
profile_function(profile_config, ref_forw_backward, input_tensor)
181284

182285
# Profile Float8 Model
183-
float8_suffix = f"_float8_compile_{compile}_{linear_type}.json"
286+
float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json"
184287
profile_config = ProfileConfig(
185288
profile_path_prefix + float8_suffix,
186289
float8_suffix,

0 commit comments

Comments
 (0)