You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In this blog, we will demonstrate how we achieve up to **50% throughput speedup** while achieving loss and evaluation benchmark parity in training over FSDPv1 bf16 training ([link](https://pytorch.org/blog/maximizing-training-throughput/)). We achieve this speedup by leveraging FSDP2, DTensor, and torch.compile with torchao’s float8 via linear layer updates (compute), and float8 all_gathers for weight communication. We showcase these improvements across a spectrum of Meta LLaMa model architecture sizes, ranging from small 1.8B model size all the way to 405B model size, making training faster than ever.
7
+
In this blog, we will demonstrate how we achieve up to **50% throughput speedup** while achieving loss and evaluation benchmark parity in training over [FSDP1 bf16 training](https://pytorch.org/blog/maximizing-training-throughput/). We achieve this speedup by leveraging FSDP2, DTensor, and torch.compile with torchao’s float8 via linear layer updates (compute), and float8 all_gathers for weight communication. We showcase these improvements across a spectrum of Meta LLaMa model architecture sizes, ranging from small 1.8B model size all the way to 405B model size, making training faster than ever.
8
8
9
-
We demonstrate these improvements using the Meta Llama3 architecture, and then perform model quality
10
-
11
-
studies at two scales: 100B tokens at 8B model size, and 50B tokens at 70B model size, which provide an exact comparison of float8 and bf16 training loss curves. We demonstrate that the loss curves result in identical loss convergence across these model training runs compared to the `bf16` counterpart. \
12
-
Further, we train a 3B model to 1T tokens using the FineWeb-edu dataset and run standard evaluation benchmarks to ensure that the model quality is intact and comparable to a `bf16` run.
9
+
We demonstrate these improvements using the Meta Llama3 architecture, and then perform model quality studies at two scales: 100B tokens at 8B model size, and 50B tokens at 70B model size, which provide an exact comparison of float8 and bf16 training loss curves. We demonstrate that the loss curves result in identical loss convergence across these model training runs compared to the `bf16` counterpart. Further, we train a 3B model to 1T tokens using the FineWeb-edu dataset and run standard evaluation benchmarks to ensure that the model quality is intact and comparable to a `bf16` run.
13
10
14
11
At IBM Research, we plan to adopt these capabilities for our data ablations to improve the number of experiments we can perform in a given GPU budget. Longer term, we will follow up with a larger scale model run to demonstrate the end-to-end feasibility of `float8` training.
15
12
16
13
17
14
## What is Float8?
18
15
19
-
The `float8` format for training models was introduced by NVIDIA, ARM, and Intel in a [2022 paper](https://arxiv.org/abs/2209.05433) which demonstrated the feasibility of training using lower precision float8, without sacrificing model quality. With the introduction of newer GPUs like the NVIDIA Hopper series, FP8 training became feasible with the potential of more than 2x improvement in training throughput due to native float8 tensor core support. \
20
-
There are a few challenges to realize this promise: \
16
+
The `float8` format for training models was introduced by NVIDIA, ARM, and Intel in a [2022 paper](https://arxiv.org/abs/2209.05433) which demonstrated the feasibility of training using lower precision float8, without sacrificing model quality. With the introduction of newer GPUs like the NVIDIA Hopper series, FP8 training became feasible with the potential of more than 2x improvement in training throughput due to native float8 tensor core support. There are a few challenges to realize this promise: \
21
17
(i) Enable the core model operations like `matmul` and `attention` in `float8`, \
22
18
(ii) Enable `float8` training in a distributed framework, and \
23
19
(iii) Enable weight communication between GPUs in `float8`. \
0 commit comments