diff --git a/advanced_source/torch_script_custom_classes.rst b/advanced_source/torch_script_custom_classes.rst index 6ef54a4a8d9..93a78348522 100644 --- a/advanced_source/torch_script_custom_classes.rst +++ b/advanced_source/torch_script_custom_classes.rst @@ -29,7 +29,7 @@ state in a member variable. #include template - struct MyStackClass : torch::jit::CustomClassHolder { + struct MyStackClass : torch::CustomClassHolder { std::vector stack_; MyStackClass(std::vector init) : stack_(init.begin(), init.end()) {} @@ -63,7 +63,7 @@ There are several things to note: is to ensure consistent lifetime management of the object instances between languages (C++, Python and TorchScript). - The second thing to notice is that the user-defined class must inherit from - ``torch::jit::CustomClassHolder``. This ensures that everything is set up to handle + ``torch::CustomClassHolder``. This ensures that everything is set up to handle the lifetime management system previously mentioned. Now let's take a look at how we will make this class visible to TorchScript, a process called @@ -73,24 +73,24 @@ Now let's take a look at how we will make this class visible to TorchScript, a p // Notice a few things: // - We pass the class to be registered as a template parameter to - // `torch::jit::class_`. In this instance, we've passed the + // `torch::class_`. In this instance, we've passed the // specialization of the MyStackClass class ``MyStackClass``. // In general, you cannot register a non-specialized template // class. For non-templated classes, you can just pass the // class name directly as the template parameter. - // - The single parameter to ``torch::jit::class_()`` is a + // - The single parameter to ``torch::class_()`` is a // string indicating the name of the class. This is the name // the class will appear as in both Python and TorchScript. // For example, our MyStackClass class would appear as ``torch.classes.MyStackClass``. static auto testStack = - torch::jit::class_>("MyStackClass") + torch::class_>("MyStackClass") // The following line registers the contructor of our MyStackClass // class that takes a single `std::vector` argument, // i.e. it exposes the C++ method `MyStackClass(std::vector init)`. // Currently, we do not support registering overloaded // constructors, so for now you can only `def()` one instance of - // `torch::jit::init`. - .def(torch::jit::init>()) + // `torch::init`. + .def(torch::init>()) // The next line registers a stateless (i.e. no captures) C++ lambda // function as a method. Note that a lambda function must take a // `c10::intrusive_ptr` (or some const/ref version of that) @@ -99,7 +99,7 @@ Now let's take a look at how we will make this class visible to TorchScript, a p return self->stack_.back(); }) // The following four lines expose methods of the MyStackClass - // class as-is. `torch::jit::class_` will automatically examine the + // class as-is. `torch::class_` will automatically examine the // argument and return types of the passed-in method pointers and // expose these to Python and TorchScript accordingly. Finally, notice // that we must take the *address* of the fully-qualified method name, @@ -307,7 +307,7 @@ Let's populate ``infer.cpp`` with the following: #include int main(int argc, const char* argv[]) { - torch::jit::script::Module module; + torch::script::Module module; try { // Deserialize the ScriptModule from a file using torch::jit::load(). module = torch::jit::load("foo.pt"); @@ -394,6 +394,31 @@ And now we can run our exciting C++ binary: Incredible! +Moving Custom Classes To/From IValues +------------------------------------- + +It's also possible that you may need to move custom classes into or out of +``IValue``s, such as when you take or return ``IValue``s from TorchScript methods +or you want to instantiate a custom class attribute in C++. For creating an +``IValue`` from a custom C++ class instance: + +- ``torch::make_custom_class()`` provides an API similar to c10::intrusive_ptr + in that it will take whatever set of arguments you provide to it, call the constructor + of T that matches that set of arguments, and wrap that instance up and return it. + However, instead of returning just a pointer to a custom class object, it returns + an ``IValue`` wrapping the object. You can then pass this ``IValue`` directly to + TorchScript. +- In the event that you already have an ``intrusive_ptr`` pointing to your class, you + can directly construct an IValue from it using the constructor ``IValue(intrusive_ptr)``. + +For converting ``IValue``s back to custom classes: + +- ``IValue::toCustomClass()`` will return an ``intrusive_ptr`` pointing to the + custom class that the ``IValue`` contains. Internally, this function is checking + that ``T`` is registered as a custom class and that the ``IValue`` does in fact contain + a custom class. You can check whether the ``IValue`` contains a custom class manually by + calling ``isCustomClass()``. + Defining Serialization/Deserialization Methods for Custom C++ Classes --------------------------------------------------------------------- @@ -422,7 +447,7 @@ an attribute, you'll get the following error: .. code-block:: shell $ python export_attr.py - RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via torch::jit::pickle_ for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128) + RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128) This is because TorchScript cannot automatically figure out what information save from your C++ class. You must specify that manually. The way to do that @@ -441,8 +466,8 @@ Here is an example of how we can update the registration code for our .. code-block:: cpp static auto testStack = - torch::jit::class_>("MyStackClass") - .def(torch::jit::init>()) + torch::class_>("MyStackClass") + .def(torch::init>()) .def("top", [](const c10::intrusive_ptr>& self) { return self->stack_.back(); })