Skip to content

Some small updates to custom C++ class tutorial for clarity #883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 49 additions & 48 deletions advanced_source/torch_script_custom_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ state in a member variable.
#include <vector>

template <class T>
struct Stack : torch::jit::CustomClassHolder {
struct MyStackClass : torch::jit::CustomClassHolder {
std::vector<T> stack_;
Stack(std::vector<T> init) : stack_(init.begin(), init.end()) {}
MyStackClass(std::vector<T> init) : stack_(init.begin(), init.end()) {}

void push(T x) {
stack_.push_back(x);
Expand All @@ -42,11 +42,11 @@ state in a member variable.
return val;
}

c10::intrusive_ptr<Stack> clone() const {
return c10::make_intrusive<Stack>(stack_);
c10::intrusive_ptr<MyStackClass> clone() const {
return c10::make_intrusive<MyStackClass>(stack_);
}

void merge(const c10::intrusive_ptr<Stack>& c) {
void merge(const c10::intrusive_ptr<MyStackClass>& c) {
for (auto& elem : c->stack_) {
push(elem);
}
Expand Down Expand Up @@ -74,19 +74,19 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
// Notice a few things:
// - We pass the class to be registered as a template parameter to
// `torch::jit::class_`. In this instance, we've passed the
// specialization of the Stack class ``Stack<std::string>``.
// specialization of the MyStackClass class ``MyStackClass<std::string>``.
// In general, you cannot register a non-specialized template
// class. For non-templated classes, you can just pass the
// class name directly as the template parameter.
// - The single parameter to ``torch::jit::class_()`` is a
// string indicating the name of the class. This is the name
// the class will appear as in both Python and TorchScript.
// For example, our Stack class would appear as ``torch.classes.Stack``.
// For example, our MyStackClass class would appear as ``torch.classes.MyStackClass``.
static auto testStack =
torch::jit::class_<Stack<std::string>>("Stack")
// The following line registers the contructor of our Stack
torch::jit::class_<MyStackClass<std::string>>("MyStackClass")
// The following line registers the contructor of our MyStackClass
// class that takes a single `std::vector<std::string>` argument,
// i.e. it exposes the C++ method `Stack(std::vector<T> init)`.
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
// Currently, we do not support registering overloaded
// constructors, so for now you can only `def()` one instance of
// `torch::jit::init`.
Expand All @@ -95,19 +95,19 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
// function as a method. Note that a lambda function must take a
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
// as the first argument. Other arguments can be whatever you want.
.def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
return self->stack_.back();
})
// The following four lines expose methods of the Stack<std::string>
// The following four lines expose methods of the MyStackClass<std::string>
// class as-is. `torch::jit::class_` will automatically examine the
// argument and return types of the passed-in method pointers and
// expose these to Python and TorchScript accordingly. Finally, notice
// that we must take the *address* of the fully-qualified method name,
// i.e. use the unary `&` operator, due to C++ typing rules.
.def("push", &Stack<std::string>::push)
.def("pop", &Stack<std::string>::pop)
.def("clone", &Stack<std::string>::clone)
.def("merge", &Stack<std::string>::merge);
.def("push", &MyStackClass<std::string>::push)
.def("pop", &MyStackClass<std::string>::pop)
.def("clone", &MyStackClass<std::string>::clone)
.def("merge", &MyStackClass<std::string>::merge);



Expand Down Expand Up @@ -215,9 +215,9 @@ demonstrates that:
# We can find and instantiate our custom C++ class in python by using the
# `torch.classes` namespace:
#
# This instantiation will invoke the Stack(std::vector<T> init) constructor
# This instantiation will invoke the MyStackClass(std::vector<T> init) constructor
# we registered earlier
s = torch.classes.Stack(["foo", "bar"])
s = torch.classes.MyStackClass(["foo", "bar"])

# We can call methods in Python
s.push("pushed")
Expand All @@ -233,16 +233,16 @@ demonstrates that:
# For now, we need to assign the class's type to a local in order to
# annotate the type on the TorchScript function. This may change
# in the future.
Stack = torch.classes.Stack
MyStackClass = torch.classes.MyStackClass

@torch.jit.script
def do_stacks(s : Stack): # We can pass a custom class instance to TorchScript
s2 = torch.classes.Stack(["hi", "mom"]) # We can instantiate the class
def do_stacks(s : MyStackClass): # We can pass a custom class instance to TorchScript
s2 = torch.classes.MyStackClass(["hi", "mom"]) # We can instantiate the class
s2.merge(s) # We can call a method on the class
return s2.clone(), s2.top() # We can also return instances of the class
# from TorchScript function/methods

stack, top = do_stacks(torch.classes.Stack(["wow"]))
stack, top = do_stacks(torch.classes.MyStackClass(["wow"]))
assert top == "wow"
for expected in ["wow", "mom", "hi"]:
assert stack.pop() == expected
Expand All @@ -252,7 +252,7 @@ Saving, Loading, and Running TorchScript Code Using Custom Classes

We can also use custom-registered C++ classes in a C++ process using
libtorch. As an example, let's define a simple ``nn.Module`` that
instantiates and calls a method on our Stack class:
instantiates and calls a method on our MyStackClass class:

.. code-block:: python

Expand All @@ -265,7 +265,7 @@ instantiates and calls a method on our Stack class:
super().__init__()

def forward(self, s : str) -> str:
stack = torch.classes.Stack(["hi", "mom"])
stack = torch.classes.MyStackClass(["hi", "mom"])
return stack.pop() + s

scripted_foo = torch.jit.script(Foo())
Expand Down Expand Up @@ -410,7 +410,7 @@ an attribute, you'll get the following error:
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.stack = torch.classes.Stack(["just", "testing"])
self.stack = torch.classes.MyStackClass(["just", "testing"])

def forward(self, s : str) -> str:
return self.stack.pop() + s
Expand All @@ -422,7 +422,7 @@ an attribute, you'll get the following error:
.. code-block:: shell

$ python export_attr.py
RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.Stack. Please define serialization methods via torch::jit::pickle_ for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via torch::jit::pickle_ for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)

This is because TorchScript cannot automatically figure out what information
save from your C++ class. You must specify that manually. The way to do that
Expand All @@ -436,20 +436,20 @@ the special ``def_pickle`` method on ``class_``.
about how we use these methods.

Here is an example of how we can update the registration code for our
``Stack`` class to include serialization methods:
``MyStackClass`` class to include serialization methods:

.. code-block:: cpp

static auto testStack =
torch::jit::class_<Stack<std::string>>("Stack")
torch::jit::class_<MyStackClass<std::string>>("MyStackClass")
.def(torch::jit::init<std::vector<std::string>>())
.def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
return self->stack_.back();
})
.def("push", &Stack<std::string>::push)
.def("pop", &Stack<std::string>::pop)
.def("clone", &Stack<std::string>::clone)
.def("merge", &Stack<std::string>::merge)
.def("push", &MyStackClass<std::string>::push)
.def("pop", &MyStackClass<std::string>::pop)
.def("clone", &MyStackClass<std::string>::clone)
.def("merge", &MyStackClass<std::string>::merge)
// class_<>::def_pickle allows you to define the serialization
// and deserialization methods for your C++ class.
// Currently, we only support passing stateless lambda functions
Expand All @@ -464,7 +464,7 @@ Here is an example of how we can update the registration code for our
// custom operator API. In this instance, we've chosen to return
// a std::vector<std::string> as the salient data to preserve
// from the class.
[](const c10::intrusive_ptr<Stack<std::string>>& self)
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
-> std::vector<std::string> {
return self->stack_;
},
Expand All @@ -476,13 +476,13 @@ Here is an example of how we can update the registration code for our
// to a new instance of the C++ class, initialized however
// you would like given the serialized state.
[](std::vector<std::string> state)
-> c10::intrusive_ptr<Stack<std::string>> {
-> c10::intrusive_ptr<MyStackClass<std::string>> {
// A convenient way to instantiate an object and get an
// intrusive_ptr to it is via `make_intrusive`. We use
// that here to allocate an instance of Stack<std::string>
// that here to allocate an instance of MyStackClass<std::string>
// and call the single-argument std::vector<std::string>
// constructor with the serialized state.
return c10::make_intrusive<Stack<std::string>>(std::move(state));
return c10::make_intrusive<MyStackClass<std::string>>(std::move(state));
});

