Skip to content

Commit 72102ee

Browse files
author
Svetlana Karslioglu
authored
Merge branch 'main' into patch-1
2 parents 98e05cf + c26d6a5 commit 72102ee

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

intermediate_source/torch_compile_tutorial.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,25 @@
3333
# - ``numpy``
3434
# - ``scipy``
3535
# - ``tabulate``
36-
#
37-
# Note: a modern NVIDIA GPU (Volta or Ampere) is recommended for this tutorial.
38-
#
36+
37+
######################################################################
38+
# NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in
39+
# order to reproduce the speedup numbers shown below and documented elsewhere.
40+
41+
import torch
42+
import warnings
43+
44+
gpu_ok = False
45+
if torch.cuda.is_available():
46+
device_cap = torch.cuda.get_device_capability()
47+
if device_cap in ((7, 0), (8, 0), (9, 0)):
48+
gpu_ok = True
49+
50+
if not gpu_ok:
51+
warnings.warn(
52+
"GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
53+
"than expected."
54+
)
3955

4056
######################################################################
4157
# Basic Usage
@@ -51,8 +67,6 @@
5167
# ``torch.compile``. We can then call the returned optimized
5268
# function in place of the original function.
5369

54-
import torch
55-
5670
def foo(x, y):
5771
a = torch.sin(x)
5872
b = torch.cos(x)

0 commit comments

Comments
 (0)