-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Libtorch cuda graphs #2441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Libtorch cuda graphs #2441
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
00251a0
add libtorch_cuda_graphs tutorial
Aidyn-A c5dcb0b
minor fixes
Aidyn-A 48fd415
minor fixes
Aidyn-A f1763b2
remove unnecessary line
Aidyn-A 2546e0a
minor fixes
Aidyn-A e07989e
add stream_sync
Aidyn-A 0c9a6ba
rename files and folders for consistency
Aidyn-A af5807b
add more text
Aidyn-A ebe92d0
fix typos and better phrasing
Aidyn-A 63b4b63
Merge branch 'main' into libtorch_cuda_graphs
7218062
Merge branch 'pytorch:main' into libtorch_cuda_graphs
Aidyn-A 073021c
apply text comments
Aidyn-A 42d27d1
apply source comments
Aidyn-A a586265
use cout and apply clang-format
Aidyn-A 23ddd51
Merge branch 'main' into libtorch_cuda_graphs
46dd9ca
Apply suggestions from code review
6df8287
Require CMake >= 3.18
Aidyn-A File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
Using CUDA Graphs in PyTorch C++ API | ||
==================================== | ||
|
||
.. note:: | ||
|edit| View and edit this tutorial in `GitHub <https://github.com/pytorch/tutorials/blob/main/advanced_source/cpp_cuda_graphs.rst>`__. The full source code is available on `GitHub <https://github.com/pytorch/tutorials/blob/main/advanced_source/cpp_cuda_graphs>`__. | ||
|
||
Prerequisites: | ||
|
||
- `Using the PyTorch C++ Frontend <../advanced_source/cpp_frontend.html>`__ | ||
- `CUDA semantics <https://pytorch.org/docs/master/notes/cuda.html>`__ | ||
- Pytorch 2.0 or later | ||
- CUDA 11 or later | ||
|
||
NVIDIA’s CUDA Graphs have been a part of CUDA Toolkit library since the | ||
release of `version 10 <https://developer.nvidia.com/blog/cuda-graphs/>`_. | ||
They are capable of greatly reducing the CPU overhead increasing the | ||
performance of applications. | ||
|
||
In this tutorial, we will be focusing on using CUDA Graphs for `C++ | ||
frontend of PyTorch <https://pytorch.org/tutorials/advanced/cpp_frontend.html>`_. | ||
The C++ frontend is mostly utilized in production and deployment applications which | ||
are important parts of PyTorch use cases. Since `the first appearance | ||
<https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/>`_ | ||
the CUDA Graphs won users’ and developer’s hearts for being a very performant | ||
and at the same time simple-to-use tool. In fact, CUDA Graphs are used by default | ||
in ``torch.compile`` of PyTorch 2.0 to boost the productivity of training and inference. | ||
|
||
We would like to demonstrate CUDA Graphs usage on PyTorch’s `MNIST | ||
example <https://github.com/pytorch/examples/tree/main/cpp/mnist>`_. | ||
The usage of CUDA Graphs in LibTorch (C++ Frontend) is very similar to its | ||
`Python counterpart <https://pytorch.org/docs/main/notes/cuda.html#cuda-graphs>`_ | ||
but with some differences in syntax and functionality. | ||
|
||
Getting Started | ||
--------------- | ||
|
||
The main training loop consists of the several steps and depicted in the | ||
following code chunk: | ||
|
||
.. code-block:: cpp | ||
|
||
for (auto& batch : data_loader) { | ||
auto data = batch.data.to(device); | ||
auto targets = batch.target.to(device); | ||
optimizer.zero_grad(); | ||
auto output = model.forward(data); | ||
auto loss = torch::nll_loss(output, targets); | ||
loss.backward(); | ||
optimizer.step(); | ||
} | ||
|
||
The example above includes a forward pass, a backward pass, and weight updates. | ||
|
||
In this tutorial, we will be applying CUDA Graph on all the compute steps through the whole-network | ||
graph capture. But before doing so, we need to slightly modify the source code. What we need | ||
to do is preallocate tensors for reusing them in the main training loop. Here is an example | ||
implementation: | ||
|
||
.. code-block:: cpp | ||
|
||
torch::TensorOptions FloatCUDA = | ||
torch::TensorOptions(device).dtype(torch::kFloat); | ||
torch::TensorOptions LongCUDA = | ||
torch::TensorOptions(device).dtype(torch::kLong); | ||
|
||
torch::Tensor data = torch::zeros({kTrainBatchSize, 1, 28, 28}, FloatCUDA); | ||
torch::Tensor targets = torch::zeros({kTrainBatchSize}, LongCUDA); | ||
torch::Tensor output = torch::zeros({1}, FloatCUDA); | ||
torch::Tensor loss = torch::zeros({1}, FloatCUDA); | ||
|
||
for (auto& batch : data_loader) { | ||
data.copy_(batch.data); | ||
targets.copy_(batch.target); | ||
training_step(model, optimizer, data, targets, output, loss); | ||
} | ||
|
||
Where ``training_step`` simply consists of forward and backward passes with corresponding optimizer calls: | ||
|
||
.. code-block:: cpp | ||
|
||
void training_step( | ||
Net& model, | ||
torch::optim::Optimizer& optimizer, | ||
torch::Tensor& data, | ||
torch::Tensor& targets, | ||
torch::Tensor& output, | ||
torch::Tensor& loss) { | ||
optimizer.zero_grad(); | ||
output = model.forward(data); | ||
loss = torch::nll_loss(output, targets); | ||
loss.backward(); | ||
optimizer.step(); | ||
} | ||
|
||
PyTorch’s CUDA Graphs API is relying on Stream Capture which in our case would be used like this: | ||
|
||
.. code-block:: cpp | ||
|
||
at::cuda::CUDAGraph graph; | ||
at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool(); | ||
at::cuda::setCurrentCUDAStream(captureStream); | ||
|
||
graph.capture_begin(); | ||
training_step(model, optimizer, data, targets, output, loss); | ||
graph.capture_end(); | ||
|
||
Before the actual graph capture, it is important to run several warm-up iterations on side stream to | ||
prepare CUDA cache as well as CUDA libraries (like CUBLAS and CUDNN) that will be used during | ||
the training: | ||
|
||
.. code-block:: cpp | ||
|
||
at::cuda::CUDAStream warmupStream = at::cuda::getStreamFromPool(); | ||
at::cuda::setCurrentCUDAStream(warmupStream); | ||
for (int iter = 0; iter < num_warmup_iters; iter++) { | ||
training_step(model, optimizer, data, targets, output, loss); | ||
} | ||
|
||
After the successful graph capture, we can replace ``training_step(model, optimizer, data, targets, output, loss);`` | ||
call via ``graph.replay();`` to do the training step. | ||
|
||
Training Results | ||
---------------- | ||
|
||
Aidyn-A marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Taking the code for a spin we can see the following output from ordinary non-graphed training: | ||
|
||
.. code-block:: shell | ||
|
||
$ time ./mnist | ||
Train Epoch: 1 [59584/60000] Loss: 0.3921 | ||
Test set: Average loss: 0.2051 | Accuracy: 0.938 | ||
Train Epoch: 2 [59584/60000] Loss: 0.1826 | ||
Test set: Average loss: 0.1273 | Accuracy: 0.960 | ||
Train Epoch: 3 [59584/60000] Loss: 0.1796 | ||
Test set: Average loss: 0.1012 | Accuracy: 0.968 | ||
Train Epoch: 4 [59584/60000] Loss: 0.1603 | ||
Test set: Average loss: 0.0869 | Accuracy: 0.973 | ||
Train Epoch: 5 [59584/60000] Loss: 0.2315 | ||
Test set: Average loss: 0.0736 | Accuracy: 0.978 | ||
Train Epoch: 6 [59584/60000] Loss: 0.0511 | ||
Test set: Average loss: 0.0704 | Accuracy: 0.977 | ||
Train Epoch: 7 [59584/60000] Loss: 0.0802 | ||
Test set: Average loss: 0.0654 | Accuracy: 0.979 | ||
Train Epoch: 8 [59584/60000] Loss: 0.0774 | ||
Test set: Average loss: 0.0604 | Accuracy: 0.980 | ||
Train Epoch: 9 [59584/60000] Loss: 0.0669 | ||
Test set: Average loss: 0.0544 | Accuracy: 0.984 | ||
Train Epoch: 10 [59584/60000] Loss: 0.0219 | ||
Test set: Average loss: 0.0517 | Accuracy: 0.983 | ||
|
||
real 0m44.287s | ||
user 0m44.018s | ||
sys 0m1.116s | ||
|
||
While the training with the CUDA Graph produces the following output: | ||
|
||
.. code-block:: shell | ||
|
||
$ time ./mnist --use-train-graph | ||
Train Epoch: 1 [59584/60000] Loss: 0.4092 | ||
Test set: Average loss: 0.2037 | Accuracy: 0.938 | ||
Train Epoch: 2 [59584/60000] Loss: 0.2039 | ||
Test set: Average loss: 0.1274 | Accuracy: 0.961 | ||
Train Epoch: 3 [59584/60000] Loss: 0.1779 | ||
Test set: Average loss: 0.1017 | Accuracy: 0.968 | ||
Train Epoch: 4 [59584/60000] Loss: 0.1559 | ||
Test set: Average loss: 0.0871 | Accuracy: 0.972 | ||
Train Epoch: 5 [59584/60000] Loss: 0.2240 | ||
Test set: Average loss: 0.0735 | Accuracy: 0.977 | ||
Train Epoch: 6 [59584/60000] Loss: 0.0520 | ||
Test set: Average loss: 0.0710 | Accuracy: 0.978 | ||
Train Epoch: 7 [59584/60000] Loss: 0.0935 | ||
Test set: Average loss: 0.0666 | Accuracy: 0.979 | ||
Train Epoch: 8 [59584/60000] Loss: 0.0744 | ||
Test set: Average loss: 0.0603 | Accuracy: 0.981 | ||
Train Epoch: 9 [59584/60000] Loss: 0.0762 | ||
Test set: Average loss: 0.0547 | Accuracy: 0.983 | ||
Train Epoch: 10 [59584/60000] Loss: 0.0207 | ||
Test set: Average loss: 0.0525 | Accuracy: 0.983 | ||
|
||
real 0m6.952s | ||
user 0m7.048s | ||
sys 0m0.619s | ||
|
||
Aidyn-A marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Conclusion | ||
---------- | ||
|
||
As we can see, just by applying a CUDA Graph on the `MNIST example | ||
<https://github.com/pytorch/examples/tree/main/cpp/mnist>`_ we were able to gain the performance | ||
by more than six times for training. This kind of large performance improvement was achievable due to | ||
the small model size. In case of larger models with heavy GPU usage, the CPU overhead is less impactful | ||
so the improvement will be smaller. Nevertheless, it is always advantageous to use CUDA Graphs to | ||
gain the performance of GPUs. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,31 @@ | ||||||
cmake_minimum_required(VERSION 3.1 FATAL_ERROR) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it will actually compiles with 3.1
Suggested change
|
||||||
project(mnist) | ||||||
set(CMAKE_CXX_STANDARD 17) | ||||||
|
||||||
find_package(Torch REQUIRED) | ||||||
find_package(Threads REQUIRED) | ||||||
|
||||||
option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON) | ||||||
if (DOWNLOAD_MNIST) | ||||||
message(STATUS "Downloading MNIST dataset") | ||||||
execute_process( | ||||||
COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py | ||||||
-d ${CMAKE_BINARY_DIR}/data | ||||||
ERROR_VARIABLE DOWNLOAD_ERROR) | ||||||
if (DOWNLOAD_ERROR) | ||||||
message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}") | ||||||
endif() | ||||||
endif() | ||||||
|
||||||
add_executable(mnist mnist.cpp) | ||||||
target_compile_features(mnist PUBLIC cxx_range_for) | ||||||
target_link_libraries(mnist ${TORCH_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) | ||||||
|
||||||
if (MSVC) | ||||||
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") | ||||||
add_custom_command(TARGET mnist | ||||||
POST_BUILD | ||||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different | ||||||
${TORCH_DLLS} | ||||||
$<TARGET_FILE_DIR:mnist>) | ||||||
endif (MSVC) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# MNIST Example with the PyTorch C++ Frontend | ||
|
||
This folder contains an example of training a computer vision model to recognize | ||
digits in images from the MNIST dataset, using the PyTorch C++ frontend. | ||
|
||
The entire training code is contained in `mnist.cpp`. | ||
|
||
To build the code, run the following commands from your terminal: | ||
|
||
```shell | ||
$ cd mnist | ||
$ mkdir build | ||
$ cd build | ||
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. | ||
$ make | ||
``` | ||
|
||
where `/path/to/libtorch` should be the path to the unzipped _LibTorch_ | ||
distribution, which you can get from the [PyTorch | ||
homepage](https://pytorch.org/get-started/locally/). | ||
|
||
Execute the compiled binary to train the model: | ||
|
||
```shell | ||
$ ./mnist | ||
Train Epoch: 1 [59584/60000] Loss: 0.4232 | ||
Test set: Average loss: 0.1989 | Accuracy: 0.940 | ||
Train Epoch: 2 [59584/60000] Loss: 0.1926 | ||
Test set: Average loss: 0.1338 | Accuracy: 0.959 | ||
Train Epoch: 3 [59584/60000] Loss: 0.1390 | ||
Test set: Average loss: 0.0997 | Accuracy: 0.969 | ||
Train Epoch: 4 [59584/60000] Loss: 0.1239 | ||
Test set: Average loss: 0.0875 | Accuracy: 0.972 | ||
... | ||
``` | ||
|
||
For running with CUDA Graphs add `--use-train-graph` and/or `--use-test-graph` | ||
for training and testing passes respectively. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.