Skip to content

Commit 2a93df3

Browse files
authored
Cross compile guard (#3486)
1 parent f09be72 commit 2a93df3

File tree

7 files changed

+60
-3
lines changed

7 files changed

+60
-3
lines changed

py/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ pybind11==2.6.2
55
torch>=2.8.0.dev,<2.9.0
66
torchvision>=0.22.0.dev,<0.23.0
77
--extra-index-url https://pypi.ngc.nvidia.com
8-
pyyaml
8+
pyyaml
9+
dllist

py/torch_tensorrt/_compile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.fx
1111
from torch_tensorrt._enums import dtype
12-
from torch_tensorrt._features import ENABLED_FEATURES
12+
from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile
1313
from torch_tensorrt._Input import Input
1414
from torch_tensorrt.dynamo import _defaults
1515
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
@@ -301,6 +301,7 @@ def compile(
301301
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
302302

303303

304+
@needs_cross_compile
304305
def cross_compile_for_windows(
305306
module: torch.nn.Module,
306307
file_path: str,

py/torch_tensorrt/_features.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from collections import namedtuple
55
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
66

7-
from torch_tensorrt._utils import sanitized_torch_version
7+
from torch_tensorrt._utils import (
8+
check_cross_compile_trt_win_lib,
9+
sanitized_torch_version,
10+
)
811

912
from packaging import version
1013

@@ -17,6 +20,7 @@
1720
"fx_frontend",
1821
"refit",
1922
"qdp_plugin",
23+
"windows_cross_compile",
2024
],
2125
)
2226

@@ -40,6 +44,7 @@
4044
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
4145
_FX_FE_AVAIL = True
4246
_REFIT_AVAIL = True
47+
_WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib()
4348

4449
if importlib.util.find_spec("tensorrt.plugin"):
4550
_QDP_PLUGIN_AVAIL = True
@@ -53,6 +58,7 @@
5358
_FX_FE_AVAIL,
5459
_REFIT_AVAIL,
5560
_QDP_PLUGIN_AVAIL,
61+
_WINDOWS_CROSS_COMPILE,
5662
)
5763

5864

@@ -108,6 +114,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
108114
return wrapper
109115

110116

117+
def needs_cross_compile(f: Callable[..., Any]) -> Callable[..., Any]:
118+
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
119+
if ENABLED_FEATURES.windows_cross_compile:
120+
return f(*args, **kwargs)
121+
else:
122+
123+
def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
124+
raise NotImplementedError(
125+
"Windows cross compilation feature is not available"
126+
)
127+
128+
return not_implemented(*args, **kwargs)
129+
130+
return wrapper
131+
132+
111133
T = TypeVar("T")
112134

113135

py/torch_tensorrt/_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import Any
2+
import sys
3+
import platform
24

35
import torch
46

@@ -9,3 +11,16 @@ def sanitized_torch_version() -> Any:
911
if ".nv" not in torch.__version__
1012
else torch.__version__.split(".nv")[0]
1113
)
14+
15+
16+
def check_cross_compile_trt_win_lib() -> bool:
17+
# cross compile feature is only available on linux
18+
# build engine on linux and run on windows
19+
import dllist
20+
21+
if sys.platform.startswith("linux"):
22+
loaded_libs = dllist.dllist()
23+
target_lib = "libnvinfer_builder_resource_win.so.*"
24+
if target_lib in loaded_libs:
25+
return True
26+
return False

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.fx.node import Target
1212
from torch_tensorrt._Device import Device
1313
from torch_tensorrt._enums import EngineCapability, dtype
14+
from torch_tensorrt._features import needs_cross_compile
1415
from torch_tensorrt._Input import Input
1516
from torch_tensorrt.dynamo import _defaults, partitioning
1617
from torch_tensorrt.dynamo._DryRunTracker import (
@@ -50,6 +51,7 @@
5051
logger = logging.getLogger(__name__)
5152

5253

54+
@needs_cross_compile
5355
def cross_compile_for_windows(
5456
exported_program: ExportedProgram,
5557
inputs: Optional[Sequence[Sequence[Any]]] = None,
@@ -1223,6 +1225,7 @@ def convert_exported_program_to_serialized_trt_engine(
12231225
return serialized_engine
12241226

12251227

1228+
@needs_cross_compile
12261229
def save_cross_compiled_exported_program(
12271230
gm: torch.fx.GraphModule,
12281231
file_path: str,

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ requires = [
1616
"numpy; 'tegra' not in platform_release",
1717
"numpy<2.0.0; 'tegra' in platform_release",
1818
"sympy",
19+
"dllist",
1920
]
2021
build-backend = "setuptools.build_meta"
2122

@@ -74,6 +75,7 @@ dependencies = [
7475
"numpy<2.0.0; 'tegra' in platform_release",
7576

7677
"typing-extensions>=4.7.0",
78+
"dllist",
7779
]
7880

7981
dynamic = ["version"]

tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch_tensorrt
99
from torch.testing._internal.common_utils import TestCase
1010
from torch_tensorrt.dynamo.utils import get_model_device
11+
from torch_tensorrt._utils import check_cross_compile_trt_win_lib
1112

1213
from ..testing_utilities import DECIMALS_OF_AGREEMENT
1314

@@ -17,6 +18,10 @@ class TestCrossCompileSaveForWindows(TestCase):
1718
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
1819
"Cross compile for windows can only be enabled on linux x86-64 platform",
1920
)
21+
@unittest.skipIf(
22+
not (check_cross_compile_trt_win_lib()),
23+
"TRT windows lib for cross compile not found",
24+
)
2025
@pytest.mark.unit
2126
def test_cross_compile_for_windows(self):
2227
class Add(torch.nn.Module):
@@ -41,6 +46,10 @@ def forward(self, a, b):
4146
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
4247
"Cross compile for windows can only be enabled on linux x86-64 platform",
4348
)
49+
@unittest.skipIf(
50+
not (check_cross_compile_trt_win_lib()),
51+
"TRT windows lib for cross compile not found",
52+
)
4453
@pytest.mark.unit
4554
def test_dynamo_cross_compile_for_windows(self):
4655
class Add(torch.nn.Module):
@@ -69,6 +78,10 @@ def forward(self, a, b):
6978
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
7079
"Cross compile for windows can only be enabled on linux x86-64 platform",
7180
)
81+
@unittest.skipIf(
82+
not (check_cross_compile_trt_win_lib()),
83+
"TRT windows lib for cross compile not found",
84+
)
7285
@pytest.mark.unit
7386
def test_dynamo_cross_compile_for_windows_cpu_offload(self):
7487
class Add(torch.nn.Module):

0 commit comments

Comments
 (0)