Skip to content

Commit 241ae0e

Browse files
committed
Turn on AOTAutogradCache by default on open source
ghstack-source-id: 9a82bec Pull Request resolved: #141981
1 parent b59ef87 commit 241ae0e

File tree

5 files changed

+31
-28
lines changed

5 files changed

+31
-28
lines changed

test/dynamo/test_aot_autograd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch._dynamo
1010
import torch._dynamo.test_case
11+
import torch._inductor.test_case
1112
import torch.fx.traceback as fx_traceback
1213
import torch.utils._pytree as pytree
1314
from torch._dynamo.testing import (
@@ -45,7 +46,7 @@ def is_dynamic_shape_test(test_name):
4546
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
4647

4748

48-
class AotAutogradFallbackTests(torch._dynamo.test_case.TestCase):
49+
class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
4950
def test_LSTM(self):
5051
# https://github.com/pytorch/torchdynamo/issues/1147
5152
class Repro(torch.nn.Module):

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def should_use_local_autograd_cache():
9696

9797

9898
def autograd_cache_enabled():
99-
return should_use_local_autograd_cache() or should_use_remote_autograd_cache()
99+
return (should_use_local_autograd_cache(), should_use_remote_autograd_cache())
100100

101101

102102
def check_node_safe(node: Node):

torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from .autograd_cache import (
3737
AOTAutogradCache,
3838
AOTAutogradCacheEntry,
39-
autograd_cache_enabled,
4039
CompiledBackward,
4140
CompiledForward,
4241
should_use_remote_autograd_cache,
@@ -149,14 +148,13 @@ def aot_dispatch_base(
149148
flat_fn, flat_args, fw_metadata = pre_compile(
150149
wrappers, flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
151150
)
152-
153151
fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc]
154152
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
155153
)
156154
# Save the forward_graph_str right after aot_dispatch_base_graph,
157155
# to save in the cache
158156
aot_forward_graph_str = None
159-
if autograd_cache_enabled():
157+
if aot_config.cache_info is not None:
160158
aot_forward_graph_str = fw_module.print_readable(
161159
print_output=False, include_stride=True, include_device=True
162160
)
@@ -218,7 +216,7 @@ def aot_dispatch_base(
218216
compiled_fw, aot_config, runtime_metadata=fw_metadata
219217
)
220218
cache_info = aot_config.cache_info
221-
if autograd_cache_enabled() and cache_info:
219+
if cache_info is not None:
222220
if fw_key := getattr(compiled_fw, "_fx_graph_cache_key", None):
223221
time_taken_ns = time.time_ns() - cache_info.start_time_ns
224222
entry = AOTAutogradCacheEntry(
@@ -824,13 +822,12 @@ def aot_dispatch_autograd(
824822

825823
try_save_cache_entry: Optional[Callable] = None
826824

827-
if autograd_cache_enabled():
828-
cache_info = aot_config.cache_info
829-
if cache_info is not None:
830-
forward_time_taken_ns = time.time_ns() - cache_info.start_time_ns
831-
else:
832-
forward_time_taken_ns = None
825+
if aot_config.cache_info is not None:
826+
forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns
833827

828+
# NB: aot_config here is technically not needed as an argument: we could just
829+
# close over aot_config.cache_info, since aot_config never changes.
830+
# But closing over random variables is confusing IMO, so I'm leaving it.
834831
def try_save_cache_entry( # noqa: F811
835832
compiled_bw_func, _fw_metadata, aot_config
836833
):

torch/_functorch/aot_autograd.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,20 +1137,22 @@ def dispatch_and_compile():
11371137
)
11381138
return compiled_fn
11391139

1140-
# Autograd cache stuff
1141-
remote = should_use_remote_autograd_cache()
1142-
local = should_use_local_autograd_cache()
11431140
# We only care if the forward will return an OutputCode.
1144-
if (local or remote) and isinstance(fw_compiler, SerializableAOTDispatchCompiler):
1145-
compiled_fn = AOTAutogradCache.load(
1146-
dispatch_and_compile,
1147-
mod,
1148-
fake_flat_args,
1149-
aot_config,
1150-
cudagraphs,
1151-
local,
1152-
remote,
1153-
)
1141+
if isinstance(fw_compiler, SerializableAOTDispatchCompiler):
1142+
local = should_use_local_autograd_cache()
1143+
remote = should_use_remote_autograd_cache()
1144+
if local or remote:
1145+
compiled_fn = AOTAutogradCache.load(
1146+
dispatch_and_compile,
1147+
mod,
1148+
fake_flat_args,
1149+
aot_config,
1150+
cudagraphs,
1151+
local,
1152+
remote,
1153+
)
1154+
else:
1155+
compiled_fn = dispatch_and_compile()
11541156
else:
11551157
compiled_fn = dispatch_and_compile()
11561158

torch/_functorch/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@
3232
# Applies CSE to the graph before partitioning
3333
cse = True
3434

35+
from torch._inductor.config import is_fbcode
36+
3537

36-
enable_autograd_cache = os.environ.get("TORCHINDUCTOR_AUTOGRAD_CACHE", "0") == "1"
38+
enable_autograd_cache = (
39+
os.environ.get("TORCHINDUCTOR_AUTOGRAD_CACHE", "0" if is_fbcode() else "1") == "1"
40+
)
3741

3842

3943
def remote_autograd_cache_default() -> Optional[bool]:
@@ -63,13 +67,12 @@ def remote_autograd_cache_default() -> Optional[bool]:
6367
# eventually: either default this config to false completely
6468
# once XLA pin update works,
6569
# or default config to true and fix relevant bugs
66-
from torch._inductor.config import is_fbcode
6770

6871

6972
# View replay is currently not compatible with AOTAutogradCache, since
7073
# FunctionalTensors are not serializable. We'll need to make them
7174
# serializable before enabling warm cache with this config turned on.
72-
view_replay_for_aliased_outputs = (not is_fbcode()) and (not enable_autograd_cache)
75+
view_replay_for_aliased_outputs = not is_fbcode()
7376

7477
# Restricts the amount of computation AOTAutograd can do.
7578
# NB: We have essentially disabled this heuristic now. However, this is kept

0 commit comments

Comments
 (0)