Skip to content

TRT-LLM loading mechanism tool #3398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev_dep_versions.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
__cuda_version__: "12.8"
__tensorrt_version__: "10.11.0"
__tensorrt_llm_version__: "0.17.0.post1"
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def cross_compile_for_windows(
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -177,6 +178,7 @@ def cross_compile_for_windows(
enable_weight_streaming (bool): Enable weight streaming.
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -336,6 +338,7 @@ def cross_compile_for_windows(
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"use_distributed_mode_trace": use_distributed_mode_trace,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -437,6 +440,7 @@ def compile(
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -515,6 +519,7 @@ def compile(
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -683,6 +688,7 @@ def compile(
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -988,6 +994,7 @@ def convert_exported_program_to_serialized_trt_engine(
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -1053,6 +1060,7 @@ def convert_exported_program_to_serialized_trt_engine(
enable_weight_streaming (bool): Enable weight streaming.
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -1172,6 +1180,7 @@ def convert_exported_program_to_serialized_trt_engine(
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
}

settings = CompilationSettings(**compilation_options)
Expand Down
65 changes: 0 additions & 65 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import collections
import ctypes
import functools
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
Expand Down Expand Up @@ -1012,69 +1010,6 @@ def args_bounds_check(
return args[i] if len(args) > i and args[i] is not None else replacement


def load_tensorrt_llm() -> bool:
"""
Attempts to load the TensorRT-LLM plugin and initialize it.

Returns:
bool: True if the plugin was successfully loaded and initialized, False otherwise.
"""
try:
import tensorrt_llm as trt_llm # noqa: F401

_LOGGER.info("TensorRT-LLM successfully imported")
return True
except (ImportError, AssertionError) as e_import_error:
# Check for environment variable for the plugin library path
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
if not plugin_lib_path:
_LOGGER.warning(
"TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops",
)
return False

_LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}")
try:
# Load the shared library
handle = ctypes.CDLL(plugin_lib_path)
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
except OSError as e_os_error:
_LOGGER.error(
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
f"Ensure the path is correct and the library is compatible",
exc_info=e_os_error,
)
return False

try:
# Configure plugin initialization arguments
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as e_plugin_unavailable:
_LOGGER.warning(
"Unable to initialize the TensorRT-LLM plugin library",
exc_info=e_plugin_unavailable,
)
return False

try:
# Initialize the plugin
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
return True
else:
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
return False
except Exception as e_initialization_error:
_LOGGER.warning(
"Exception occurred during TensorRT-LLM plugin library initialization",
exc_info=e_initialization_error,
)
return False
return False


def promote_trt_tensors_to_same_dtype(
ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str
) -> tuple[TRTTensor, TRTTensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)
from torch_tensorrt.dynamo.utils import load_tensorrt_llm

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down
194 changes: 192 additions & 2 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
from __future__ import annotations

import ctypes
import gc
import getpass
import logging
import os
import tempfile
import urllib.request
import warnings
from contextlib import contextmanager
from dataclasses import fields, replace
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)

import numpy as np
import sympy
Expand All @@ -14,9 +31,10 @@
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._enums import Platform, dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt._version import __tensorrt_llm_version__
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
Expand All @@ -33,6 +51,7 @@
RTOL = 5e-3
ATOL = 5e-3
CPU_DEVICE = "cpu"
_WHL_CPYTHON_VERSION = "3.10"


class Frameworks(Enum):
Expand Down Expand Up @@ -817,3 +836,174 @@ def is_tegra_platform() -> bool:
if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]:
return True
return False


@contextmanager
def download_plugin_lib_path(platform: str) -> Iterator[str]:
"""
Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
then yields the path to the extracted shared library (.so or .dll).

The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
Extraction happens in a temporary directory that is cleaned up after use.

Args:
platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.

Yields:
str: The full path to the extracted TensorRT-LLM shared library file.

Raises:
ImportError: If the 'zipfile' module is not available.
RuntimeError: If the wheel file is missing, corrupted, or extraction fails.
"""
plugin_lib_path = None
username = getpass.getuser()
torchtrt_cache_dir = Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
torchtrt_cache_dir.mkdir(parents=True, exist_ok=True)
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-{_WHL_CPYTHON_VERSION}-{platform}.whl"
torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name

if not torchtrt_cache_trtllm_whl.exists():
# Downloading TRT-LLM lib
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
download_url = base_url + file_name
downloaded_file_path = torchtrt_cache_trtllm_whl
try:
logger.debug(f"Downloading {download_url} ...")
urllib.request.urlretrieve(download_url, downloaded_file_path)
logger.debug("Download succeeded and TRT-LLM wheel is now present")
except urllib.error.HTTPError as e:
logger.error(
f"HTTP error {e.code} when trying to download {download_url}: {e.reason}"
)
except urllib.error.URLError as e:
logger.error(
f"URL error when trying to download {download_url}: {e.reason}"
)
except OSError as e:
logger.error(f"Local file write error: {e}")

# Proceeding with the unzip of the wheel file in tmpdir
if "linux" in platform:
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
else:
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"

with tempfile.TemporaryDirectory() as tmpdir:
try:
import zipfile
except ImportError:
raise ImportError(
"zipfile module is required but not found. Please install zipfile"
)
try:
with zipfile.ZipFile(downloaded_file_path, "r") as zip_ref:
zip_ref.extractall(tmpdir) # Extract to a folder named 'tensorrt_llm'
except FileNotFoundError as e:
# This should capture the errors in the download failure above
logger.error(f"Wheel file not found at {downloaded_file_path}: {e}")
raise RuntimeError(
f"Failed to find downloaded wheel file at {downloaded_file_path}"
) from e
except zipfile.BadZipFile as e:
logger.error(f"Invalid or corrupted wheel file: {e}")
raise RuntimeError(
"Downloaded wheel file is corrupted or not a valid zip archive"
) from e
except Exception as e:
logger.error(f"Unexpected error while extracting wheel: {e}")
raise RuntimeError(
"Unexpected error during extraction of TensorRT-LLM wheel"
) from e
plugin_lib_path = os.path.join(tmpdir, "tensorrt_llm/libs", lib_filename)
yield plugin_lib_path


def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
"""
Loads and initializes the TensorRT-LLM plugin from the given shared library path.

Args:
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.

Returns:
bool: True if successful, False otherwise.
"""
try:
handle = ctypes.CDLL(plugin_lib_path)
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
except OSError as e_os_error:
if "libmpi" in str(e_os_error):
logger.warning(
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
f"The dependency libmpi.so is missing. "
f"Please install the packages libmpich-dev and libopenmpi-dev.",
exc_info=e_os_error,
)
else:
logger.warning(
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
f"Ensure the path is correct and the library is compatible.",
exc_info=e_os_error,
)
return False

try:
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as e_plugin_unavailable:
logger.warning(
"Unable to initialize the TensorRT-LLM plugin library",
exc_info=e_plugin_unavailable,
)
return False

try:
if handle.initTrtLlmPlugins(None, b"tensorrt_llm"):
logger.info("TensorRT-LLM plugin successfully initialized")
return True
else:
logger.warning("TensorRT-LLM plugin library failed in initialization")
return False
except Exception as e_initialization_error:
logger.warning(
"Exception occurred during TensorRT-LLM plugin library initialization",
exc_info=e_initialization_error,
)
return False
return False


def load_tensorrt_llm() -> bool:
"""
Attempts to load the TensorRT-LLM plugin and initialize it.
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it

Returns:
bool: True if the plugin was successfully loaded and initialized, False otherwise.
"""
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
if plugin_lib_path:
return load_and_initialize_trtllm_plugin(plugin_lib_path)
else:
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
"1",
"true",
"yes",
"on",
)
if not use_trtllm_plugin:
logger.warning(
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
)
return False
else:
platform = Platform.current_platform()
platform = str(platform).lower()

with download_plugin_lib_path(platform) as plugin_lib_path:
return load_and_initialize_trtllm_plugin(plugin_lib_path)
return False
Loading
Loading