diff --git a/advanced_source/torch_script_custom_classes.rst b/advanced_source/torch_script_custom_classes.rst index e85d722cd47..6ef54a4a8d9 100644 --- a/advanced_source/torch_script_custom_classes.rst +++ b/advanced_source/torch_script_custom_classes.rst @@ -29,9 +29,9 @@ state in a member variable. #include template - struct Stack : torch::jit::CustomClassHolder { + struct MyStackClass : torch::jit::CustomClassHolder { std::vector stack_; - Stack(std::vector init) : stack_(init.begin(), init.end()) {} + MyStackClass(std::vector init) : stack_(init.begin(), init.end()) {} void push(T x) { stack_.push_back(x); @@ -42,11 +42,11 @@ state in a member variable. return val; } - c10::intrusive_ptr clone() const { - return c10::make_intrusive(stack_); + c10::intrusive_ptr clone() const { + return c10::make_intrusive(stack_); } - void merge(const c10::intrusive_ptr& c) { + void merge(const c10::intrusive_ptr& c) { for (auto& elem : c->stack_) { push(elem); } @@ -74,19 +74,19 @@ Now let's take a look at how we will make this class visible to TorchScript, a p // Notice a few things: // - We pass the class to be registered as a template parameter to // `torch::jit::class_`. In this instance, we've passed the - // specialization of the Stack class ``Stack``. + // specialization of the MyStackClass class ``MyStackClass``. // In general, you cannot register a non-specialized template // class. For non-templated classes, you can just pass the // class name directly as the template parameter. // - The single parameter to ``torch::jit::class_()`` is a // string indicating the name of the class. This is the name // the class will appear as in both Python and TorchScript. - // For example, our Stack class would appear as ``torch.classes.Stack``. + // For example, our MyStackClass class would appear as ``torch.classes.MyStackClass``. static auto testStack = - torch::jit::class_>("Stack") - // The following line registers the contructor of our Stack + torch::jit::class_>("MyStackClass") + // The following line registers the contructor of our MyStackClass // class that takes a single `std::vector` argument, - // i.e. it exposes the C++ method `Stack(std::vector init)`. + // i.e. it exposes the C++ method `MyStackClass(std::vector init)`. // Currently, we do not support registering overloaded // constructors, so for now you can only `def()` one instance of // `torch::jit::init`. @@ -95,19 +95,19 @@ Now let's take a look at how we will make this class visible to TorchScript, a p // function as a method. Note that a lambda function must take a // `c10::intrusive_ptr` (or some const/ref version of that) // as the first argument. Other arguments can be whatever you want. - .def("top", [](const c10::intrusive_ptr>& self) { + .def("top", [](const c10::intrusive_ptr>& self) { return self->stack_.back(); }) - // The following four lines expose methods of the Stack + // The following four lines expose methods of the MyStackClass // class as-is. `torch::jit::class_` will automatically examine the // argument and return types of the passed-in method pointers and // expose these to Python and TorchScript accordingly. Finally, notice // that we must take the *address* of the fully-qualified method name, // i.e. use the unary `&` operator, due to C++ typing rules. - .def("push", &Stack::push) - .def("pop", &Stack::pop) - .def("clone", &Stack::clone) - .def("merge", &Stack::merge); + .def("push", &MyStackClass::push) + .def("pop", &MyStackClass::pop) + .def("clone", &MyStackClass::clone) + .def("merge", &MyStackClass::merge); @@ -215,9 +215,9 @@ demonstrates that: # We can find and instantiate our custom C++ class in python by using the # `torch.classes` namespace: # - # This instantiation will invoke the Stack(std::vector init) constructor + # This instantiation will invoke the MyStackClass(std::vector init) constructor # we registered earlier - s = torch.classes.Stack(["foo", "bar"]) + s = torch.classes.MyStackClass(["foo", "bar"]) # We can call methods in Python s.push("pushed") @@ -233,16 +233,16 @@ demonstrates that: # For now, we need to assign the class's type to a local in order to # annotate the type on the TorchScript function. This may change # in the future. - Stack = torch.classes.Stack + MyStackClass = torch.classes.MyStackClass @torch.jit.script - def do_stacks(s : Stack): # We can pass a custom class instance to TorchScript - s2 = torch.classes.Stack(["hi", "mom"]) # We can instantiate the class + def do_stacks(s : MyStackClass): # We can pass a custom class instance to TorchScript + s2 = torch.classes.MyStackClass(["hi", "mom"]) # We can instantiate the class s2.merge(s) # We can call a method on the class return s2.clone(), s2.top() # We can also return instances of the class # from TorchScript function/methods - stack, top = do_stacks(torch.classes.Stack(["wow"])) + stack, top = do_stacks(torch.classes.MyStackClass(["wow"])) assert top == "wow" for expected in ["wow", "mom", "hi"]: assert stack.pop() == expected @@ -252,7 +252,7 @@ Saving, Loading, and Running TorchScript Code Using Custom Classes We can also use custom-registered C++ classes in a C++ process using libtorch. As an example, let's define a simple ``nn.Module`` that -instantiates and calls a method on our Stack class: +instantiates and calls a method on our MyStackClass class: .. code-block:: python @@ -265,7 +265,7 @@ instantiates and calls a method on our Stack class: super().__init__() def forward(self, s : str) -> str: - stack = torch.classes.Stack(["hi", "mom"]) + stack = torch.classes.MyStackClass(["hi", "mom"]) return stack.pop() + s scripted_foo = torch.jit.script(Foo()) @@ -410,7 +410,7 @@ an attribute, you'll get the following error: class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.stack = torch.classes.Stack(["just", "testing"]) + self.stack = torch.classes.MyStackClass(["just", "testing"]) def forward(self, s : str) -> str: return self.stack.pop() + s @@ -422,7 +422,7 @@ an attribute, you'll get the following error: .. code-block:: shell $ python export_attr.py - RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.Stack. Please define serialization methods via torch::jit::pickle_ for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128) + RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via torch::jit::pickle_ for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128) This is because TorchScript cannot automatically figure out what information save from your C++ class. You must specify that manually. The way to do that @@ -436,20 +436,20 @@ the special ``def_pickle`` method on ``class_``. about how we use these methods. Here is an example of how we can update the registration code for our -``Stack`` class to include serialization methods: +``MyStackClass`` class to include serialization methods: .. code-block:: cpp static auto testStack = - torch::jit::class_>("Stack") + torch::jit::class_>("MyStackClass") .def(torch::jit::init>()) - .def("top", [](const c10::intrusive_ptr>& self) { + .def("top", [](const c10::intrusive_ptr>& self) { return self->stack_.back(); }) - .def("push", &Stack::push) - .def("pop", &Stack::pop) - .def("clone", &Stack::clone) - .def("merge", &Stack::merge) + .def("push", &MyStackClass::push) + .def("pop", &MyStackClass::pop) + .def("clone", &MyStackClass::clone) + .def("merge", &MyStackClass::merge) // class_<>::def_pickle allows you to define the serialization // and deserialization methods for your C++ class. // Currently, we only support passing stateless lambda functions @@ -464,7 +464,7 @@ Here is an example of how we can update the registration code for our // custom operator API. In this instance, we've chosen to return // a std::vector as the salient data to preserve // from the class. - [](const c10::intrusive_ptr>& self) + [](const c10::intrusive_ptr>& self) -> std::vector { return self->stack_; }, @@ -476,13 +476,13 @@ Here is an example of how we can update the registration code for our // to a new instance of the C++ class, initialized however // you would like given the serialized state. [](std::vector state) - -> c10::intrusive_ptr> { + -> c10::intrusive_ptr> { // A convenient way to instantiate an object and get an // intrusive_ptr to it is via `make_intrusive`. We use - // that here to allocate an instance of Stack + // that here to allocate an instance of MyStackClass // and call the single-argument std::vector // constructor with the serialized state. - return c10::make_intrusive>(std::move(state)); + return c10::make_intrusive>(std::move(state)); }); .. note:: @@ -503,7 +503,7 @@ now run successfully: class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.stack = torch.classes.Stack(["just", "testing"]) + self.stack = torch.classes.MyStackClass(["just", "testing"]) def forward(self, s : str) -> str: return self.stack.pop() + s @@ -520,24 +520,25 @@ now run successfully: $ python ../export_attr.py testing -Defining Custom Operators that Take C++ Classes as Arguments ------------------------------------------------------------- +Defining Custom Operators that Take or Return Bound C++ Classes +--------------------------------------------------------------- Once you've defined a custom C++ class, you can also use that class -as an argument to custom operator (i.e. free functions). Here's an +as an argument or return from a custom operator (i.e. free functions). Here's an example of how to do that: .. code-block:: cpp - std::string take_an_instance(const c10::intrusive_ptr>& instance) { - return instance->pop(); + c10::intrusive_ptr> manipulate_instance(const c10::intrusive_ptr>& instance) { + instance->pop(); + return instance; } static auto instance_registry = torch::RegisterOperators().op( torch::RegisterOperators::options() .schema( - "foo::take_an_instance(__torch__.torch.classes.Stack x) -> str Y") - .catchAllKernel()); + "foo::manipulate_instance(__torch__.torch.classes.MyStackClass x) -> __torch__.torch.classes.MyStackClass Y") + .catchAllKernel()); Refer to the `custom op tutorial `_ for more details on the registration API. @@ -549,10 +550,10 @@ Once this is done, you can use the op like the following example: class TryCustomOp(torch.nn.Module): def __init__(self): super(TryCustomOp, self).__init__() - self.f = torch.classes.Stack(["foo", "bar"]) + self.f = torch.classes.MyStackClass(["foo", "bar"]) - def forward(self) -> str: - return torch.ops._TorchScriptTesting.take_an_instance(self.f) + def forward(self): + return torch.ops.foo.manipulate_instance(self.f) .. note::