Skip to content

Updated Doc for Intel XPU Profile #3013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added _static/img/itt_tutorial/vtune_xpu_config.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/itt_tutorial/vtune_xpu_timeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/trace_xpu_img.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 8 additions & 1 deletion en-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -647,4 +647,11 @@ url
colab
sharders
Criteo
torchrec
torchrec
_batch_norm_impl_index
convolution_overrideable
aten
XPU
XPUs
impl
overrideable
36 changes: 33 additions & 3 deletions recipes_source/profile_with_itt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ Launch Intel® VTune™ Profiler

To verify the functionality, you need to start an Intel® VTune™ Profiler instance. Please check the `Intel® VTune™ Profiler User Guide <https://www.intel.com/content/www/us/en/develop/documentation/vtune-help/top/launch.html>`__ for steps to launch Intel® VTune™ Profiler.

.. note::
Users can also use web-server-ui by following `Intel® VTune™ Profiler Web Server UI Guide <https://www.intel.com/content/www/us/en/docs/vtune-profiler/user-guide/2024-1/web-server-ui.html>`__
ex : vtune-backend --web-port=8080 --allow-remote-access --enable-server-profiling

Once you get the Intel® VTune™ Profiler GUI launched, you should see a user interface as below:

.. figure:: /_static/img/itt_tutorial/vtune_start.png
Expand All @@ -66,8 +70,8 @@ Once you get the Intel® VTune™ Profiler GUI launched, you should see a user i

Three sample results are available on the left side navigation bar under `sample (matrix)` project. If you do not want profiling results appear in this default sample project, you can create a new project via the button `New Project...` under the blue `Configure Analysis...` button. To start a new profiling, click the blue `Configure Analysis...` button to initiate configuration of the profiling.

Configure Profiling
~~~~~~~~~~~~~~~~~~~
Configure Profiling for CPU
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Once you click the `Configure Analysis...` button, you should see the screen below:

Expand All @@ -77,6 +81,16 @@ Once you click the `Configure Analysis...` button, you should see the screen bel

The right side of the windows is split into 3 parts: `WHERE` (top left), `WHAT` (bottom left), and `HOW` (right). With `WHERE`, you can assign a machine where you want to run the profiling on. With `WHAT`, you can set the path of the application that you want to profile. To profile a PyTorch script, it is recommended to wrap all manual steps, including activating a Python environment and setting required environment variables, into a bash script, then profile this bash script. In the screenshot above, we wrapped all steps into the `launch.sh` bash script and profile `bash` with the parameter to be `<path_of_launch.sh>`. On the right side `HOW`, you can choose whatever type that you would like to profile. Intel® VTune™ Profiler provides a bunch of profiling types that you can choose from. Details can be found at `Intel® VTune™ Profiler User Guide <https://www.intel.com/content/www/us/en/develop/documentation/vtune-help/top/analyze-performance.html>`__.


Configure Profiling for XPU
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Pick GPU Offload Profiling Type instead of Hotspots, and follow the same instructions as CPU to Launch the Application.

.. figure:: /_static/img/itt_tutorial/vtune_xpu_config.png
:width: 100%
:align: center


Read Profiling Result
~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -101,6 +115,18 @@ As illustrated on the right side navigation bar, brown portions in the timeline

Of course there are much more enriched sets of profiling features that Intel® VTune™ Profiler provides to help you understand a performance issue. When you understand the root cause of a performance issue, you can get it fixed. More detailed usage instructions are available at `Intel® VTune™ Profiler User Guide <https://www.intel.com/content/www/us/en/develop/documentation/vtune-help/top/analyze-performance.html>`__.

Read XPU Profiling Result
~~~~~~~~~~~~~~~~~~~~~~~~~

With a successful profiling with ITT, you can open `Platform` tab of the profiling result to see labels in the Intel® VTune™ Profiler timeline.

.. figure:: /_static/img/itt_tutorial/vtune_xpu_timeline.png
:width: 100%
:align: center


The timeline shows the main thread as a `python` thread on the top. Labeled PyTorch operators and customized regions are shown in the main thread row. All operators starting with `aten::` are operators labeled implicitly by the ITT feature in PyTorch. The timeline also shows the GPU Computing Queue on the top, and users could see different XPU Kernels dispatched into GPU Queue.

A short sample code showcasing how to use PyTorch ITT APIs
----------------------------------------------------------

Expand Down Expand Up @@ -128,8 +154,12 @@ The topology is formed by two operators, `Conv2d` and `Linear`. Three iterations
return x

