Skip to content

feat: Saving modules using the AOTI format #3567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,5 @@ MODULE.bazel.lock
*.whl
.coverage
coverage.xml
*.log
*.log
*.pt2
34 changes: 29 additions & 5 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.sav
Dynamo IR
-------------

The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`), `ExportedProgram` (`torch.export.ExportedProgram`) or `PT2` formats by
specifying the `output_format` flag. Here are the options `output_format` will accept

* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.
* `PT2 Format` : This is a next generation runtime for PyTorch models, allowing them to run in Python and in C++

a) ExportedProgram
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -52,8 +53,8 @@ b) Torchscript
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
trt_gm = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", arg_inputs=inputs)

# Later, you can load it and run inference
model = torch.jit.load("trt.ts").cuda()
Expand All @@ -73,7 +74,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object
trt_ts = torch_tensorrt.compile(model, ir="ts", arg_inputs=inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")

# Later, you can load it and run inference
Expand All @@ -98,3 +99,26 @@ Here's an example usage
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
model = torch_tensorrt.load(<file_path>).module()
model(*inputs)

b) PT2 Format
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

PT2 is a new format that allows models to be run outside of Python in the future. It utilizes `AOTInductor <https://docs.pytorch.org/docs/main/torch.compiler_aot_inductor.html>`_
to generate kernels for components that will not be run in TensorRT.

Here's an example on how to save and load Torch-TensorRT Module using AOTInductor in Python

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.pt2", arg_inputs=inputs, output_format="aot_inductor", retrace=True)

# Later, you can load it and run inference
model = torch._inductor.aoti_load_package("trt.pt2")
model(*inputs)
11 changes: 8 additions & 3 deletions examples/torchtrt_runtime_example/Makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
CXX=g++
DEP_DIR=$(PWD)/deps
INCLUDE_DIRS=-I$(DEP_DIR)/libtorch/include -I$(DEP_DIR)/torch_tensorrt/include
LIB_DIRS=-L$(DEP_DIR)/torch_tensorrt/lib -L$(DEP_DIR)/libtorch/lib # -Wl,-rpath $(DEP_DIR)/tensorrt/lib
LIBS=-Wl,--no-as-needed -ltorchtrt_runtime -Wl,--as-needed -ltorch -ltorch_cuda -ltorch_cpu -ltorch_global_deps -lbackend_with_compiler -lc10 -lc10_cuda
CUDA_HOME=/usr/local/cuda
INCLUDE_DIRS=-I$(DEP_DIR)/libtorch/include -I$(DEP_DIR)/torch_tensorrt/include -I$(CUDA_HOME)/include -I$(DEP_DIR)/libtorch/include/torch/csrc/api/include
LIB_DIRS=-L$(DEP_DIR)/torch_tensorrt/lib -L$(DEP_DIR)/libtorch/lib -Wl,-rpath $(DEP_DIR)/tensorrt/lib
LIBS=-Wl,--no-as-needed -ltorchtrt_runtime -ltorchtrt_plugins -Wl,--as-needed -ltorch -ltorch_cuda -ltorch_cpu -ltorch_global_deps -lbackend_with_compiler -lc10 -lc10_cuda
SRCS=main.cpp

TARGET=torchtrt_runtime_example

$(TARGET):
$(CXX) $(SRCS) $(INCLUDE_DIRS) $(LIB_DIRS) $(LIBS) -o $(TARGET)
echo "Add to LD_LIBRARY_PATH: $(DEP_DIR)/torch_tensorrt/lib:$(DEP_DIR)/libtorch/lib:$(DEP_DIR)/tensorrt/lib:$(CUDA_HOME)/lib64"

generate_pt2:
uv run network.py

clean:
$(RM) $(TARGET)
64 changes: 48 additions & 16 deletions examples/torchtrt_runtime_example/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,63 @@
#include <memory>
#include <sstream>
#include <vector>
#include "torch/script.h"
#include "torch/torch.h"
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"
#include "torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h"

/*
* This example demonstrates how to load and run a pre-built Torch-TensorRT
* AOTInductor (AOTI) model package using the PyTorch C++ API.
*
* Usage:
* torchtrt_runtime_example <path-to-pre-built-trt-aoti module>
*
* Steps:
* 1. Parse the path to the AOTI model package from the command line.
* 2. Load the model package using AOTIModelPackageLoader.
* 3. Prepare a random CUDA tensor as input.
* 4. Run inference using the loaded model.
* 5. Print the output tensor(s) or an error message if inference fails.
*/

int main(int argc, const char* argv[]) {
// Check for correct number of command-line arguments
if (argc < 2) {
std::cerr << "usage: samplertapp <path-to-pre-built-trt-ts module>\n";
std::cerr << "usage: torchtrt_runtime_example <path-to-pre-built-trt-aoti module>\n";
return -1;
}

std::string trt_ts_module_path = argv[1];
// Get the path to the TRT AOTI model package from the command line
std::string trt_aoti_module_path = argv[1];

torch::jit::Module trt_ts_mod;
// Enable inference mode for thread-local optimizations
c10::InferenceMode mode;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
trt_ts_mod = torch::jit::load(trt_ts_module_path);
// Load the AOTI model package
torch::inductor::AOTIModelPackageLoader runner(trt_aoti_module_path);

// Create a random input tensor on CUDA with shape [1, 3, 5, 5] and type float32
std::vector<at::Tensor> inputs = {at::randn({1, 3, 5, 5}, {at::kCUDA}).to(torch::kFloat32)};

// Run inference using the loaded model
std::vector<at::Tensor> outputs = runner.run(inputs);

// Process and print the output tensor(s)
if (!outputs.empty()) {
std::cout << "Model output: " << outputs[0] << std::endl;
} else {
std::cerr << "No output tensors received!" << std::endl;
}

} catch (const c10::Error& e) {
std::cerr << "error loading the model from : " << trt_ts_module_path << std::endl;
return -1;
// Handle errors from the PyTorch C++ API
std::cerr << "Error running model: " << e.what() << std::endl;
return 1;
} catch (const std::exception& e) {
// Handle other standard exceptions
std::cerr << "An unexpected error occurred: " << e.what() << std::endl;
return 1;
}

