Skip to content

Commit 039a26f

Browse files
committed
add note and check for GPU capability
1 parent 07a7ae2 commit 039a26f

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

intermediate_source/torch_compile_tutorial.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,25 @@
3333
# - ``numpy``
3434
# - ``scipy``
3535
# - ``tabulate``
36-
#
37-
# Note: a modern NVIDIA GPU (Volta or Ampere) is recommended for this tutorial.
38-
#
36+
37+
######################################################################
38+
# NOTE: a modern NVIDIA GPU (A100 or V100) is recommended for this tutorial in
39+
# order to reproduce the speedup numbers shown below and documented elsewhere.
40+
41+
import torch
42+
import warnings
43+
44+
gpu_ok = False
45+
if torch.cuda.is_available():
46+
device = torch.cuda.get_device_name()
47+
if "V100" in device or "A100" in device:
48+
gpu_ok = True
49+
50+
if not gpu_ok:
51+
warnings.warn(
52+
"GPU is not NVIDIA V100 or A100. Speedup numbers may be lower than "
53+
"expected."
54+
)
3955

4056
######################################################################
4157
# Basic Usage
@@ -51,8 +67,6 @@
5167
# ``torch.compile``. We can then call the returned optimized
5268
# function in place of the original function.
5369

54-
import torch
55-
5670
def foo(x, y):
5771
a = torch.sin(x)
5872
b = torch.cos(x)

0 commit comments

Comments
 (0)