diff --git a/dev_dep_versions.yml b/dev_dep_versions.yml index 492035a76f..4f83678265 100644 --- a/dev_dep_versions.yml +++ b/dev_dep_versions.yml @@ -1,2 +1,3 @@ __cuda_version__: "12.8" __tensorrt_version__: "10.11.0" +__tensorrt_llm_version__: "0.17.0.post1" diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c109e3fa3c..63838dd43d 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -102,6 +102,7 @@ def cross_compile_for_windows( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -177,6 +178,7 @@ def cross_compile_for_windows( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -336,6 +338,7 @@ def cross_compile_for_windows( "enable_weight_streaming": enable_weight_streaming, "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, + "use_distributed_mode_trace": use_distributed_mode_trace, } # disable the following settings is not supported for cross compilation for windows feature @@ -437,6 +440,7 @@ def compile( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -515,6 +519,7 @@ def compile( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. + use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -683,6 +688,7 @@ def compile( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) @@ -988,6 +994,7 @@ def convert_exported_program_to_serialized_trt_engine( tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL, l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1053,6 +1060,7 @@ def convert_exported_program_to_serialized_trt_engine( enable_weight_streaming (bool): Enable weight streaming. tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). + use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs """ @@ -1172,6 +1180,7 @@ def convert_exported_program_to_serialized_trt_engine( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, + "use_distributed_mode_trace": use_distributed_mode_trace, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 685f40b254..f4e975603b 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,8 +1,6 @@ import collections -import ctypes import functools import logging -import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np @@ -1012,69 +1010,6 @@ def args_bounds_check( return args[i] if len(args) > i and args[i] is not None else replacement -def load_tensorrt_llm() -> bool: - """ - Attempts to load the TensorRT-LLM plugin and initialize it. - - Returns: - bool: True if the plugin was successfully loaded and initialized, False otherwise. - """ - try: - import tensorrt_llm as trt_llm # noqa: F401 - - _LOGGER.info("TensorRT-LLM successfully imported") - return True - except (ImportError, AssertionError) as e_import_error: - # Check for environment variable for the plugin library path - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") - if not plugin_lib_path: - _LOGGER.warning( - "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", - ) - return False - - _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") - try: - # Load the shared library - handle = ctypes.CDLL(plugin_lib_path) - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") - except OSError as e_os_error: - _LOGGER.error( - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" - f"Ensure the path is correct and the library is compatible", - exc_info=e_os_error, - ) - return False - - try: - # Configure plugin initialization arguments - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] - handle.initTrtLlmPlugins.restype = ctypes.c_bool - except AttributeError as e_plugin_unavailable: - _LOGGER.warning( - "Unable to initialize the TensorRT-LLM plugin library", - exc_info=e_plugin_unavailable, - ) - return False - - try: - # Initialize the plugin - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): - _LOGGER.info("TensorRT-LLM plugin successfully initialized") - return True - else: - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") - return False - except Exception as e_initialization_error: - _LOGGER.warning( - "Exception occurred during TensorRT-LLM plugin library initialization", - exc_info=e_initialization_error, - ) - return False - return False - - def promote_trt_tensors_to_same_dtype( ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str ) -> tuple[TRTTensor, TRTTensor]: diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 79611c7552..3e67457e54 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -11,11 +11,11 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, ) -from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( tensorrt_fused_nccl_all_gather_op, tensorrt_fused_nccl_reduce_scatter_op, ) +from torch_tensorrt.dynamo.utils import load_tensorrt_llm _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index e0b3af7e0b..e800de4e82 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,11 +1,28 @@ from __future__ import annotations +import ctypes import gc +import getpass import logging +import os +import tempfile +import urllib.request import warnings +from contextlib import contextmanager from dataclasses import fields, replace from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import sympy @@ -14,9 +31,10 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device -from torch_tensorrt._enums import dtype +from torch_tensorrt._enums import Platform, dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input +from torch_tensorrt._version import __tensorrt_llm_version__ from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._engine_cache import BaseEngineCache @@ -33,6 +51,7 @@ RTOL = 5e-3 ATOL = 5e-3 CPU_DEVICE = "cpu" +_WHL_CPYTHON_VERSION = "3.10" class Frameworks(Enum): @@ -817,3 +836,174 @@ def is_tegra_platform() -> bool: if torch.cuda.get_device_capability() in [(8, 7), (7, 2)]: return True return False + + +@contextmanager +def download_plugin_lib_path(platform: str) -> Iterator[str]: + """ + Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform, + then yields the path to the extracted shared library (.so or .dll). + + The wheel file is cached in a user-specific temporary directory to avoid repeated downloads. + Extraction happens in a temporary directory that is cleaned up after use. + + Args: + platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel. + + Yields: + str: The full path to the extracted TensorRT-LLM shared library file. + + Raises: + ImportError: If the 'zipfile' module is not available. + RuntimeError: If the wheel file is missing, corrupted, or extraction fails. + """ + plugin_lib_path = None + username = getpass.getuser() + torchtrt_cache_dir = Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}" + torchtrt_cache_dir.mkdir(parents=True, exist_ok=True) + file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-{_WHL_CPYTHON_VERSION}-{platform}.whl" + torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name + + if not torchtrt_cache_trtllm_whl.exists(): + # Downloading TRT-LLM lib + base_url = "https://pypi.nvidia.com/tensorrt-llm/" + download_url = base_url + file_name + downloaded_file_path = torchtrt_cache_trtllm_whl + try: + logger.debug(f"Downloading {download_url} ...") + urllib.request.urlretrieve(download_url, downloaded_file_path) + logger.debug("Download succeeded and TRT-LLM wheel is now present") + except urllib.error.HTTPError as e: + logger.error( + f"HTTP error {e.code} when trying to download {download_url}: {e.reason}" + ) + except urllib.error.URLError as e: + logger.error( + f"URL error when trying to download {download_url}: {e.reason}" + ) + except OSError as e: + logger.error(f"Local file write error: {e}") + + # Proceeding with the unzip of the wheel file in tmpdir + if "linux" in platform: + lib_filename = "libnvinfer_plugin_tensorrt_llm.so" + else: + lib_filename = "libnvinfer_plugin_tensorrt_llm.dll" + + with tempfile.TemporaryDirectory() as tmpdir: + try: + import zipfile + except ImportError: + raise ImportError( + "zipfile module is required but not found. Please install zipfile" + ) + try: + with zipfile.ZipFile(downloaded_file_path, "r") as zip_ref: + zip_ref.extractall(tmpdir) # Extract to a folder named 'tensorrt_llm' + except FileNotFoundError as e: + # This should capture the errors in the download failure above + logger.error(f"Wheel file not found at {downloaded_file_path}: {e}") + raise RuntimeError( + f"Failed to find downloaded wheel file at {downloaded_file_path}" + ) from e + except zipfile.BadZipFile as e: + logger.error(f"Invalid or corrupted wheel file: {e}") + raise RuntimeError( + "Downloaded wheel file is corrupted or not a valid zip archive" + ) from e + except Exception as e: + logger.error(f"Unexpected error while extracting wheel: {e}") + raise RuntimeError( + "Unexpected error during extraction of TensorRT-LLM wheel" + ) from e + plugin_lib_path = os.path.join(tmpdir, "tensorrt_llm/libs", lib_filename) + yield plugin_lib_path + + +def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool: + """ + Loads and initializes the TensorRT-LLM plugin from the given shared library path. + + Args: + plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library. + + Returns: + bool: True if successful, False otherwise. + """ + try: + handle = ctypes.CDLL(plugin_lib_path) + logger.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + if "libmpi" in str(e_os_error): + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " + f"The dependency libmpi.so is missing. " + f"Please install the packages libmpich-dev and libopenmpi-dev.", + exc_info=e_os_error, + ) + else: + logger.warning( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. " + f"Ensure the path is correct and the library is compatible.", + exc_info=e_os_error, + ) + return False + + try: + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + logger.warning( + "Unable to initialize the TensorRT-LLM plugin library", + exc_info=e_plugin_unavailable, + ) + return False + + try: + if handle.initTrtLlmPlugins(None, b"tensorrt_llm"): + logger.info("TensorRT-LLM plugin successfully initialized") + return True + else: + logger.warning("TensorRT-LLM plugin library failed in initialization") + return False + except Exception as e_initialization_error: + logger.warning( + "Exception occurred during TensorRT-LLM plugin library initialization", + exc_info=e_initialization_error, + ) + return False + return False + + +def load_tensorrt_llm() -> bool: + """ + Attempts to load the TensorRT-LLM plugin and initialize it. + Either the env variable TRTLLM_PLUGINS_PATH can specify the path + Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it + + Returns: + bool: True if the plugin was successfully loaded and initialized, False otherwise. + """ + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + if plugin_lib_path: + return load_and_initialize_trtllm_plugin(plugin_lib_path) + else: + # this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if not use_trtllm_plugin: + logger.warning( + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT" + ) + return False + else: + platform = Platform.current_platform() + platform = str(platform).lower() + + with download_plugin_lib_path(platform) as plugin_lib_path: + return load_and_initialize_trtllm_plugin(plugin_lib_path) + return False diff --git a/setup.py b/setup.py index a342967a92..067e961d54 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ __version__: str = "0.0.0" __cuda_version__: str = "0.0" __tensorrt_version__: str = "0.0" +__tensorrt_llm_version__: str = "0.0" LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$") @@ -63,6 +64,7 @@ def get_base_version() -> str: def load_dep_info(): global __cuda_version__ global __tensorrt_version__ + global __tensorrt_llm_version__ with open("dev_dep_versions.yml", "r") as stream: versions = yaml.safe_load(stream) if (gpu_arch_version := os.environ.get("CU_VERSION")) is not None: @@ -72,6 +74,7 @@ def load_dep_info(): else: __cuda_version__ = versions["__cuda_version__"] __tensorrt_version__ = versions["__tensorrt_version__"] + __tensorrt_llm_version__ = versions["__tensorrt_llm_version__"] load_dep_info() @@ -241,6 +244,7 @@ def gen_version_file(): f.write('__version__ = "' + __version__ + '"\n') f.write('__cuda_version__ = "' + __cuda_version__ + '"\n') f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n') + f.write('__tensorrt_llm_version__ = "' + __tensorrt_llm_version__ + '"\n') def copy_libtorchtrt(multilinux=False, rt_only=False):