Skip to content

Commit 336579d

Browse files
committed
Update aoti tutorial
1 parent d8a9749 commit 336579d

File tree

1 file changed

+135
-69
lines changed

1 file changed

+135
-69
lines changed

recipes_source/torch_export_aoti_python.py

Lines changed: 135 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
.. meta::
55
:description: An end-to-end example of how to use AOTInductor for Python runtime.
6-
:keywords: torch.export, AOTInductor, torch._inductor.aot_compile, torch._export.aot_load
6+
:keywords: torch.export, AOTInductor, torch._inductor.aoti_compile_and_package, aot_compile, torch._export.aoti_load_package
77
88
``torch.export`` AOTInductor Tutorial for Python runtime (Beta)
99
===============================================================
@@ -14,19 +14,18 @@
1414
#
1515
# .. warning::
1616
#
17-
# ``torch._inductor.aot_compile`` and ``torch._export.aot_load`` are in Beta status and are subject to backwards compatibility
18-
# breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.
17+
# ``torch._inductor.aoti_compile_and_package`` and
18+
# ``torch._inductor.aoti_load_package`` are in Beta status and are subject
19+
# to backwards compatibility breaking changes. This tutorial provides an
20+
# example of how to use these APIs for model deployment using Python
21+
# runtime.
1922
#
20-
# It has been shown `previously <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how AOTInductor can be used
21-
# to do Ahead-of-Time compilation of PyTorch exported models by creating
22-
# a shared library that can be run in a non-Python environment.
23-
#
24-
#
25-
# In this tutorial, you will learn an end-to-end example of how to use AOTInductor for Python runtime.
26-
# We will look at how to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a
27-
# shared library. Additionally, we will examine how to execute the shared library in Python runtime using :func:`torch._export.aot_load`.
28-
# You will learn about the speed up seen in the first inference time using AOTInductor, especially when using
29-
# ``max-autotune`` mode which can take some time to execute.
23+
# It has been shown `previously
24+
# <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`__ how
25+
# AOTInductor can be used to do Ahead-of-Time compilation of PyTorch exported
26+
# models by creating an artifact that can be run in a non-Python environment.
27+
# In this tutorial, you will learn an end-to-end example of how to use
28+
# AOTInductor for Python runtime.
3029
#
3130
# **Contents**
3231
#
@@ -36,115 +35,182 @@
3635
######################################################################
3736
# Prerequisites
3837
# -------------
39-
# * PyTorch 2.4 or later
38+
# * PyTorch 2.6 or later
4039
# * Basic understanding of ``torch.export`` and AOTInductor
4140
# * Complete the `AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models <https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#>`_ tutorial
4241

4342
######################################################################
4443
# What you will learn
4544
# ----------------------
46-
# * How to use AOTInductor for python runtime.
47-
# * How to use :func:`torch._inductor.aot_compile` along with :func:`torch.export.export` to generate a shared library
48-
# * How to run a shared library in Python runtime using :func:`torch._export.aot_load`.
49-
# * When do you use AOTInductor for python runtime
45+
# * How to use AOTInductor for Python runtime.
46+
# * How to use :func:`torch._inductor.aoti_compile_and_package` along with :func:`torch.export.export` to generate a compiled artifact
47+
# * How to load and run the artifact in a Python runtime using :func:`torch._export.aot_load`.
48+
# * When do you use AOTInductor with a Python runtime
5049

