Skip to content

Commit e8bc3a4

Browse files
committed
Addressing review comments- tmp dir for wheel download and wheel extraction, variable for py_version
1 parent 23d27b0 commit e8bc3a4

File tree

1 file changed

+121
-59
lines changed

1 file changed

+121
-59
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 121 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,27 @@
22

33
import ctypes
44
import gc
5+
import getpass
56
import logging
67
import os
8+
import tempfile
79
import urllib.request
810
import warnings
11+
from contextlib import contextmanager
912
from dataclasses import fields, replace
1013
from enum import Enum
11-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
14+
from pathlib import Path
15+
from typing import (
16+
Any,
17+
Callable,
18+
Dict,
19+
Iterator,
20+
List,
21+
Optional,
22+
Sequence,
23+
Tuple,
24+
Union,
25+
)
1226

1327
import numpy as np
1428
import sympy
@@ -37,6 +51,7 @@
3751
RTOL = 5e-3
3852
ATOL = 5e-3
3953
CPU_DEVICE = "cpu"
54+
_WHL_CPYTHON_VERSION = "3.10"
4055

4156

4257
class Frameworks(Enum):
@@ -823,17 +838,40 @@ def is_tegra_platform() -> bool:
823838
return False
824839

825840

826-
def download_plugin_lib_path(py_version: str, platform: str) -> str:
827-
plugin_lib_path = None
841+
@contextmanager
842+
def download_plugin_lib_path(platform: str) -> Iterator[str]:
843+
"""
844+
Downloads (if needed) and extracts the TensorRT-LLM plugin wheel for the specified platform,
845+
then yields the path to the extracted shared library (.so or .dll).
846+
847+
The wheel file is cached in a user-specific temporary directory to avoid repeated downloads.
848+
Extraction happens in a temporary directory that is cleaned up after use.
849+
850+
Args:
851+
platform (str): The platform identifier string (e.g., 'linux_x86_64') to select the correct wheel.
828852
829-
# Downloading TRT-LLM lib
830-
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
831-
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{py_version}-{py_version}-{platform}.whl"
832-
download_url = base_url + file_name
833-
if not (os.path.exists(file_name)):
853+
Yields:
854+
str: The full path to the extracted TensorRT-LLM shared library file.
855+
856+
Raises:
857+
ImportError: If the 'zipfile' module is not available.
858+
RuntimeError: If the wheel file is missing, corrupted, or extraction fails.
859+
"""
860+
plugin_lib_path = None
861+
username = getpass.getuser()
862+
torchtrt_cache_dir = Path(tempfile.gettempdir()) / f"torch_tensorrt_{username}"
863+
torchtrt_cache_dir.mkdir(parents=True, exist_ok=True)
864+
file_name = f"tensorrt_llm-{__tensorrt_llm_version__}-{_WHL_CPYTHON_VERSION}-{_WHL_CPYTHON_VERSION}-{platform}.whl"
865+
torchtrt_cache_trtllm_whl = torchtrt_cache_dir / file_name
866+
867+
if not torchtrt_cache_trtllm_whl.exists():
868+
# Downloading TRT-LLM lib
869+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
870+
download_url = base_url + file_name
871+
downloaded_file_path = torchtrt_cache_trtllm_whl
834872
try:
835873
logger.debug(f"Downloading {download_url} ...")
836-
urllib.request.urlretrieve(download_url, file_name)
874+
urllib.request.urlretrieve(download_url, downloaded_file_path)
837875
logger.debug("Download succeeded and TRT-LLM wheel is now present")
838876
except urllib.error.HTTPError as e:
839877
logger.error(
@@ -846,60 +884,53 @@ def download_plugin_lib_path(py_version: str, platform: str) -> str:
846884
except OSError as e:
847885
logger.error(f"Local file write error: {e}")
848886

849-
# Proceeding with the unzip of the wheel file
850-
# This will exist if the filename was already downloaded
887+
# Proceeding with the unzip of the wheel file in tmpdir
851888
if "linux" in platform:
852889
lib_filename = "libnvinfer_plugin_tensorrt_llm.so"
853890
else:
854891
lib_filename = "libnvinfer_plugin_tensorrt_llm.dll"
855-
plugin_lib_path = os.path.join("./tensorrt_llm/libs", lib_filename)
856-
if os.path.exists(plugin_lib_path):
857-
return plugin_lib_path
858-
try:
859-
import zipfile
860-
except ImportError as e:
861-
raise ImportError(
862-
"zipfile module is required but not found. Please install zipfile"
863-
)
864-
with zipfile.ZipFile(file_name, "r") as zip_ref:
865-
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
866-
plugin_lib_path = "./tensorrt_llm/libs/" + lib_filename
867-
return plugin_lib_path
868892

893+
with tempfile.TemporaryDirectory() as tmpdir:
894+
try:
895+
import zipfile
896+
except ImportError:
897+
raise ImportError(
898+
"zipfile module is required but not found. Please install zipfile"
899+
)
900+
try:
901+
with zipfile.ZipFile(downloaded_file_path, "r") as zip_ref:
902+
zip_ref.extractall(tmpdir) # Extract to a folder named 'tensorrt_llm'
903+
except FileNotFoundError as e:
904+
# This should capture the errors in the download failure above
905+
logger.error(f"Wheel file not found at {downloaded_file_path}: {e}")
906+
raise RuntimeError(
907+
f"Failed to find downloaded wheel file at {downloaded_file_path}"
908+
) from e
909+
except zipfile.BadZipFile as e:
910+
logger.error(f"Invalid or corrupted wheel file: {e}")
911+
raise RuntimeError(
912+
"Downloaded wheel file is corrupted or not a valid zip archive"
913+
) from e
914+
except Exception as e:
915+
logger.error(f"Unexpected error while extracting wheel: {e}")
916+
raise RuntimeError(
917+
"Unexpected error during extraction of TensorRT-LLM wheel"
918+
) from e
919+
plugin_lib_path = os.path.join(tmpdir, "tensorrt_llm/libs", lib_filename)
920+
yield plugin_lib_path
921+
922+
923+
def load_and_initialize_trtllm_plugin(plugin_lib_path: str) -> bool:
924+
"""
925+
Loads and initializes the TensorRT-LLM plugin from the given shared library path.
869926
870-
def load_tensorrt_llm() -> bool:
871-
"""
872-
Attempts to load the TensorRT-LLM plugin and initialize it.
873-
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
874-
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
927+
Args:
928+
plugin_lib_path (str): Path to the shared TensorRT-LLM plugin library.
875929
876930
Returns:
877-
bool: True if the plugin was successfully loaded and initialized, False otherwise.
931+
bool: True if successful, False otherwise.
878932
"""
879-
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
880-
if not plugin_lib_path:
881-
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
882-
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
883-
"1",
884-
"true",
885-
"yes",
886-
"on",
887-
)
888-
if not use_trtllm_plugin:
889-
logger.warning(
890-
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
891-
)
892-
return False
893-
else:
894-
# this is used as the default py version
895-
py_version = "cp310"
896-
platform = Platform.current_platform()
897-
898-
platform = str(platform).lower()
899-
plugin_lib_path = download_plugin_lib_path(py_version, platform)
900-
901933
try:
902-
# Load the shared TRT-LLM file
903934
handle = ctypes.CDLL(plugin_lib_path)
904935
logger.info(f"Successfully loaded plugin library: {plugin_lib_path}")
905936
except OSError as e_os_error:
@@ -912,14 +943,13 @@ def load_tensorrt_llm() -> bool:
912943
)
913944
else:
914945
logger.warning(
915-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
916-
f"Ensure the path is correct and the library is compatible",
946+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}. "
947+
f"Ensure the path is correct and the library is compatible.",
917948
exc_info=e_os_error,
918949
)
919950
return False
920951

921952
try:
922-
# Configure plugin initialization arguments
923953
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
924954
handle.initTrtLlmPlugins.restype = ctypes.c_bool
925955
except AttributeError as e_plugin_unavailable:
@@ -930,9 +960,7 @@ def load_tensorrt_llm() -> bool:
930960
return False
931961

932962
try:
933-
# Initialize the plugin
934-
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
935-
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
963+
if handle.initTrtLlmPlugins(None, b"tensorrt_llm"):
936964
logger.info("TensorRT-LLM plugin successfully initialized")
937965
return True
938966
else:
@@ -945,3 +973,37 @@ def load_tensorrt_llm() -> bool:
945973
)
946974
return False
947975
return False
976+
977+
978+
def load_tensorrt_llm() -> bool:
979+
"""
980+
Attempts to load the TensorRT-LLM plugin and initialize it.
981+
Either the env variable TRTLLM_PLUGINS_PATH can specify the path
982+
Or the user can specify USE_TRTLLM_PLUGINS as either of (1, true, yes, on) to download the TRT-LLM distribution and load it
983+
984+
Returns:
985+
bool: True if the plugin was successfully loaded and initialized, False otherwise.
986+
"""
987+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
988+
if plugin_lib_path:
989+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
990+
else:
991+
# this option can be used by user if TRTLLM_PLUGINS_PATH is not set by user
992+
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
993+
"1",
994+
"true",
995+
"yes",
996+
"on",
997+
)
998+
if not use_trtllm_plugin:
999+
logger.warning(
1000+
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library. Please set either of the two to use TRT-LLM libraries in torchTRT"
1001+
)
1002+
return False
1003+
else:
1004+
platform = Platform.current_platform()
1005+
platform = str(platform).lower()
1006+
1007+
with download_plugin_lib_path(platform) as plugin_lib_path:
1008+
return load_and_initialize_trtllm_plugin(plugin_lib_path)
1009+
return False

0 commit comments

Comments
 (0)