Skip to content

Add logging recipe #2704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 24, 2024
9 changes: 9 additions & 0 deletions recipes_source/recipes_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
:link: ../recipes/recipes/module_load_state_dict_tips.html
:tags: Basics

.. customcarditem::
:header: (beta) Using TORCH_LOGS to observe torch.compile
:card_description: Learn how to use the torch logging APIs to observe the compilation process.
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../recipes/torch_logs.html
:tags: Basics


.. Interpretability

.. customcarditem::
Expand Down Expand Up @@ -362,6 +370,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu

/recipes/recipes/loading_data_recipe
/recipes/recipes/defining_a_neural_network
/recipes/torch_logs
/recipes/recipes/what_is_state_dict
/recipes/recipes/saving_and_loading_models_for_inference
/recipes/recipes/saving_and_loading_a_general_checkpoint
Expand Down
100 changes: 100 additions & 0 deletions recipes_source/torch_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
(beta) Using TORCH_LOGS python API with torch.compile
==========================================================================================
**Author:** `Michael Lazos <https://github.com/mlazos>`_
"""

import logging

######################################################################
#
# This tutorial introduces the ``TORCH_LOGS`` environment variable, as well ass the Python API, and
# demonstrates how to apply it to observe the phases of ``torch.compile``.
#
# .. note::
#
# This tutorial requires PyTorch 2.2.0 or later.
#
#


######################################################################
# Setup
# ~~~~~~~~~~~~~~~~~~~~~
# In this example, we'll set up a simple Python function which performs an elementwise
# add and observe the compilation process with ``TORCH_LOGS`` Python API.
#
# .. note::
#
# There is also an environment variable ``TORCH_LOGS``, which can be used to
# change logging settings at the command line. The equivalent environment
# variable setting is shown for each example.

import torch

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys

sys.exit(0)


@torch.compile()
def fn(x, y):
z = x + y
return z + 2


inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))


# print separator and reset dynamo
# between each example
def separator(name):
print(f"==================={name}=========================")
torch._dynamo.reset()


separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
torch._logging.set_logs(dynamo=logging.DEBUG)
fn(*inputs)

separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
torch._logging.set_logs(graph=True)
fn(*inputs)

separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
torch._logging.set_logs(fusion=True)
fn(*inputs)

separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
torch._logging.set_logs(output_code=True)
fn(*inputs)

separator("")

######################################################################
# Conclusion
# ~~~~~~~~~~
#
# In this tutorial we introduced the TORCH_LOGS environment variable and python API
# by experimenting with a small number of the available logging options.
# To view descriptions of all available options, run any python script
# which imports torch and set TORCH_LOGS to "help".
#
# Alternatively, you can view the `torch._logging documentation`_ to see
# descriptions of all available logging options.
#
# For more information on torch.compile, see the `torch.compile tutorial`_.
#
# .. _torch._logging documentation: https://pytorch.org/docs/main/logging.html
# .. _torch.compile tutorial: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html