Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Allow for modifying the scaled_mm compute #144

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,17 @@ for _ in range(N_ITER):
optimizer.step()
```

# code tips
# 🧭 Code Organization

* `float8_experimental/float8_linear.py` - `Float8Linear` (main user facing entry point for delayed scaling)
* `float8_experimental/float8_dynamic_linear.py` - `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
* `float8_experimental/float8_tensor.py` - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
* `float8_experimental/float8_linear.py`
- `Float8Linear` (main user facing entry point for delayed scaling)
* `float8_experimental/float8_dynamic_linear.py`
- `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
* `float8_experimental/float8_tensor.py`
- `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
- `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass

# testing
# Testing

```bash
# run single-GPU unit tests
Expand All @@ -117,7 +121,7 @@ pytest test/test_compile.py
./test/run_everything.sh
```

# benchmarking
# Benchmarking

```bash
# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
Expand All @@ -130,4 +134,3 @@ pytest test/test_compile.py

# License
PyTorch has a BSD 3-Clause License, as found in the LICENSE file.

22 changes: 15 additions & 7 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
from float8_experimental.float8_tensor import ScaledMMConfig
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
Expand Down Expand Up @@ -54,8 +55,8 @@ class Experiment:
ref_time_sec: float
float8_time_sec: float
dtype: torch.dtype
compiled: bool = False
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
compiled: bool
use_fast_accum: bool

# 3 Times since we are calculating forward backward
@property
Expand All @@ -74,7 +75,7 @@ def float8_tops_sec(self):

@property
def float8_pct_top_peak(self):
return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn]


def main(
Expand All @@ -95,9 +96,10 @@ def main(
}
input_bias = False
ref_dtypes = [torch.bfloat16, torch.float16]
use_fast_accum = [True, False]
experiment_list: List[Experiment] = []
for idx, (dtype, (name, (K, N))) in enumerate(
tqdm(list(product(ref_dtypes, name_to_shapes_70b.items())))
for idx, (dtype, fast_accum, (name, (K, N))) in enumerate(
tqdm(list(product(ref_dtypes, use_fast_accum, name_to_shapes_70b.items())))
):
if n_limit is not None and idx >= n_limit:
break
Expand All @@ -108,6 +110,10 @@ def main(
linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref), emulate=False
)
if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
else:
linear_float8.forward_config = ScaledMMConfig(False, False, False)

bsz, seq_len = 4, 4096
M = bsz * seq_len
Expand Down Expand Up @@ -155,6 +161,7 @@ def wrapper(*args, **kwargs):
float8_time,
dtype,
compile,
use_fast_accum=fast_accum,
)
print(experiment)
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
Expand All @@ -168,7 +175,7 @@ def wrapper(*args, **kwargs):
"N",
"ref_dtype",
"compiled",
"fp8_dtype",
"use_fast_accum",
"ref_time_sec",
"pt_fp8_time_sec",
"ref_tops_sec",
Expand All @@ -186,7 +193,7 @@ def wrapper(*args, **kwargs):
experiment.shape[2],
experiment.dtype,
experiment.compiled,
experiment.float_8_dtype,
experiment.use_fast_accum,
experiment.ref_time_sec,
experiment.float8_time_sec,
experiment.ref_tops_sec,
Expand Down Expand Up @@ -214,6 +221,7 @@ def wrapper(*args, **kwargs):
"shape",
"ref_dtype",
"compiled",
"use_fast_accum",
"ref_time_sec",
"pt_fp8_time_sec",
"pt_fp8_speedup",
Expand Down
14 changes: 8 additions & 6 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
Expand All @@ -27,9 +28,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
def forward(
ctx,
tensor,
emulate: bool,
mm_config: ScaledMMConfig,
):
ctx.emulate = emulate
ctx.mm_config = mm_config
return tensor

@staticmethod
Expand All @@ -39,7 +40,7 @@ def backward(ctx, gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, torch.float8_e5m2, ctx.emulate
gradY, gradY_scale, torch.float8_e5m2, mm_config=ctx.mm_config
)
return fp8_tensor, None

Expand Down Expand Up @@ -73,11 +74,11 @@ def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
inpt_tensor, scale, torch.float8_e4m3fn, mm_config=self.forward_config
)

def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)
return NoopFwToFloat8E5M2Bw.apply(gradY, self.backward_config)

@classmethod
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
Expand All @@ -97,5 +98,6 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
new_mod = cls(**super_kwargs)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
new_mod.backward_config = ScaledMMConfig(emulate, False)
return new_mod
42 changes: 28 additions & 14 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@

import torch

from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
to_fp8_no_autograd,
)

from float8_experimental.float8_utils import (
amax_history_to_scale,
Expand Down Expand Up @@ -73,12 +77,12 @@ def forward(
fp8_scale_dL_dY,
scale_fn_name,
is_amax_initialized,
emulate: bool,
mm_config: ScaledMMConfig,
):
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
ctx.scale_fn_name = scale_fn_name
ctx.is_amax_initialized = is_amax_initialized
ctx.emulate = emulate
ctx.mm_config = mm_config
return tensor

@staticmethod
Expand All @@ -99,7 +103,9 @@ def backward(ctx, go):

fp8_amax_dL_dY.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(go, fp8_scale_dL_dY, torch.float8_e5m2, ctx.emulate)
res = to_fp8_no_autograd(
go, fp8_scale_dL_dY, torch.float8_e5m2, mm_config=ctx.mm_config
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads

Expand Down Expand Up @@ -154,8 +160,9 @@ def __init__(self, *args, **kwargs):
)
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0]))

