Closed
Description
🐛 Describe the bug
Using assert_constant_result outside of nn.Module fails on getting Submodule (to look for real value)
@torch._dynamo.assume_constant_result
def const_fn(n, s):
return torch.full([n], s)
def fn(B):
B = const_fn(B.size(0), 13)
X = B * 2
return X.tolist()
B_list = [8] * 32
B = torch.tensor(B_list, dtype=torch.int32)
torch._dynamo.decorators.mark_static(B, 0)
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
print(fn(B))
torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B)
Full error:
1) torchrec.distributed.tests.test_test.TestTest: test_dynamo_constant_tensor
1) TorchRuntimeError: Failed running get_attr const_fn(*(), **{}):
'SubgraphTracer' object has no attribute 'get_submodule'
from user code:
File "/data/users/ivankobzarev/fbsource/buck-out/v2/gen/fbcode/680651077c79ba5d/torchrec/distributed/tests/__test_test__/test_test#link-tree/torchrec/distributed/tests/test_test.py", line 30, in forward
X = B * 2
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
File "torchrec/distributed/tests/test_test.py", line 46, in test_dynamo_constant_tensor
torch.compile(m, backend="eager", fullgraph=True, dynamic=True)(B)
File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "torch/_dynamo/eval_frame.py", line 403, in _fn
return fn(*args, **kwargs)
File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "torch/_dynamo/convert_frame.py", line 977, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
return _compile(
File "torch/_utils_internal.py", line 279, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "caffe2/fb/strobelight/compile_time_profiler.py", line 96, in profile_compile_time
return func(*args, **kwargs)
File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "torch/_dynamo/convert_frame.py", line 700, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "torch/_dynamo/utils.py", line 268, in time_wrapper
r = func(*args, **kwargs)
File "torch/_dynamo/convert_frame.py", line 568, in compile_inner
out_code = transform_code_object(code, transform)
File "torch/_dynamo/bytecode_transformation.py", line 1116, in transform_code_object
transformations(instructions, code_options)
File "torch/_dynamo/convert_frame.py", line 173, in _fn
return fn(*args, **kwargs)
File "torch/_dynamo/convert_frame.py", line 515, in transform
tracer.run()
File "torch/_dynamo/symbolic_convert.py", line 2237, in run
super().run()
File "torch/_dynamo/symbolic_convert.py", line 875, in run
while self.step():
File "torch/_dynamo/symbolic_convert.py", line 790, in step
self.dispatch_table[inst.opcode](self, inst)
File "torch/_dynamo/symbolic_convert.py", line 229, in impl
self.push(fn_var.call_function(self, self.popn(nargs), {}))
File "torch/_dynamo/variables/builtin.py", line 946, in call_function
return handler(tx, args, kwargs)
File "torch/_dynamo/variables/builtin.py", line 850, in _handle_insert_op_in_graph
return invoke_and_store_as_constant(
File "torch/_dynamo/variables/functions.py", line 388, in invoke_and_store_as_constant
args = [convert(x) for x in args]
File "torch/_dynamo/variables/functions.py", line 388, in <listcomp>
args = [convert(x) for x in args]
File "torch/_dynamo/variables/functions.py", line 384, in convert
return x.get_real_value()
File "torch/_dynamo/variables/tensor.py", line 113, in get_real_value
return get_real_value(self.proxy.node, self.proxy.tracer)
File "torch/_dynamo/utils.py", line 1924, in get_real_value
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
File "torch/_dynamo/utils.py", line 1921, in get_real_value
real_value = run_node(tracer, node, args, kwargs, nn_module)
File "torch/_dynamo/utils.py", line 1885, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "torch/_dynamo/utils.py", line 1874, in run_node
return tracer.get_submodule(node.target)
While inside Module it works as expected:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@torch._dynamo.assume_constant_result
def const_fn(self, n, s):
return torch.full([n], s)
def forward(self, B):
B = const_fn(B.size(0), 13)
X = B * 2
return X.tolist()
B_list = [8] * 32
B = torch.tensor(B_list, dtype=torch.int32)
torch._dynamo.decorators.mark_static(B, 0)
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
m = M()
torch.compile(M, backend="eager", fullgraph=True, dynamic=True)(B)
Versions
fbcode/warm
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78