From 66800eb8dd885b1adb8e7ee89f4f96e81626a61a Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 26 Sep 2024 07:13:53 -0700 Subject: [PATCH 01/19] add max-autotune tutorial --- ...totune_CPU_with_gemm_template_tutorial.rst | 198 ++++++++++++++++++ prototype_source/prototype_index.rst | 8 + 2 files changed, 206 insertions(+) create mode 100644 prototype_source/max_autotune_CPU_with_gemm_template_tutorial.rst diff --git a/prototype_source/max_autotune_CPU_with_gemm_template_tutorial.rst b/prototype_source/max_autotune_CPU_with_gemm_template_tutorial.rst new file mode 100644 index 00000000000..3d446cf6bdc --- /dev/null +++ b/prototype_source/max_autotune_CPU_with_gemm_template_tutorial.rst @@ -0,0 +1,198 @@ +Max-autotune Support on CPU with GEMM Template Tutorial +============================================================== + +**Author**: `Jiong Gong `__, `Leslie Fang `__, `Chunyuan Wu `__ + +Prerequisites: +---------------- +- `torch.compile and TorchInductor concepts in PyTorch `__ + +Introduction +------------ +``max-autotune`` mode for the Inductor CPU backend in ``torch.compile`` profiles multiple implementations of operations at compile time and selects the best-performing one, +trading longer compilation times for improved runtime performance. This enhancement is particularly beneficial for GEMM-related operations. +In the Inductor CPU backend, we’ve introduced a C++ template-based GEMM implementation as an alternative to the ATen-based approach that relies on oneDNN and MKL libraries. +This is similar to the max-autotune mode on CUDA, where implementations from ATen, Triton, and CUTLASS are considered. + +We have covered most popular data types, including FP32, BF16, FP16, and INT8, with epilogue fusions for x86 CPUs. + +How to activate ``max-autotune`` mode +------------ +To activate the ``max-autotune`` mode in PyTorch, set the ``mode`` argument to ``max-autotune`` when compiling your model using ``torch.compile``. +If you prefer to bypass the tuning process and always use the CPP template implementations, you can configure this via an environment variable: +``export TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=CPP``. + + +Example code +------------ +The below code is an example of using the ``max-autotune`` mode on a simple neural network with a linear layer followed by a ReLU activation. +You could run the example code by setting this environment variable ``export TORCHINDUCTOR_FREEZING=1``. + + +.. code:: python + + import torch + from torch._inductor import config + config.trace.log_autotuning_results = True # enable the log of autotuning results + + class M(torch.nn.Module): + def __init__( + self, + in_features, + out_features, + bias, + **kwargs, + ): + super().__init__() + self.linear = torch.nn.Linear( + in_features, + out_features, + bias, + **kwargs, + ) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.linear(x) + x = self.relu(x) + return x + + amp_enabled = True + batch_size = 64 + in_features = 16 + out_features = 32 + bias = True + + x = torch.randn(batch_size, in_features) + model = M(in_features, out_features, bias) + + with torch.no_grad(), torch.cpu.amp.autocast(enabled=amp_enabled): + compiled = torch.compile(model, mode="max-autotune") # turn on "max-autotune" mode + y = compiled(x) + + +When running the above code snippet, you will see the autotuning result (the performance numbers are for demonstration purposes). +In this case, CPP template outperforms ATen kernel so that it will be selected. + +.. code:: shell + + AUTOTUNE linear_unary(64x16, 32x16, 32) + cpp_packed_gemm_0 0.2142 ms 100.0% + _linear_pointwise 0.2441 ms 87.7% + + +We could check the generated output code by setting ``export TORCH_LOGS="+output_code"``. +When CPP template is selected, we won't have ``torch.ops.mkldnn._linear_pointwise.default`` (for bfloat16) or ``torch.ops.mkl._mkl_linear.default`` (for float32) +in the generated code anymore, instead, we'll find kernel based on CPP GEMM template ``cpp_fused__to_copy_relu_1`` +(only part of the code is demonstrated below for simplicity) with the bias and relu epilogues fused inside the CPP GEMM template kernel. + +.. code:: python + + cpp_fused__to_copy_relu_1 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*'], ''' + + ... + + template + inline void kernel_micro_gemm_amx_kernel_32_2( + AMXState& amx_state, + const bfloat16* __restrict__ A, + const bfloat16* __restrict__ B, + float* __restrict__ C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + uint8_t tilecfg_rows + ) { + ... + } + + ... + + template + inline void kernel_micro_gemm( + AMXState& amx_state, + const bfloat16* __restrict__ A, + const bfloat16* __restrict__ B, + float* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc + ) { + ... + } + + extern "C" + void kernel(const bfloat16* X, const bfloat16* W, const bfloat16* inp, bfloat16* Y) + { + constexpr int64_t num_threads = 40; + constexpr int64_t N = 32; + constexpr int64_t K = 16; + constexpr int64_t M = static_cast(64L); + ... + #pragma omp parallel num_threads(40) + { + const int tid = omp_get_thread_num(); + ... + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { + ... + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + ... + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + ... + for (int64_t nci = nc; nci < nc_block_end; nci++) { + if (kc == k_block_start) { + kernel_micro_gemm(false)>( + ... + ); + + } else { + kernel_micro_gemm(true)>( + ... + ); + + } + } + } + { + { + // Epilogue fusion here for bias and relu + #pragma GCC ivdep + for(int64_t x0=static_cast(0L); x0(m_end + ((-1L)*m_start)); x0+=static_cast(1L)) + { + for(int64_t x1=static_cast(0L); x1(16L*(c10::div_floor_integer(static_cast((n_end + ((-1L)*n_start))), static_cast(16L)))); x1+=static_cast(16L)) + { + auto tmp0 = at::vec::Vectorized::loadu(inp + static_cast(n_start + x1), static_cast(16)); + auto tmp2 = at::vec::Vectorized::loadu(local_acc_buf + static_cast(x1 + (Nc_blocks*Nr*x0)), static_cast(16)); + auto tmp1 = at::vec::convert(tmp0); + auto tmp3 = tmp1 + tmp2; + auto tmp4 = at::vec::convert(tmp3); + auto tmp5 = static_cast(0.0); + auto tmp6 = at::vec::Vectorized(tmp5); + auto tmp7 = at::vec::maximum(tmp3, tmp6); + auto tmp8 = at::vec::convert(tmp7); + tmp8.store(Y + static_cast(n_start + x1 + (32L*m_start) + (32L*x0)), static_cast(16)); + } + + ... + + } + } + + } + } + } + ... + } + } + ''') + +Conclusion +------------ +In this tutorial, we introduced max-autotune support on CPU with GEMM template. We explained the API to activate this feature and demonstrated +the generated code of GEMM template. + +This feature is in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues `_. diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index 1eaedb6a1d9..56279d63625 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -217,6 +217,13 @@ Prototype features are not available as part of binary distributions like PyPI o :link: ../prototype/inductor_cpp_wrapper_tutorial.html :tags: Model-Optimization +.. customcarditem:: + :header: Max-autotune Support on CPU with GEMM Template Tutorial + :card_description: Tutorial for max-autotune mode support for torch.compile with GEMM template + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../prototype/max_autotune_CPU_with_gemm_template_tutorial.html + :tags: Model-Optimization + .. Distributed .. customcarditem:: :header: Flight Recorder Tutorial @@ -265,3 +272,4 @@ Prototype features are not available as part of binary distributions like PyPI o prototype/maskedtensor_sparsity.html prototype/maskedtensor_advanced_semantics.html prototype/maskedtensor_adagrad.html + prototype/max_autotune_CPU_with_gemm_template_tutorial.html From 1d575431f70fec5619e10bf44cd5c2191ce9d4fd Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 26 Sep 2024 20:33:51 -0700 Subject: [PATCH 02/19] Rename the tutorial --- ...mplate_tutorial.rst => max_autotune_on_CPU_tutorial.rst} | 2 +- prototype_source/prototype_index.rst | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename prototype_source/{max_autotune_CPU_with_gemm_template_tutorial.rst => max_autotune_on_CPU_tutorial.rst} (99%) diff --git a/prototype_source/max_autotune_CPU_with_gemm_template_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst similarity index 99% rename from prototype_source/max_autotune_CPU_with_gemm_template_tutorial.rst rename to prototype_source/max_autotune_on_CPU_tutorial.rst index 3d446cf6bdc..c9ead5f8f95 100644 --- a/prototype_source/max_autotune_CPU_with_gemm_template_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -1,4 +1,4 @@ -Max-autotune Support on CPU with GEMM Template Tutorial +Use max-autotune compilation on CPU to gain further performance boost ============================================================== **Author**: `Jiong Gong `__, `Leslie Fang `__, `Chunyuan Wu `__ diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index 56279d63625..bfc8880f155 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -218,10 +218,10 @@ Prototype features are not available as part of binary distributions like PyPI o :tags: Model-Optimization .. customcarditem:: - :header: Max-autotune Support on CPU with GEMM Template Tutorial - :card_description: Tutorial for max-autotune mode support for torch.compile with GEMM template + :header: Use max-autotune compilation on CPU to gain further performance boost + :card_description: Tutorial for max-autotune mode on CPU to gain further performance boost :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png - :link: ../prototype/max_autotune_CPU_with_gemm_template_tutorial.html + :link: ../prototype/max_autotune_on_CPU_tutorial.html :tags: Model-Optimization .. Distributed From 243c58e228bccf835b68d10c9ab25c67e4df73d4 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 26 Sep 2024 20:40:32 -0700 Subject: [PATCH 03/19] add RFC link and mention that code is subject to change --- prototype_source/max_autotune_on_CPU_tutorial.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index c9ead5f8f95..799f160622d 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -9,7 +9,8 @@ Prerequisites: Introduction ------------ -``max-autotune`` mode for the Inductor CPU backend in ``torch.compile`` profiles multiple implementations of operations at compile time and selects the best-performing one, +``max-autotune`` mode for the Inductor CPU backend in ``torch.compile`` (`RFC link `) +profiles multiple implementations of operations at compile time and selects the best-performing one, trading longer compilation times for improved runtime performance. This enhancement is particularly beneficial for GEMM-related operations. In the Inductor CPU backend, we’ve introduced a C++ template-based GEMM implementation as an alternative to the ATen-based approach that relies on oneDNN and MKL libraries. This is similar to the max-autotune mode on CUDA, where implementations from ATen, Triton, and CUTLASS are considered. @@ -85,6 +86,7 @@ We could check the generated output code by setting ``export TORCH_LOGS="+output When CPP template is selected, we won't have ``torch.ops.mkldnn._linear_pointwise.default`` (for bfloat16) or ``torch.ops.mkl._mkl_linear.default`` (for float32) in the generated code anymore, instead, we'll find kernel based on CPP GEMM template ``cpp_fused__to_copy_relu_1`` (only part of the code is demonstrated below for simplicity) with the bias and relu epilogues fused inside the CPP GEMM template kernel. +The generated code differs by CPU architecture and is implementation-specific, which is subject to change. .. code:: python From 9380b9d16d96be195575f7f4a03d844924c7b951 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 26 Sep 2024 20:47:22 -0700 Subject: [PATCH 04/19] fix link --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index 799f160622d..1cfbaabdb81 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -9,7 +9,7 @@ Prerequisites: Introduction ------------ -``max-autotune`` mode for the Inductor CPU backend in ``torch.compile`` (`RFC link `) +``max-autotune`` mode for the Inductor CPU backend in ``torch.compile`` (`RFC link `_) profiles multiple implementations of operations at compile time and selects the best-performing one, trading longer compilation times for improved runtime performance. This enhancement is particularly beneficial for GEMM-related operations. In the Inductor CPU backend, we’ve introduced a C++ template-based GEMM implementation as an alternative to the ATen-based approach that relies on oneDNN and MKL libraries. From 29effc56df5b9b6a0733ebcfbcb80fcc9b4ca183 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 26 Sep 2024 20:56:34 -0700 Subject: [PATCH 05/19] add request on frozen and no_grad --- prototype_source/max_autotune_on_CPU_tutorial.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index 1cfbaabdb81..fa13e44da67 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -27,8 +27,10 @@ If you prefer to bypass the tuning process and always use the CPP template imple Example code ------------ The below code is an example of using the ``max-autotune`` mode on a simple neural network with a linear layer followed by a ReLU activation. -You could run the example code by setting this environment variable ``export TORCHINDUCTOR_FREEZING=1``. +We only support frozen model with ``torch.no_grad`` or the inference mode +Therefore, you need to set the environment variable ``export TORCHINDUCTOR_FREEZING=1`` +and ensure that both the compilation and inference steps are executed within the ``torch.no_grad`` context. .. code:: python @@ -86,6 +88,7 @@ We could check the generated output code by setting ``export TORCH_LOGS="+output When CPP template is selected, we won't have ``torch.ops.mkldnn._linear_pointwise.default`` (for bfloat16) or ``torch.ops.mkl._mkl_linear.default`` (for float32) in the generated code anymore, instead, we'll find kernel based on CPP GEMM template ``cpp_fused__to_copy_relu_1`` (only part of the code is demonstrated below for simplicity) with the bias and relu epilogues fused inside the CPP GEMM template kernel. + The generated code differs by CPU architecture and is implementation-specific, which is subject to change. .. code:: python From b8639c10123e3abf2c2a3ff93729499f83537d39 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 26 Sep 2024 22:26:40 -0700 Subject: [PATCH 06/19] add description on perf boost --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index fa13e44da67..36bb18986e2 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -17,6 +17,8 @@ This is similar to the max-autotune mode on CUDA, where implementations from ATe We have covered most popular data types, including FP32, BF16, FP16, and INT8, with epilogue fusions for x86 CPUs. +While the development is still in progress, we have already seen promising speedups over pure ATen-based GEMMs as measured by the three benchmark suites and the inference of LLMs. + How to activate ``max-autotune`` mode ------------ To activate the ``max-autotune`` mode in PyTorch, set the ``mode`` argument to ``max-autotune`` when compiling your model using ``torch.compile``. From 806ee21793570116054df0ddd5a88d5e350c98b9 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 27 Sep 2024 00:46:32 -0700 Subject: [PATCH 07/19] change from further to additional in the title --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- prototype_source/prototype_index.rst | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index 36bb18986e2..6d6e50cc3dc 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -1,4 +1,4 @@ -Use max-autotune compilation on CPU to gain further performance boost +Use max-autotune compilation on CPU to gain additional performance boost ============================================================== **Author**: `Jiong Gong `__, `Leslie Fang `__, `Chunyuan Wu `__ diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index bfc8880f155..4e104aa40e8 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -218,8 +218,8 @@ Prototype features are not available as part of binary distributions like PyPI o :tags: Model-Optimization .. customcarditem:: - :header: Use max-autotune compilation on CPU to gain further performance boost - :card_description: Tutorial for max-autotune mode on CPU to gain further performance boost + :header: Use max-autotune compilation on CPU to gain additional performance boost + :card_description: Tutorial for max-autotune mode on CPU to gain additional performance boost :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png :link: ../prototype/max_autotune_on_CPU_tutorial.html :tags: Model-Optimization From dce78f3497249166e3067e7ffe306cca79878d31 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Sat, 28 Sep 2024 07:09:46 -0700 Subject: [PATCH 08/19] Add more details for freezing --- prototype_source/max_autotune_on_CPU_tutorial.rst | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index 6d6e50cc3dc..c545fd7d519 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -30,8 +30,12 @@ Example code ------------ The below code is an example of using the ``max-autotune`` mode on a simple neural network with a linear layer followed by a ReLU activation. -We only support frozen model with ``torch.no_grad`` or the inference mode -Therefore, you need to set the environment variable ``export TORCHINDUCTOR_FREEZING=1`` +In the C++ template-based GEMM implementation, we will pre-pack the weight for good cache usage. +In the case of inference which is the primary scenario of CPU AI workloads, +model weights are constant and we pack them upfront during compilation +so that the data accesses are contiguous within the cache blocks. +Thus, We only support frozen model with ``torch.no_grad`` or the inference mode. +You need to set the environment variable ``export TORCHINDUCTOR_FREEZING=1`` and ensure that both the compilation and inference steps are executed within the ``torch.no_grad`` context. .. code:: python From 7540b9a5c12b1a10c52d575c862e993773151bfa Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 2 Oct 2024 10:27:14 +0800 Subject: [PATCH 09/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/max_autotune_on_CPU_tutorial.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index c545fd7d519..ff4c0aa6dc8 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -2,6 +2,11 @@ Use max-autotune compilation on CPU to gain additional performance boost ============================================================== **Author**: `Jiong Gong `__, `Leslie Fang `__, `Chunyuan Wu `__ +In this tutorial, you will learn how to boost your PyTorch models' performance on CPU by +leveraging the max-autotune mode in the Inductor CPU backend. Explore the activation +process, understand the differences from traditional methods, and integrate max-autotune +into your code for enhanced computational efficiency. Dive into the use of advanced +GEMM templates for faster processing and superior runtime performance. Prerequisites: ---------------- From fb8f415cb7a4acfda4731ab68bbe04207deafa41 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 2 Oct 2024 10:27:29 +0800 Subject: [PATCH 10/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/max_autotune_on_CPU_tutorial.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index ff4c0aa6dc8..c1c5723ccfa 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -24,8 +24,8 @@ We have covered most popular data types, including FP32, BF16, FP16, and INT8, w While the development is still in progress, we have already seen promising speedups over pure ATen-based GEMMs as measured by the three benchmark suites and the inference of LLMs. -How to activate ``max-autotune`` mode ------------- +Activating the ``max-autotune`` mode +------------------------------------- To activate the ``max-autotune`` mode in PyTorch, set the ``mode`` argument to ``max-autotune`` when compiling your model using ``torch.compile``. If you prefer to bypass the tuning process and always use the CPP template implementations, you can configure this via an environment variable: ``export TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=CPP``. From 4323a708cf219aa98bf4eca374bb54531b462ec4 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 2 Oct 2024 10:27:38 +0800 Subject: [PATCH 11/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index c1c5723ccfa..b53039d00b5 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -31,7 +31,7 @@ If you prefer to bypass the tuning process and always use the CPP template imple ``export TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=CPP``. -Example code +Example ------------ The below code is an example of using the ``max-autotune`` mode on a simple neural network with a linear layer followed by a ReLU activation. From e480573b7ef0c9a7b561ee3794b7c6bbdace127d Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 2 Oct 2024 10:27:48 +0800 Subject: [PATCH 12/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index b53039d00b5..4c4639c625d 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -86,7 +86,7 @@ and ensure that both the compilation and inference steps are executed within the When running the above code snippet, you will see the autotuning result (the performance numbers are for demonstration purposes). -In this case, CPP template outperforms ATen kernel so that it will be selected. +In this example, C++ template outperforms ATen kernel so that it will be selected. .. code:: shell From f9b415959d1d1159851597dfedcec45286c72d41 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 2 Oct 2024 10:27:55 +0800 Subject: [PATCH 13/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index 4c4639c625d..5989b417b5b 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -27,7 +27,7 @@ While the development is still in progress, we have already seen promising speed Activating the ``max-autotune`` mode ------------------------------------- To activate the ``max-autotune`` mode in PyTorch, set the ``mode`` argument to ``max-autotune`` when compiling your model using ``torch.compile``. -If you prefer to bypass the tuning process and always use the CPP template implementations, you can configure this via an environment variable: +If you prefer to bypass the tuning process and always use the C++ template implementations, you can configure this via an environment variable: ``export TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=CPP``. From a43f7b99a5d0e363e72f8e2885f2096cd126c776 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 2 Oct 2024 10:28:04 +0800 Subject: [PATCH 14/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index 5989b417b5b..ef88647ba6a 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -96,7 +96,7 @@ In this example, C++ template outperforms ATen kernel so that it will be selecte We could check the generated output code by setting ``export TORCH_LOGS="+output_code"``. -When CPP template is selected, we won't have ``torch.ops.mkldnn._linear_pointwise.default`` (for bfloat16) or ``torch.ops.mkl._mkl_linear.default`` (for float32) +When C++ template is selected, we won't have ``torch.ops.mkldnn._linear_pointwise.default`` (for bfloat16) or ``torch.ops.mkl._mkl_linear.default`` (for float32) in the generated code anymore, instead, we'll find kernel based on CPP GEMM template ``cpp_fused__to_copy_relu_1`` (only part of the code is demonstrated below for simplicity) with the bias and relu epilogues fused inside the CPP GEMM template kernel. From 5aa9e8409b90bca52222766ad1790a09fee01d0c Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 2 Oct 2024 10:28:13 +0800 Subject: [PATCH 15/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index ef88647ba6a..127a325a1b0 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -98,7 +98,7 @@ In this example, C++ template outperforms ATen kernel so that it will be selecte We could check the generated output code by setting ``export TORCH_LOGS="+output_code"``. When C++ template is selected, we won't have ``torch.ops.mkldnn._linear_pointwise.default`` (for bfloat16) or ``torch.ops.mkl._mkl_linear.default`` (for float32) in the generated code anymore, instead, we'll find kernel based on CPP GEMM template ``cpp_fused__to_copy_relu_1`` -(only part of the code is demonstrated below for simplicity) with the bias and relu epilogues fused inside the CPP GEMM template kernel. +(only part of the code is demonstrated below for simplicity) with the bias and relu epilogues fused inside the C++ GEMM template kernel. The generated code differs by CPU architecture and is implementation-specific, which is subject to change. From c320d58f32d97c802be0ad6c0461fd89fe8b98ef Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 2 Oct 2024 10:28:24 +0800 Subject: [PATCH 16/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index 127a325a1b0..9bcaba7be14 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -208,7 +208,7 @@ The generated code differs by CPU architecture and is implementation-specific, w Conclusion ------------ -In this tutorial, we introduced max-autotune support on CPU with GEMM template. We explained the API to activate this feature and demonstrated +In this tutorial, we introduced max-autotune support on CPU with GEMM template. We explained the API to activate this feature, and demonstrated the generated code of GEMM template. This feature is in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues `_. From f1fae2ebc98bab4a50b8a79fdee0410844f3aa1e Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 2 Oct 2024 10:28:31 +0800 Subject: [PATCH 17/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst Co-authored-by: Svetlana Karslioglu --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index 9bcaba7be14..dc046776b9c 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -209,6 +209,6 @@ The generated code differs by CPU architecture and is implementation-specific, w Conclusion ------------ In this tutorial, we introduced max-autotune support on CPU with GEMM template. We explained the API to activate this feature, and demonstrated -the generated code of GEMM template. +the generated code of the GEMM template. This feature is in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues `_. From 2f90eae0d12016101e20226a0a8c000c6133694c Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Wed, 9 Oct 2024 11:52:17 -0700 Subject: [PATCH 18/19] Formatting fixes. --- prototype_source/max_autotune_on_CPU_tutorial.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index dc046776b9c..a199265c743 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -1,7 +1,8 @@ -Use max-autotune compilation on CPU to gain additional performance boost -============================================================== +Using Max-Autotune Compilation on CPU for Better Performance +================================================================================ **Author**: `Jiong Gong `__, `Leslie Fang `__, `Chunyuan Wu `__ + In this tutorial, you will learn how to boost your PyTorch models' performance on CPU by leveraging the max-autotune mode in the Inductor CPU backend. Explore the activation process, understand the differences from traditional methods, and integrate max-autotune From ad0c00cd8d652e641507449c6ba7b0c9d6cce9e3 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Wed, 9 Oct 2024 14:53:14 -0700 Subject: [PATCH 19/19] Update prototype_source/max_autotune_on_CPU_tutorial.rst --- prototype_source/max_autotune_on_CPU_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/max_autotune_on_CPU_tutorial.rst b/prototype_source/max_autotune_on_CPU_tutorial.rst index a199265c743..47374744938 100644 --- a/prototype_source/max_autotune_on_CPU_tutorial.rst +++ b/prototype_source/max_autotune_on_CPU_tutorial.rst @@ -15,7 +15,7 @@ Prerequisites: Introduction ------------ -``max-autotune`` mode for the Inductor CPU backend in ``torch.compile`` (`RFC link `_) +The ``max-autotune`` mode for the Inductor CPU backend in ``torch.compile`` (`RFC link `_) profiles multiple implementations of operations at compile time and selects the best-performing one, trading longer compilation times for improved runtime performance. This enhancement is particularly beneficial for GEMM-related operations. In the Inductor CPU backend, we’ve introduced a C++ template-based GEMM implementation as an alternative to the ATen-based approach that relies on oneDNN and MKL libraries.