5150
######################################################################
5251
# Model Compilation
5352
# -----------------
5453
#
55-
# We will use the TorchVision pretrained `ResNet18` model and TorchInductor on the
56-
# exported PyTorch program using :func:`torch._inductor.aot_compile`.
54+
# We will use the TorchVision pretrained ``ResNet18`` model as an example.
5755
#
58-
# .. note::
56+
# The first step is to export the model to a graph representation using
57+
# :func:`torch.export.export`. To learn more about using this function, you can
58+
# check out the `docs <https://pytorch.org/docs/main/export.html>`_ or the
59+
# `tutorial <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`_.
5960
#
60-
# This API also supports :func:`torch.compile` options like ``mode``
61-
# This means that if used on a CUDA enabled device, you can, for example, set ``"max_autotune": True``
62-
# which leverages Triton based matrix multiplications & convolutions, and enables CUDA graphs by default.
61+
# Once we have exported the PyTorch model and obtained an ``ExportedProgram``,
62+
# we can apply :func:`torch._inductor.aoti_compile_and_package` to AOTInductor
63+
# compile the program to a specified device, and save the generated contents
64+
# into a ".pt2" artifact.
6365
#
64-
# We also specify ``dynamic_shapes`` for the batch dimension. In this example, ``min=2`` is not a bug and is
65-
# explained in `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`__
66-
66+
# .. note::
67+
#
68+
# This API supports the same available options that :func:`torch.compile`
69+
# has, like ``mode`` and ``max_autotune`` (for those who want to enable
70+
# CUDA graphs and leverage Triton based matrix multiplications and
71+
# convolutions)
6772

6873
import os
6974
import torch
75+
import torch._inductor
7076
from torchvision.models import ResNet18_Weights, resnet18
7177

7278
model = resnet18(weights=ResNet18_Weights.DEFAULT)
7379
model.eval()
7480

7581
with torch.inference_mode():
82+
inductor_configs = {}
7683

77-
# Specify the generated shared library path
78-
aot_compile_options = {
79-
"aot_inductor.output_path": os.path.join(os.getcwd(), "resnet18_pt2.so"),
80-
}
8184
if torch.cuda.is_available():
8285
device = "cuda"
83-
aot_compile_options.update({"max_autotune": True})
86+
inductor_configs["max_autotune"] = True
8487
else:
8588
device = "cpu"
8689

8790
model = model.to(device=device)
8891
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
8992