# Whether to emulate the fp8 matmul logic in float32
self.emulate = False
# Defines the behavior of the matmul in the forward and backward pass
self.forward_config = ScaledMMConfig()
self.backward_config = ScaledMMConfig()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to configure the two backward gemms separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat challenging , since as written today we dont have a very clean way of knowing which matmul is which

cc @bdhirsh maybe Im no thinking of something.

We have out = x@W where x = FLoat8Tensor and W = Float8Tensor.

Since W will not be used in calucalting the gradW you could tag some extra info on the activation float8tensor and since this gets used for backward this should get carried through to backwards calcs.

I think that this would be better as a follow up though since the logic gets spread out over multiple Float8Tensor instances.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since as written today we dont have a very clean way of knowing which matmul is which

yeah, this is weird because the config is really per gemm but we have to stick it on a tensor. How about something like

  1. local: given matmul(A, B), the B matmul (second argument) always overrides A (first argument).
  2. global: the float8 UX allows setting options for the 3 gemms, and under the hood maps it to be implemented via (1).

While not the most intuitive to implement, I think that could work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate PR sgtm, I do feel like we need to make all 3 gemms configurable before we lock the API down.


# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
Expand Down Expand Up @@ -216,7 +223,11 @@ def cast_x_to_float8(
is_amax_initialized,
)
x_fp8 = Float8Tensor.to_float8(
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, self.emulate
x,
self.fp8_scale_x,
torch.float8_e4m3fn,
self.fp8_amax_x,
self.forward_config,
)
return x_fp8

Expand All @@ -239,13 +250,11 @@ def cast_w_to_float8(
self.fp8_scale_w,
torch.float8_e4m3fn,
self.fp8_amax_w,
self.emulate,
self.forward_config,
)
return w_fp8

def cast_y_to_float8_in_bw(
self, y: torch.Tensor, emulate: bool = False
) -> torch.Tensor:
def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
scale_fn_name = self.recipe.scale_fn_name
y = NoopFwToFloat8E5M2Bw.apply(
y,
Expand All @@ -254,7 +263,7 @@ def cast_y_to_float8_in_bw(
self.fp8_scale_dL_dY,
scale_fn_name,
self.is_amax_initialized,
emulate,
self.backward_config,
)
return y

Expand Down Expand Up @@ -295,7 +304,7 @@ def forward(self, x):
y = torch.matmul(x_fp8, w_fp8.t())

# Cast gradY to float8_e5m2 during backward
y = self.cast_y_to_float8_in_bw(y, self.emulate)
y = self.cast_y_to_float8_in_bw(y)

if self.bias is not None:
y = y + self.bias.to(y.dtype)
Expand All @@ -318,7 +327,12 @@ def from_float(cls, mod, emulate: bool = False):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate

# Defines the behavior of the matmul in the forward and backward
# Forward we use fast_accum, backwards we do not
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
new_mod.backward_config = ScaledMMConfig(emulate, False)

# I think its okay to send all params and buffers to device
new_mod.to(mod.weight.device)
return new_mod
46 changes: 35 additions & 11 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import torch

from float8_experimental.float8_python_api import addmm_float8_unwrapped
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import (
Float8Tensor,
merge_mm_configs,
ScaledMMConfig,
)
from float8_experimental.float8_utils import is_row_major
from torch.utils._pytree import tree_map

Expand Down Expand Up @@ -41,7 +45,9 @@ def decorator(func):
)
def float8_desugar_op(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype, args[0]._emulate)
return Float8Tensor(
new_data, args[0]._scale, args[0]._orig_dtype, args[0]._mm_config
)


@implements([aten.sum.dim_IntList])
Expand Down Expand Up @@ -89,13 +95,22 @@ def float8_mm(aten_op, args, kwargs=None):
)
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
if a._emulate:
assert a._emulate == b._emulate
a_mm_config: ScaledMMConfig = a._mm_config
b_mm_config: ScaledMMConfig = b._mm_config
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
if mm_config.emulate:
return torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)[0]
tensor_out, amax = addmm_float8_unwrapped(
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None
a_data,
a_scale,
b_data,
b_scale,
output_dtype,
output_scale=None,
bias=None,
use_fast_accum=mm_config.use_fast_accum,
)
return tensor_out

Expand All @@ -113,14 +128,23 @@ def float8_addmm(aten_op, args, kwargs=None):
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
assert bias.dtype == output_dtype, "bias dtype must match output dtype"
if a._emulate:
assert a._emulate == b._emulate
a_mm_config: ScaledMMConfig = a._mm_config
b_mm_config: ScaledMMConfig = b._mm_config
mm_config: ScaledMMConfig = merge_mm_configs(a_mm_config, b_mm_config)
if mm_config.emulate:
out = torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)[0]
return out + bias
tensor_out, amax = addmm_float8_unwrapped(
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=bias
a_data,
a_scale,
b_data,
b_scale,
output_dtype,
output_scale=None,
bias=bias,
use_fast_accum=mm_config.use_fast_accum,
)
return tensor_out

Expand All @@ -145,7 +169,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
torch.bfloat16,
}, "Only support floating point conversion for autocast w/ Float8Tensor"
return Float8Tensor(
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._mm_config
)


Expand All @@ -170,7 +194,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
fp8_out = fp8_out.view(fp8_input._data.dtype)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._emulate
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
)


Expand All @@ -182,5 +206,5 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
fp8_data = fp8_input._data
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._emulate
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
)
Loading