def main():
m = ITTSample()
m = ITTSample
# unmark below code for XPU
# m = m.to("xpu")
x = torch.rand(10, 3, 244, 244)
# unmark below code for XPU
# x = x.to("xpu")
with torch.autograd.profiler.emit_itt():
for i in range(3)
# Labeling a region with pair of range_push and range_pop
Expand Down
85 changes: 67 additions & 18 deletions recipes_source/recipes/profiler_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
# - ``ProfilerActivity.CPU`` - PyTorch operators, TorchScript functions and
# user-defined code labels (see ``record_function`` below);
# - ``ProfilerActivity.CUDA`` - on-device CUDA kernels;
# - ``ProfilerActivity.XPU`` - on-device XPU kernels;
# - ``record_shapes`` - whether to record shapes of the operator inputs;
# - ``profile_memory`` - whether to report amount of memory consumed by
# model's Tensors;
Expand Down Expand Up @@ -160,17 +161,28 @@
# Note the occurrence of ``aten::convolution`` twice with different input shapes.

######################################################################
# Profiler can also be used to analyze performance of models executed on GPUs:

model = models.resnet18().cuda()
inputs = torch.randn(5, 3, 224, 224).cuda()

with profile(activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
# Profiler can also be used to analyze performance of models executed on GPUs and XPUs:
# Users could switch between cpu, cuda and xpu
if torch.cuda.is_available():
device = 'cuda'
elif torch.xpu.is_available():
device = 'xpu'
else:
print('Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices')
import sys
sys.exit(0)

activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU]
sort_by_keyword = device + "_time_total"

model = models.resnet18().to(device)
inputs = torch.randn(5, 3, 224, 224).to(device)

with profile(activities=activities, record_shapes=True) as prof:
with record_function("model_inference"):
model(inputs)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))

######################################################################
# (Note: the first use of CUDA profiling may bring an extra overhead.)
Expand All @@ -197,6 +209,36 @@
# Self CPU time total: 23.015ms
# Self CUDA time total: 11.666ms
#
######################################################################


######################################################################
# (Note: the first use of XPU profiling may bring an extra overhead.)

######################################################################
# The resulting table output (omitting some columns):
#
# .. code-block:: sh
#
#------------------------------------------------------- ------------ ------------ ------------ ------------ ------------
# Name Self XPU Self XPU % XPU total XPU time avg # of Calls
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------
# model_inference 0.000us 0.00% 2.567ms 2.567ms 1
# aten::conv2d 0.000us 0.00% 1.871ms 93.560us 20
# aten::convolution 0.000us 0.00% 1.871ms 93.560us 20
# aten::_convolution 0.000us 0.00% 1.871ms 93.560us 20
# aten::convolution_overrideable 1.871ms 72.89% 1.871ms 93.560us 20
# gen_conv 1.484ms 57.82% 1.484ms 74.216us 20
# aten::batch_norm 0.000us 0.00% 432.640us 21.632us 20
# aten::_batch_norm_impl_index 0.000us 0.00% 432.640us 21.632us 20
# aten::native_batch_norm 432.640us 16.85% 432.640us 21.632us 20
# conv_reorder 386.880us 15.07% 386.880us 6.448us 60
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------
# Self CPU time total: 712.486ms
# Self XPU time total: 2.567ms
Comment on lines +215 to +238
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What value does this extra table brings to the user?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are indeed just minor changes including some different operators and also XPU instead of GPU in this table. just want people to understand how it might look like for the output.


#


######################################################################
# Note the occurrence of on-device kernels in the output (e.g. ``sgemm_32x32x32_NN``).
Expand Down Expand Up @@ -266,17 +308,22 @@
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Profiling results can be outputted as a ``.json`` trace file:
# Tracing CUDA or XPU kernels
# Users could switch between cpu, cuda and xpu
device = 'cuda'

activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU]

model = models.resnet18().cuda()
inputs = torch.randn(5, 3, 224, 224).cuda()
model = models.resnet18().to(device)
inputs = torch.randn(5, 3, 224, 224).to(device)

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
with profile(activities=activities) as prof:
model(inputs)

prof.export_chrome_trace("trace.json")

######################################################################
# You can examine the sequence of profiled operators and CUDA kernels
# You can examine the sequence of profiled operators and CUDA/XPU kernels
# in Chrome trace viewer (``chrome://tracing``):
#
# .. image:: ../../_static/img/trace_img.png
Expand All @@ -287,15 +334,16 @@
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Profiler can be used to analyze Python and TorchScript stack traces:
sort_by_keyword = "self_" + device + "_time_total"

with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
activities=activities,
with_stack=True,
) as prof:
model(inputs)

# Print aggregated stats
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=2))
print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=2))

#################################################################################
# The output might look like this (omitting some columns):
Expand Down Expand Up @@ -384,15 +432,17 @@
# To send the signal to the profiler that the next step has started, call ``prof.step()`` function.
# The current profiler step is stored in ``prof.step_num``.
#
# The following example shows how to use all of the concepts above:
# The following example shows how to use all of the concepts above for CUDA and XPU Kernels:

sort_by_keyword = "self_" + device + "_time_total"

def trace_handler(p):
output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)
output = p.key_averages().table(sort_by=sort_by_keyword, row_limit=10)
print(output)
p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json")

with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
activities=activities,
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
Expand All @@ -403,7 +453,6 @@ def trace_handler(p):
model(inputs)
p.step()


######################################################################
# Learn More
# ----------
Expand Down
Loading