Skip to content

Commit 2ea3972

Browse files
authored
Merge pull request #880 from jamesr66a/cpp_class_op
Add section about C++ classes as op args
2 parents 8427cae + d7c5947 commit 2ea3972

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

advanced_source/torch_script_custom_classes.rst

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,50 @@ now run successfully:
520520
$ python ../export_attr.py
521521
testing
522522
523+
Defining Custom Operators that Take C++ Classes as Arguments
524+
------------------------------------------------------------
525+
526+
Once you've defined a custom C++ class, you can also use that class
527+
as an argument to custom operator (i.e. free functions). Here's an
528+
example of how to do that:
529+
530+
.. code-block:: cpp
531+
532+
std::string take_an_instance(const c10::intrusive_ptr<Stack<std::string>>& instance) {
533+
return instance->pop();
534+
}
535+
536+
static auto instance_registry = torch::RegisterOperators().op(
537+
torch::RegisterOperators::options()
538+
.schema(
539+
"foo::take_an_instance(__torch__.torch.classes.Stack x) -> str Y")
540+
.catchAllKernel<decltype(take_an_instance), &take_an_instance>());
541+
542+
Refer to the `custom op tutorial <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_
543+
for more details on the registration API.
544+
545+
Once this is done, you can use the op like the following example:
546+
547+
.. code-block:: python
548+
549+
class TryCustomOp(torch.nn.Module):
550+
def __init__(self):
551+
super(TryCustomOp, self).__init__()
552+
self.f = torch.classes.Stack(["foo", "bar"])
553+
554+
def forward(self) -> str:
555+
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
556+
557+
.. note::
558+
559+
Registration of an operator that takes a C++ class as an argument requires that
560+
the custom class has already been registered. This is fine if your op is
561+
registered after your class in a single compilation unit, however, if your
562+
class is registered in a separate compilation unit from the op you will need
563+
to enforce that dependency. One way to do this is to wrap the class registration
564+
in a `Meyer's singleton <https://stackoverflow.com/q/1661529>`_, which can be
565+
called from the compilation unit that does the operator registration.
566+
523567
Conclusion
524568
----------
525569

0 commit comments

Comments
 (0)