diff --git a/advanced_source/torch_script_custom_classes.rst b/advanced_source/torch_script_custom_classes.rst index ff647e6ab0a..e85d722cd47 100644 --- a/advanced_source/torch_script_custom_classes.rst +++ b/advanced_source/torch_script_custom_classes.rst @@ -520,6 +520,50 @@ now run successfully: $ python ../export_attr.py testing +Defining Custom Operators that Take C++ Classes as Arguments +------------------------------------------------------------ + +Once you've defined a custom C++ class, you can also use that class +as an argument to custom operator (i.e. free functions). Here's an +example of how to do that: + +.. code-block:: cpp + + std::string take_an_instance(const c10::intrusive_ptr>& instance) { + return instance->pop(); + } + + static auto instance_registry = torch::RegisterOperators().op( + torch::RegisterOperators::options() + .schema( + "foo::take_an_instance(__torch__.torch.classes.Stack x) -> str Y") + .catchAllKernel()); + +Refer to the `custom op tutorial `_ +for more details on the registration API. + +Once this is done, you can use the op like the following example: + +.. code-block:: python + + class TryCustomOp(torch.nn.Module): + def __init__(self): + super(TryCustomOp, self).__init__() + self.f = torch.classes.Stack(["foo", "bar"]) + + def forward(self) -> str: + return torch.ops._TorchScriptTesting.take_an_instance(self.f) + +.. note:: + + Registration of an operator that takes a C++ class as an argument requires that + the custom class has already been registered. This is fine if your op is + registered after your class in a single compilation unit, however, if your + class is registered in a separate compilation unit from the op you will need + to enforce that dependency. One way to do this is to wrap the class registration + in a `Meyer's singleton `_, which can be + called from the compilation unit that does the operator registration. + Conclusion ----------