std::cout << "Running TRT engine" << std::endl;
std::vector<torch::jit::IValue> trt_inputs_ivalues;
trt_inputs_ivalues.push_back(at::randint(-5, 5, {1, 3, 5, 5}, {at::kCUDA}).to(torch::kFloat32));
torch::jit::IValue trt_results_ivalues = trt_ts_mod.forward(trt_inputs_ivalues);
std::cout << "==================TRT outputs================" << std::endl;
std::cout << trt_results_ivalues << std::endl;
std::cout << "=============================================" << std::endl;
std::cout << "TRT engine execution completed. " << std::endl;
return 0;
}
40 changes: 30 additions & 10 deletions examples/torchtrt_runtime_example/network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch_tensorrt as torchtrt
import torch_tensorrt


# create a simple norm layer.
Expand Down Expand Up @@ -29,21 +29,41 @@ def forward(self, x):

def main():
model = ConvGelu().eval().cuda()
scripted_model = torch.jit.script(model)

torch_ex_input = torch.randn([1, 3, 5, 5], device="cuda")
compile_settings = {
"inputs": [torchtrt.Input([1, 3, 5, 5])],
"arg_inputs": [torch_ex_input],
"ir": "dynamo",
"enabled_precisions": {torch.float32},
"min_block_size": 1,
}

trt_ts_module = torchtrt.compile(scripted_model, **compile_settings)
torch.jit.save(trt_ts_module, "conv_gelu.jit")
cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
torch_tensorrt.save(
cg_trt_module,
file_path="torchtrt_aoti_conv_gelu.pt2",
output_format="aot_inductor",
retrace=True,
arg_inputs=[torch_ex_input],
)

norm_model = Norm().eval().cuda()
norm_ts_module = torch.jit.script(norm_model)
norm_trt_ts = torchtrt.compile(norm_ts_module, **compile_settings)
torch.jit.save(norm_trt_ts, "norm.jit")
print("Generated Torchscript-TRT models.")
norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
torch_tensorrt.save(
norm_trt_module,
file_path="torchtrt_aoti_norm.pt2",
output_format="aot_inductor",
retrace=True,
arg_inputs=[torch_ex_input],
)
print("Generated TorchTRT-AOTI models.")

loaded_cg_trt_module = torch._inductor.aoti_load_package(
"torchtrt_aoti_conv_gelu.pt2"
)
loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
with torch.inference_mode():
print(loaded_cg_trt_module(torch_ex_input))
print(loaded_norm_trt_module(torch_ex_input))


if __name__ == "__main__":
Expand Down
69 changes: 58 additions & 11 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def save(
kwarg_inputs: Optional[dict[str, Any]] = None,
retrace: bool = False,
pickle_protocol: int = 2,
**kwargs: Any,
) -> None:
"""
Save the model to disk in the specified output format.
Expand All @@ -595,15 +596,15 @@ def save(
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript.
output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models
"""
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
accepted_formats = {"exported_program", "torchscript", "aot_inductor"}
if arg_inputs is not None and not all(
isinstance(input, torch.Tensor) for input in arg_inputs
):
Expand Down Expand Up @@ -634,9 +635,9 @@ def save(
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
)
elif module_type == _ModuleType.ts:
if output_format == "exported_program":
if not all([output_format == f for f in ["exported_program", "aot_inductor"]]):
raise ValueError(
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
)
else:
if arg_inputs is not None:
Expand All @@ -654,7 +655,22 @@ def save(
logger.warning(
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
)
torch.export.save(module, file_path)
if output_format == "exported_program":
torch.export.save(module, file_path, pickle_protocol=pickle_protocol)
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]

torch._inductor.aoti_compile_and_package(
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
)
elif module_type == _ModuleType.fx:
# The module type is torch.fx.GraphModule
if output_format == "torchscript":
Expand All @@ -671,9 +687,24 @@ def save(
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
)
exp_program = export(module)
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
)
if output_format == "exported_program":
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
)
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]

torch._inductor.aoti_compile_and_package(
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
)
else:
if arg_inputs is None:
raise ValueError(
Expand All @@ -685,6 +716,22 @@ def save(
kwargs=kwarg_inputs,
strict=False,
)
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
)

if output_format == "exported_program":
torch.export.save(
exp_program, file_path, pickle_protocol=pickle_protocol
)
elif output_format == "aot_inductor":
inductor_configs = {}
if "inductor_configs" in kwargs:
inductor_configs = kwargs["inductor_configs"]

torch._inductor.aoti_compile_and_package(
exp_program,
inductor_configs=inductor_configs,
package_path=file_path,
)
else:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
)
Loading
Loading