.. note::
Expand All @@ -503,7 +503,7 @@ now run successfully:
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.stack = torch.classes.Stack(["just", "testing"])
self.stack = torch.classes.MyStackClass(["just", "testing"])

def forward(self, s : str) -> str:
return self.stack.pop() + s
Expand All @@ -520,24 +520,25 @@ now run successfully:
$ python ../export_attr.py
testing

Defining Custom Operators that Take C++ Classes as Arguments
------------------------------------------------------------
Defining Custom Operators that Take or Return Bound C++ Classes
---------------------------------------------------------------

Once you've defined a custom C++ class, you can also use that class
as an argument to custom operator (i.e. free functions). Here's an
as an argument or return from a custom operator (i.e. free functions). Here's an
example of how to do that:

.. code-block:: cpp

std::string take_an_instance(const c10::intrusive_ptr<Stack<std::string>>& instance) {
return instance->pop();
c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance(const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
instance->pop();
return instance;
}

static auto instance_registry = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"foo::take_an_instance(__torch__.torch.classes.Stack x) -> str Y")
.catchAllKernel<decltype(take_an_instance), &take_an_instance>());
"foo::manipulate_instance(__torch__.torch.classes.MyStackClass x) -> __torch__.torch.classes.MyStackClass Y")
.catchAllKernel<decltype(manipulate_instance), &manipulate_instance>());

Refer to the `custom op tutorial <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_
for more details on the registration API.
Expand All @@ -549,10 +550,10 @@ Once this is done, you can use the op like the following example:
class TryCustomOp(torch.nn.Module):
def __init__(self):
super(TryCustomOp, self).__init__()
self.f = torch.classes.Stack(["foo", "bar"])
self.f = torch.classes.MyStackClass(["foo", "bar"])

def forward(self) -> str:
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
def forward(self):
return torch.ops.foo.manipulate_instance(self.f)

.. note::

Expand Down