90-
# min=2 is not a bug and is explained in the 0/1 Specialization Problem
91-
batch_dim = torch.export.Dim("batch", min=2, max=32)
9293
exported_program = torch.export.export(
9394
model,
9495
example_inputs,
9596
# Specify the first dimension of the input x as dynamic
96-
dynamic_shapes={"x": {0: batch_dim}},
97+
dynamic_shapes={"x": {0: torch.export.Dim.DYNAMIC}},
9798
)
98-
so_path = torch._inductor.aot_compile(
99-
exported_program.module(),
100-
example_inputs,
101-
# Specify the generated shared library path
102-
options=aot_compile_options
99+
path = torch._inductor.aoti_compile_and_package(
100+
exported_program,
101+
package_path=os.path.join(os.getcwd(), "resnet18.pt2"),
102+
inductor_configs=inductor_configs
103103
)
104104

105+
######################################################################
106+
# The result of :func:`aoti_compile_and_package` is an artifact "resnet18.pt2"
107+
# which can be loaded and executed in Python and C++.
108+
#
109+
# The artifact itself contains a bunch of AOTInductor generated code, such as
110+
# a generated c++ runner file, a shared library compiled from the c++ file, and
111+
# cubin files if optimizing for CUDA.
112+
#
113+
# Structure wise, the artifact is a structured zip file, with the following
114+
# specification:
115+
#
116+
# .. code::
117+
# .
118+
# ├── archive_format
119+
# ├── version
120+
# ├── data
121+
# │ ├── aotinductor
122+
# │ │ └── model
123+
# │ │ ├── xxx.cpp # AOTInductor generated cpp file
124+
# │ │ ├── xxx.so # AOTInductor generated shared library
125+
# │ │ ├── xxx.cubin # Cubin files (if running on CUDA)
126+
# │ │ └── xxx_metadata.json # Additional metadata to save
127+
# │ ├── weights
128+
# │ │ └── TBD
129+
# │ └── constants
130+
# │ └── TBD
131+
# └── extra
132+
# └── metadata.json
133+
#
134+
# We can use the following command to inspect the artifact contents:
135+
#
136+
# .. code:: bash
137+
#
138+
# $ unzip -l resnet18.pt2
139+
#
140+
# .. code::
141+
#
142+
# Archive: resnet18.pt2
143+
# Length Date Time Name
144+
# --------- ---------- ----- ----
145+
# 1 01-08-2025 16:40 version
146+
# 3 01-08-2025 16:40 archive_format
147+
# 10088 01-08-2025 16:40 data/aotinductor/model/cagzt6akdaczvxwtbvqe34otfe5jlorktbqlojbzqjqvbfsjlge4.cubin
148+
# 17160 01-08-2025 16:40 data/aotinductor/model/c6oytfjmt5w4c7onvtm6fray7clirxt7q5xjbwx3hdydclmwoujz.cubin
149+
# 16616 01-08-2025 16:40 data/aotinductor/model/c7ydp7nocyz323hij4tmlf2kcedmwlyg6r57gaqzcsy3huneamu6.cubin
150+
# 17776 01-08-2025 16:40 data/aotinductor/model/cyqdf46ordevqhiddvpdpp3uzwatfbzdpl3auj2nx23uxvplnne2.cubin
151+
# 10856 01-08-2025 16:40 data/aotinductor/model/cpzfebfgrusqslui7fxsuoo4tvwulmrxirc5tmrpa4mvrbdno7kn.cubin
152+
# 14608 01-08-2025 16:40 data/aotinductor/model/c5ukeoz5wmaszd7vczdz2qhtt6n7tdbl3b6wuy4rb2se24fjwfoy.cubin
153+
# 11376 01-08-2025 16:40 data/aotinductor/model/csu3nstcp56tsjfycygaqsewpu64l5s6zavvz7537cm4s4cv2k3r.cubin
154+
# 10984 01-08-2025 16:40 data/aotinductor/model/cp76lez4glmgq7gedf2u25zvvv6rksv5lav4q22dibd2zicbgwj3.cubin
155+
# 14736 01-08-2025 16:40 data/aotinductor/model/c2bb5p6tnwz4elgujqelsrp3unvkgsyiv7xqxmpvuxcm4jfl7pc2.cubin
156+
# 11376 01-08-2025 16:40 data/aotinductor/model/c6eopmb2b4ngodwsayae4r5q6ni3jlfogfbdk3ypg56tgpzhubfy.cubin
157+
# 11624 01-08-2025 16:40 data/aotinductor/model/chmwe6lvoekzfowdbiizitm3haiiuad5kdm6sd2m6mv6dkn2zk32.cubin
158+
# 15632 01-08-2025 16:40 data/aotinductor/model/c3jop5g344hj3ztsu4qm6ibxyaaerlhkzh2e6emak23rxfje6jam.cubin
159+
# 25472 01-08-2025 16:40 data/aotinductor/model/chaiixybeiuuitm2nmqnxzijzwgnn2n7uuss4qmsupgblfh3h5hk.cubin
160+
# 139389 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.cpp
161+
# 27 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t_metadata.json
162+
# 47195424 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.so
163+
# --------- -------
164+
# 47523148 18 files
165+
105166

106167
######################################################################
107168
# Model Inference in Python
108169
# -------------------------
109170
#
110-
# Typically, the shared object generated above is used in a non-Python environment. In PyTorch 2.3,
111-
# we added a new API called :func:`torch._export.aot_load` to load the shared library in the Python runtime.
112-
# The API follows a structure similar to the :func:`torch.jit.load` API . You need to specify the path
113-
# of the shared library and the device where it should be loaded.
171+
# To load and run the artifact in Python, we can use :func:`torch._inductor.aoti_load_package`.
172+
#
173+
# There are mainly two reasons why one would use AOTInductor with the Python runtime:
174+
#
175+
# - ``torch._inductor.aot_compile`` generates a shared library. This is useful for model
176+
# versioning for deployments and tracking model performance over time.
177+
# - With :func:`torch.compile` being a JIT compiler, there is a warmup
178+
# cost associated with the first compilation. Your deployment needs to account for the
179+
# compilation time taken for the first inference. With AOTInductor, the compilation is
180+
# done offline using ``torch.export.export`` & ``torch._indutor.aot_compile``. The deployment
181+
# would only load the shared library using ``torch._export.aot_load`` and run inference.
114182
#
115-
# .. note::
116-
# In the example above, we specified ``batch_size=1`` for inference and it still functions correctly even though we specified ``min=2`` in
117-
# :func:`torch.export.export`.
118183

119184

120185
import os
121186
import torch
187+
import torch._inductor
122188

123-
device = "cuda" if torch.cuda.is_available() else "cpu"
124-
model_so_path = os.path.join(os.getcwd(), "resnet18_pt2.so")
189+
model_path = os.path.join(os.getcwd(), "resnet18.pt2")
125190

126-
model = torch._export.aot_load(model_so_path, device)
127-
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
191+
compiled_model = torch._inductor.aoti_load_package(model_path)
192+
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
128193

129194
with torch.inference_mode():
130-
output = model(example_inputs)
195+
output = compiled_model(example_inputs)
196+
131197

132198
######################################################################
133-
# When to use AOTInductor for Python Runtime
134-
# ------------------------------------------
199+
# When to use AOTInductor with a Python Runtime
200+
# ---------------------------------------------
135201
#
136-
# One of the requirements for using AOTInductor is that the model shouldn't have any graph breaks.
137-
# Once this requirement is met, the primary use case for using AOTInductor Python Runtime is for
138-
# model deployment using Python.
139-
# There are mainly two reasons why you would use AOTInductor Python Runtime:
202+
# There are mainly two reasons why one would use AOTInductor with a Python Runtime:
140203
#
141-
# - ``torch._inductor.aot_compile`` generates a shared library. This is useful for model
142-
# versioning for deployments and tracking model performance over time.
204+
# - ``torch._inductor.aoti_compile_and_package`` generates a singular
205+
# serialized artifact. This is useful for model versioning for deployments
206+
# and tracking model performance over time.
143207
# - With :func:`torch.compile` being a JIT compiler, there is a warmup
144-
# cost associated with the first compilation. Your deployment needs to account for the
145-
# compilation time taken for the first inference. With AOTInductor, the compilation is
146-
# done offline using ``torch.export.export`` & ``torch._indutor.aot_compile``. The deployment
147-
# would only load the shared library using ``torch._export.aot_load`` and run inference.
208+
# cost associated with the first compilation. Your deployment needs to
209+
# account for the compilation time taken for the first inference. With
210+
# AOTInductor, the compilation is done ahead of time using
211+
# ``torch.export.export`` and ``torch._inductor.aoti_compile_and_package``.
212+
# At deployment time, after loading the model, running inference does not
213+
# have any additional cost.
148214
#
149215
#
150216
# The section below shows the speedup achieved with AOTInductor for first inference
@@ -185,7 +251,7 @@ def timed(fn):
185251

186252
torch._dynamo.reset()
187253

188-
model = torch._export.aot_load(model_so_path, device)
254+
model = torch._inductor.aoti_load_package(model_path)
189255
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
190256

191257
with torch.inference_mode():
@@ -217,8 +283,8 @@ def timed(fn):
217283
# ----------
218284
#
219285
# In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by
220-
# compiling and loading a pretrained ``ResNet18`` model using the ``torch._inductor.aot_compile``
221-
# and ``torch._export.aot_load`` APIs. This process demonstrates the practical application of
222-
# generating a shared library and running it within a Python environment, even with dynamic shape
223-
# considerations and device-specific optimizations. We also looked at the advantage of using
286+
# compiling and loading a pretrained ``ResNet18`` model. This process
287+
# demonstrates the practical application of generating a compiled artifact and
288+
# running it within a Python environment, even with dynamic shape considerations
289+
# and device-specific optimizations. We also looked at the advantage of using
224290
# AOTInductor in model deployments, with regards to speed up in first inference time.

0 commit comments

Comments
 (0)