Skip to content

More updates to custom class tutorial #890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 17, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions advanced_source/torch_script_custom_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ state in a member variable.
#include <vector>

template <class T>
struct MyStackClass : torch::jit::CustomClassHolder {
struct MyStackClass : torch::CustomClassHolder {
std::vector<T> stack_;
MyStackClass(std::vector<T> init) : stack_(init.begin(), init.end()) {}

Expand Down Expand Up @@ -63,7 +63,7 @@ There are several things to note:
is to ensure consistent lifetime management of the object instances between languages
(C++, Python and TorchScript).
- The second thing to notice is that the user-defined class must inherit from
``torch::jit::CustomClassHolder``. This ensures that everything is set up to handle
``torch::CustomClassHolder``. This ensures that everything is set up to handle
the lifetime management system previously mentioned.

Now let's take a look at how we will make this class visible to TorchScript, a process called
Expand All @@ -73,24 +73,24 @@ Now let's take a look at how we will make this class visible to TorchScript, a p

// Notice a few things:
// - We pass the class to be registered as a template parameter to
// `torch::jit::class_`. In this instance, we've passed the
// `torch::class_`. In this instance, we've passed the
// specialization of the MyStackClass class ``MyStackClass<std::string>``.
// In general, you cannot register a non-specialized template
// class. For non-templated classes, you can just pass the
// class name directly as the template parameter.
// - The single parameter to ``torch::jit::class_()`` is a
// - The single parameter to ``torch::class_()`` is a
// string indicating the name of the class. This is the name
// the class will appear as in both Python and TorchScript.
// For example, our MyStackClass class would appear as ``torch.classes.MyStackClass``.
static auto testStack =
torch::jit::class_<MyStackClass<std::string>>("MyStackClass")
torch::class_<MyStackClass<std::string>>("MyStackClass")
// The following line registers the contructor of our MyStackClass
// class that takes a single `std::vector<std::string>` argument,
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
// Currently, we do not support registering overloaded
// constructors, so for now you can only `def()` one instance of
// `torch::jit::init`.
.def(torch::jit::init<std::vector<std::string>>())
// `torch::init`.
.def(torch::init<std::vector<std::string>>())
// The next line registers a stateless (i.e. no captures) C++ lambda
// function as a method. Note that a lambda function must take a
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
Expand All @@ -99,7 +99,7 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
return self->stack_.back();
})
// The following four lines expose methods of the MyStackClass<std::string>
// class as-is. `torch::jit::class_` will automatically examine the
// class as-is. `torch::class_` will automatically examine the
// argument and return types of the passed-in method pointers and
// expose these to Python and TorchScript accordingly. Finally, notice
// that we must take the *address* of the fully-qualified method name,
Expand Down Expand Up @@ -307,7 +307,7 @@ Let's populate ``infer.cpp`` with the following:
#include <memory>

int main(int argc, const char* argv[]) {
torch::jit::script::Module module;
torch::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load("foo.pt");
Expand Down Expand Up @@ -394,6 +394,31 @@ And now we can run our exciting C++ binary:

Incredible!

Moving Custom Classes To/From IValues
-------------------------------------

It's also possible that you may need to move custom classes into or out of
``IValue``s, such as when you take or return ``IValue``s from TorchScript methods
or you want to instantiate a custom class attribute in C++. For creating an
``IValue`` from a custom C++ class instance:

- ``torch::make_custom_class<T>()`` provides an API similar to c10::intrusive_ptr<T>
in that it will take whatever set of arguments you provide to it, call the constructor
of T that matches that set of arguments, and wrap that instance up and return it.
However, instead of returning just a pointer to a custom class object, it returns
an ``IValue`` wrapping the object. You can then pass this ``IValue`` directly to
TorchScript.
- In the event that you already have an ``intrusive_ptr`` pointing to your class, you
can directly construct an IValue from it using the constructor ``IValue(intrusive_ptr<T>)``.

For converting ``IValue``s back to custom classes:

- ``IValue::toCustomClass<T>()`` will return an ``intrusive_ptr<T>`` pointing to the
custom class that the ``IValue`` contains. Internally, this function is checking
that ``T`` is registered as a custom class and that the ``IValue`` does in fact contain
a custom class. You can check whether the ``IValue`` contains a custom class manually by
calling ``isCustomClass()``.

Defining Serialization/Deserialization Methods for Custom C++ Classes
---------------------------------------------------------------------

Expand Down Expand Up @@ -422,7 +447,7 @@ an attribute, you'll get the following error:
.. code-block:: shell

$ python export_attr.py
RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via torch::jit::pickle_ for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)

This is because TorchScript cannot automatically figure out what information
save from your C++ class. You must specify that manually. The way to do that
Expand All @@ -441,8 +466,8 @@ Here is an example of how we can update the registration code for our
.. code-block:: cpp

static auto testStack =
torch::jit::class_<MyStackClass<std::string>>("MyStackClass")
.def(torch::jit::init<std::vector<std::string>>())
torch::class_<MyStackClass<std::string>>("MyStackClass")
.def(torch::init<std::vector<std::string>>())
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
return self->stack_.back();
})
Expand Down