-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Add logging recipe #2704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add logging recipe #2704
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
f4b98a9
Added logging recipe
mlazos 83b075d
Add links to documentation and help
mlazos 4a17ff2
Update recipes_source/torch_logs.py
mlazos 582e298
Update recipes_source/torch_logs.py
mlazos 301ad90
Update recipes_source/torch_logs.py
mlazos 2ebaf1f
Update recipes_source/torch_logs.py
mlazos 9412aac
Update recipes_source/torch_logs.py
mlazos 53db40d
Add link to torch.compile tutorial
mlazos 1010779
Add cuda version check
mlazos 789aa89
Merge branch 'main' into mlazos/log-recipe
svekars dff3cf8
Merge branch 'main' into mlazos/log-recipe
svekars 39784aa
Merge branch 'main' into mlazos/log-recipe
svekars e486157
Merge branch 'main' into mlazos/log-recipe
svekars File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
(beta) Using TORCH_LOGS python API with torch.compile | ||
========================================================================================== | ||
**Author:** `Michael Lazos <https://github.com/mlazos>`_ | ||
""" | ||
|
||
import logging | ||
|
||
###################################################################### | ||
# | ||
# This tutorial introduces the ``TORCH_LOGS`` environment variable, as well ass the Python API, and | ||
# demonstrates how to apply it to observe the phases of ``torch.compile``. | ||
# | ||
# .. note:: | ||
# | ||
# This tutorial requires PyTorch 2.2.0 or later. | ||
# | ||
# | ||
|
||
|
||
###################################################################### | ||
# Setup | ||
# ~~~~~~~~~~~~~~~~~~~~~ | ||
# In this example, we'll set up a simple Python function which performs an elementwise | ||
# add and observe the compilation process with ``TORCH_LOGS`` Python API. | ||
# | ||
# .. note:: | ||
# | ||
# There is also an environment variable ``TORCH_LOGS``, which can be used to | ||
# change logging settings at the command line. The equivalent environment | ||
# variable setting is shown for each example. | ||
|
||
import torch | ||
|
||
# exit cleanly if we are on a device that doesn't support torch.compile | ||
if torch.cuda.get_device_capability() < (7, 0): | ||
print("Exiting because torch.compile is not supported on this device.") | ||
import sys | ||
|
||
sys.exit(0) | ||
|
||
|
||
@torch.compile() | ||
def fn(x, y): | ||
z = x + y | ||
return z + 2 | ||
|
||
|
||
inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda")) | ||
|
||
|
||
# print separator and reset dynamo | ||
# between each example | ||
def separator(name): | ||
print(f"==================={name}=========================") | ||
torch._dynamo.reset() | ||
|
||
|
||
separator("Dynamo Tracing") | ||
# View dynamo tracing | ||
# TORCH_LOGS="+dynamo" | ||
torch._logging.set_logs(dynamo=logging.DEBUG) | ||
fn(*inputs) | ||
|
||
separator("Traced Graph") | ||
# View traced graph | ||
# TORCH_LOGS="graph" | ||
torch._logging.set_logs(graph=True) | ||
fn(*inputs) | ||
|
||
separator("Fusion Decisions") | ||
# View fusion decisions | ||
# TORCH_LOGS="fusion" | ||
torch._logging.set_logs(fusion=True) | ||
fn(*inputs) | ||
|
||
separator("Output Code") | ||
# View output code generated by inductor | ||
# TORCH_LOGS="output_code" | ||
torch._logging.set_logs(output_code=True) | ||
fn(*inputs) | ||
|
||
separator("") | ||
|
||
###################################################################### | ||
# Conclusion | ||
# ~~~~~~~~~~ | ||
# | ||
# In this tutorial we introduced the TORCH_LOGS environment variable and python API | ||
# by experimenting with a small number of the available logging options. | ||
# To view descriptions of all available options, run any python script | ||
# which imports torch and set TORCH_LOGS to "help". | ||
# | ||
# Alternatively, you can view the `torch._logging documentation`_ to see | ||
# descriptions of all available logging options. | ||
# | ||
# For more information on torch.compile, see the `torch.compile tutorial`_. | ||
# | ||
# .. _torch._logging documentation: https://pytorch.org/docs/main/logging.html | ||
# .. _torch.compile tutorial: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.