Skip to content

torch._dynamo.assume_constant_result does not work outside nn.Module #124858

Closed
@IvanKobzarev

Description

@IvanKobzarev

🐛 Describe the bug

Using assert_constant_result outside of nn.Module fails on getting Submodule (to look for real value)

        @torch._dynamo.assume_constant_result
        def const_fn(n, s):
            return torch.full([n], s)

        def fn(B):
            B = const_fn(B.size(0), 13)
            X = B * 2
            return X.tolist()

        B_list = [8] * 32

        B = torch.tensor(B_list, dtype=torch.int32)
        torch._dynamo.decorators.mark_static(B, 0)

        torch._dynamo.config.capture_scalar_outputs = True
        torch._dynamo.config.capture_dynamic_output_shape_ops = True

        print(fn(B))
        torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B)

Full error:

  1) torchrec.distributed.tests.test_test.TestTest: test_dynamo_constant_tensor
    1) TorchRuntimeError: Failed running get_attr const_fn(*(), **{}):
    'SubgraphTracer' object has no attribute 'get_submodule'
    
    from user code:
       File "/data/users/ivankobzarev/fbsource/buck-out/v2/gen/fbcode/680651077c79ba5d/torchrec/distributed/tests/__test_test__/test_test#link-tree/torchrec/distributed/tests/test_test.py", line 30, in forward
        X = B * 2
    
    Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
    
    
    You can suppress this exception and fall back to eager by setting:
        import torch._dynamo
        torch._dynamo.config.suppress_errors = True
    
      File "torchrec/distributed/tests/test_test.py", line 46, in test_dynamo_constant_tensor
        torch.compile(m, backend="eager", fullgraph=True, dynamic=True)(B)
      File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "torch/nn/modules/module.py", line 1541, in _call_impl
        return forward_call(*args, **kwargs)
      File "torch/_dynamo/eval_frame.py", line 403, in _fn
        return fn(*args, **kwargs)
      File "torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "torch/nn/modules/module.py", line 1541, in _call_impl
        return forward_call(*args, **kwargs)
      File "torch/_dynamo/convert_frame.py", line 977, in catch_errors
        return callback(frame, cache_entry, hooks, frame_state, skip=1)
      File "torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
        return _compile(
      File "torch/_utils_internal.py", line 279, in wrapper_function
        return StrobelightCompileTimeProfiler.profile_compile_time(
      File "caffe2/fb/strobelight/compile_time_profiler.py", line 96, in profile_compile_time
        return func(*args, **kwargs)
      File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
        return func(*args, **kwds)
      File "torch/_dynamo/convert_frame.py", line 700, in _compile
        guarded_code = compile_inner(code, one_graph, hooks, transform)
      File "torch/_dynamo/utils.py", line 268, in time_wrapper
        r = func(*args, **kwargs)
      File "torch/_dynamo/convert_frame.py", line 568, in compile_inner
        out_code = transform_code_object(code, transform)
      File "torch/_dynamo/bytecode_transformation.py", line 1116, in transform_code_object
        transformations(instructions, code_options)
      File "torch/_dynamo/convert_frame.py", line 173, in _fn
        return fn(*args, **kwargs)
      File "torch/_dynamo/convert_frame.py", line 515, in transform
        tracer.run()
      File "torch/_dynamo/symbolic_convert.py", line 2237, in run
        super().run()
      File "torch/_dynamo/symbolic_convert.py", line 875, in run
        while self.step():
      File "torch/_dynamo/symbolic_convert.py", line 790, in step
        self.dispatch_table[inst.opcode](self, inst)
      File "torch/_dynamo/symbolic_convert.py", line 229, in impl
        self.push(fn_var.call_function(self, self.popn(nargs), {}))
      File "torch/_dynamo/variables/builtin.py", line 946, in call_function
        return handler(tx, args, kwargs)
      File "torch/_dynamo/variables/builtin.py", line 850, in _handle_insert_op_in_graph
        return invoke_and_store_as_constant(
      File "torch/_dynamo/variables/functions.py", line 388, in invoke_and_store_as_constant
        args = [convert(x) for x in args]
      File "torch/_dynamo/variables/functions.py", line 388, in <listcomp>
        args = [convert(x) for x in args]
      File "torch/_dynamo/variables/functions.py", line 384, in convert
        return x.get_real_value()
      File "torch/_dynamo/variables/tensor.py", line 113, in get_real_value
        return get_real_value(self.proxy.node, self.proxy.tracer)
      File "torch/_dynamo/utils.py", line 1924, in get_real_value
        raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
      File "torch/_dynamo/utils.py", line 1921, in get_real_value
        real_value = run_node(tracer, node, args, kwargs, nn_module)
      File "torch/_dynamo/utils.py", line 1885, in run_node
        raise RuntimeError(make_error_message(e)).with_traceback(
      File "torch/_dynamo/utils.py", line 1874, in run_node
        return tracer.get_submodule(node.target)

While inside Module it works as expected:

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            @torch._dynamo.assume_constant_result
            def const_fn(self, n, s):
                return torch.full([n], s)

            def forward(self, B):
                B = const_fn(B.size(0), 13)
                X = B * 2
                return X.tolist()

        B_list = [8] * 32
        B = torch.tensor(B_list, dtype=torch.int32)
        torch._dynamo.decorators.mark_static(B, 0)
        torch._dynamo.config.capture_scalar_outputs = True
        torch._dynamo.config.capture_dynamic_output_shape_ops = True
        m = M()
        torch.compile(M, backend="eager", fullgraph=True, dynamic=True)(B)

Versions

fbcode/warm

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78

Metadata

Metadata

Labels

high prioritymodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions