diff --git a/advanced_source/torch_script_custom_classes.rst b/advanced_source/torch_script_custom_classes.rst index 93a78348522..031e6c3f696 100644 --- a/advanced_source/torch_script_custom_classes.rst +++ b/advanced_source/torch_script_custom_classes.rst @@ -78,12 +78,13 @@ Now let's take a look at how we will make this class visible to TorchScript, a p // In general, you cannot register a non-specialized template // class. For non-templated classes, you can just pass the // class name directly as the template parameter. - // - The single parameter to ``torch::class_()`` is a - // string indicating the name of the class. This is the name - // the class will appear as in both Python and TorchScript. - // For example, our MyStackClass class would appear as ``torch.classes.MyStackClass``. + // - The arguments passed to the constructor make up the "qualified name" + // of the class. In this case, the registered class will appear in + // Python and C++ as `torch.classes.my_classes.MyStackClass`. We call + // the first argument the "namespace" and the second argument the + // actual class name. static auto testStack = - torch::class_>("MyStackClass") + torch::class_>("my_classes", "MyStackClass") // The following line registers the contructor of our MyStackClass // class that takes a single `std::vector` argument, // i.e. it exposes the C++ method `MyStackClass(std::vector init)`. @@ -217,7 +218,7 @@ demonstrates that: # # This instantiation will invoke the MyStackClass(std::vector init) constructor # we registered earlier - s = torch.classes.MyStackClass(["foo", "bar"]) + s = torch.classes.my_classes.MyStackClass(["foo", "bar"]) # We can call methods in Python s.push("pushed") @@ -233,16 +234,16 @@ demonstrates that: # For now, we need to assign the class's type to a local in order to # annotate the type on the TorchScript function. This may change # in the future. - MyStackClass = torch.classes.MyStackClass + MyStackClass = torch.classes.my_classes.MyStackClass @torch.jit.script def do_stacks(s : MyStackClass): # We can pass a custom class instance to TorchScript - s2 = torch.classes.MyStackClass(["hi", "mom"]) # We can instantiate the class + s2 = torch.classes.my_classes.MyStackClass(["hi", "mom"]) # We can instantiate the class s2.merge(s) # We can call a method on the class return s2.clone(), s2.top() # We can also return instances of the class # from TorchScript function/methods - stack, top = do_stacks(torch.classes.MyStackClass(["wow"])) + stack, top = do_stacks(torch.classes.my_classes.MyStackClass(["wow"])) assert top == "wow" for expected in ["wow", "mom", "hi"]: assert stack.pop() == expected @@ -265,7 +266,7 @@ instantiates and calls a method on our MyStackClass class: super().__init__() def forward(self, s : str) -> str: - stack = torch.classes.MyStackClass(["hi", "mom"]) + stack = torch.classes.my_classes.MyStackClass(["hi", "mom"]) return stack.pop() + s scripted_foo = torch.jit.script(Foo()) @@ -435,7 +436,7 @@ an attribute, you'll get the following error: class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.stack = torch.classes.MyStackClass(["just", "testing"]) + self.stack = torch.classes.my_classes.MyStackClass(["just", "testing"]) def forward(self, s : str) -> str: return self.stack.pop() + s @@ -447,7 +448,7 @@ an attribute, you'll get the following error: .. code-block:: shell $ python export_attr.py - RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128) + RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.my_classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128) This is because TorchScript cannot automatically figure out what information save from your C++ class. You must specify that manually. The way to do that @@ -466,7 +467,7 @@ Here is an example of how we can update the registration code for our .. code-block:: cpp static auto testStack = - torch::class_>("MyStackClass") + torch::class_>("my_classes", "MyStackClass") .def(torch::init>()) .def("top", [](const c10::intrusive_ptr>& self) { return self->stack_.back(); @@ -528,7 +529,7 @@ now run successfully: class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.stack = torch.classes.MyStackClass(["just", "testing"]) + self.stack = torch.classes.my_classes.MyStackClass(["just", "testing"]) def forward(self, s : str) -> str: return self.stack.pop() + s @@ -562,7 +563,7 @@ example of how to do that: static auto instance_registry = torch::RegisterOperators().op( torch::RegisterOperators::options() .schema( - "foo::manipulate_instance(__torch__.torch.classes.MyStackClass x) -> __torch__.torch.classes.MyStackClass Y") + "foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y") .catchAllKernel()); Refer to the `custom op tutorial `_ @@ -575,7 +576,7 @@ Once this is done, you can use the op like the following example: class TryCustomOp(torch.nn.Module): def __init__(self): super(TryCustomOp, self).__init__() - self.f = torch.classes.MyStackClass(["foo", "bar"]) + self.f = torch.classes.my_classes.MyStackClass(["foo", "bar"]) def forward(self): return torch.ops.foo.manipulate_instance(self.f)