Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

QOL improvements to benchmarks/profile_linear_float8.py #281

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 58 additions & 48 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy
import random
from contextlib import nullcontext
from dataclasses import dataclass, field
Expand All @@ -12,15 +14,30 @@
import fire

import torch
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
get_float8_linear,
linear_requires_sync,
LinearType,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from torch.profiler import profile, ProfilerActivity, record_function


class LNLinear(torch.nn.Module):
def __init__(self, fc_dim1, fc_dim2):
super().__init__()
self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False)
self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False)

def forward(self, x):
x = self.ln(x)
x = self.fc(x)
return x


@dataclass
class ProfileConfig:
file_path: Optional[str] = None
Expand Down Expand Up @@ -77,65 +94,58 @@ def profile_function(


@dataclass(frozen=True)
class LinearParams:
class ModelParams:
M: int
K: int
N: int
input_bias: bool
ref_dtype: torch.dtype
layer_norm: bool = True
torch_compile: Optional[bool] = False


def main(profile_path: Path, compile: bool, linear_type: str):
profile_path = Path(profile_path)
assert profile_path.is_dir(), f"Path {profile_path} must be a directory"
params = LinearParams(
def main(
profile_path_prefix: Path,
compile: bool = True,
linear_type: str = "dynamic",
use_layer_norm: bool = False,
):
params = ModelParams(
M=4 * 4096,
K=8192,
N=7168,
input_bias=False,
ref_dtype=torch.bfloat16,
layer_norm=True,
torch_compile=compile,
layer_norm=use_layer_norm,
)
print(f"Compile is set to | {compile}")
print(f"Using Linear type: | {linear_type}")
print(f"Use layer norm is set to | {params.layer_norm}")
linear_ref = torch.nn.Linear(
params.K,
params.N,
bias=params.input_bias,
device="cuda",
dtype=params.ref_dtype,
)

device = "cuda"
if params.layer_norm:
m_ref = LNLinear(params.K, params.N)
else:
m_ref = torch.nn.Sequential(
torch.nn.Linear(params.K, params.N, bias=False),
)
m_ref = m_ref.to(device).to(params.ref_dtype)

linear_type = LinearType[linear_type.upper()]
linear_float8 = get_float8_linear(linear_type, linear_ref)
linear_cls = (
Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear
)

m_float8 = copy.deepcopy(m_ref)
swap_linear_with_float8_linear(m_float8, linear_cls)

input_tensor = torch.randn(
params.M, params.K, device="cuda", dtype=params.ref_dtype, requires_grad=True
)

if params.layer_norm:
ln = torch.nn.LayerNorm(
params.K, elementwise_affine=False, device="cuda", dtype=params.ref_dtype
)

def ref_forw_backward(x):
if params.layer_norm:
with record_function("layer_norm"):
x = ln(x)
with record_function("forward"):
out = linear_ref(x)
with record_function("backward"):
out.sum().backward()
out = m_ref(x)
out.sum().backward()

def float8_forw_backward(x):
if params.layer_norm:
with record_function("layer_norm"):
x = ln(x)
with record_function("forward"):
out = linear_float8(x)
def float8_forw(x):
out = m_float8(x)
return out

def float8_forw_backward_wrapper(x):
Expand All @@ -146,34 +156,34 @@ def float8_forw_backward_wrapper(x):
# TODO(future): make this better
if linear_requires_sync(linear_type):
with record_function("scale_amax_and_scales"):
sync_float8_amax_and_scale_history(linear_float8)
out = float8_forw_backward(x)
sync_float8_amax_and_scale_history(m_float8)
out = float8_forw(x)

# out.sum().backward() is also not torch.compile fullgraph
# friendly
with record_function("backward"):
out.sum().backward()

if params.torch_compile:
if compile:
ref_forw_backward = torch.compile(ref_forw_backward)
float8_forw_backward = torch.compile(float8_forw_backward, fullgraph=True)
float8_forw = torch.compile(float8_forw, fullgraph=True)

for _ in range(5):
ref_forw_backward(input_tensor)
float8_forw_backward_wrapper(input_tensor)

# Profile Reference Linear
ref_string = f"linear_ref_dtype_{params.ref_dtype}_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}.json"
# Profile Reference Model
ref_suffix = f"_ref_compile_{compile}.json"
profile_config = ProfileConfig(
str(profile_path / ref_string), ref_string, iters=5, warmup_iters=5, sync=True
profile_path_prefix + ref_suffix, ref_suffix, iters=5, warmup_iters=5, sync=True
)
profile_function(profile_config, ref_forw_backward, input_tensor)

# # Profile Float8 Linear
float8_string = f"linear_float8_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}_{linear_type}.json"
# Profile Float8 Model
float8_suffix = f"_float8_compile_{compile}_{linear_type}.json"
profile_config = ProfileConfig(
str(profile_path / float8_string),
float8_string,
profile_path_prefix + float8_suffix,
float8_suffix,
iters=5,
warmup_iters=5,
sync=True,
Expand All @@ -182,7 +192,7 @@ def float8_forw_backward_wrapper(x):


def invoke_main() -> None:
# Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles --compile=True --linear_type="dynamic"
# Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles/current_profile --compile=True --linear_type="dynamic"
fire.Fire(main)


Expand Down
Loading