|
34 | 34 |
|
35 | 35 | # exit cleanly if we are on a device that doesn't support torch.compile
|
36 | 36 | if torch.cuda.get_device_capability() < (7, 0):
|
37 |
| - print("Exiting because torch.compile is not supported on this device.") |
38 |
| - import sys |
| 37 | + print("Skipping because torch.compile is not supported on this device.") |
| 38 | +else: |
| 39 | + @torch.compile() |
| 40 | + def fn(x, y): |
| 41 | + z = x + y |
| 42 | + return z + 2 |
39 | 43 |
|
40 |
| - sys.exit(0) |
41 | 44 |
|
42 |
| - |
43 |
| -@torch.compile() |
44 |
| -def fn(x, y): |
45 |
| - z = x + y |
46 |
| - return z + 2 |
47 |
| - |
48 |
| - |
49 |
| -inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda")) |
| 45 | + inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda")) |
50 | 46 |
|
51 | 47 |
|
52 | 48 | # print separator and reset dynamo
|
53 | 49 | # between each example
|
54 |
| -def separator(name): |
55 |
| - print(f"==================={name}=========================") |
56 |
| - torch._dynamo.reset() |
| 50 | + def separator(name): |
| 51 | + print(f"==================={name}=========================") |
| 52 | + torch._dynamo.reset() |
57 | 53 |
|
58 | 54 |
|
59 |
| -separator("Dynamo Tracing") |
| 55 | + separator("Dynamo Tracing") |
60 | 56 | # View dynamo tracing
|
61 | 57 | # TORCH_LOGS="+dynamo"
|
62 |
| -torch._logging.set_logs(dynamo=logging.DEBUG) |
63 |
| -fn(*inputs) |
| 58 | + torch._logging.set_logs(dynamo=logging.DEBUG) |
| 59 | + fn(*inputs) |
64 | 60 |
|
65 |
| -separator("Traced Graph") |
| 61 | + separator("Traced Graph") |
66 | 62 | # View traced graph
|
67 | 63 | # TORCH_LOGS="graph"
|
68 |
| -torch._logging.set_logs(graph=True) |
69 |
| -fn(*inputs) |
| 64 | + torch._logging.set_logs(graph=True) |
| 65 | + fn(*inputs) |
70 | 66 |
|
71 |
| -separator("Fusion Decisions") |
| 67 | + separator("Fusion Decisions") |
72 | 68 | # View fusion decisions
|
73 | 69 | # TORCH_LOGS="fusion"
|
74 |
| -torch._logging.set_logs(fusion=True) |
75 |
| -fn(*inputs) |
| 70 | + torch._logging.set_logs(fusion=True) |
| 71 | + fn(*inputs) |
76 | 72 |
|
77 |
| -separator("Output Code") |
| 73 | + separator("Output Code") |
78 | 74 | # View output code generated by inductor
|
79 | 75 | # TORCH_LOGS="output_code"
|
80 |
| -torch._logging.set_logs(output_code=True) |
81 |
| -fn(*inputs) |
| 76 | + torch._logging.set_logs(output_code=True) |
| 77 | + fn(*inputs) |
82 | 78 |
|
83 |
| -separator("") |
| 79 | + separator("") |
84 | 80 |
|
85 | 81 | ######################################################################
|
86 | 82 | # Conclusion
|
|
0 commit comments