Skip to content

Update custom classes tutorial to include namespaces #904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions advanced_source/torch_script_custom_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
// In general, you cannot register a non-specialized template
// class. For non-templated classes, you can just pass the
// class name directly as the template parameter.
// - The single parameter to ``torch::class_()`` is a
// string indicating the name of the class. This is the name
// the class will appear as in both Python and TorchScript.
// For example, our MyStackClass class would appear as ``torch.classes.MyStackClass``.
// - The arguments passed to the constructor make up the "qualified name"
// of the class. In this case, the registered class will appear in
// Python and C++ as `torch.classes.my_classes.MyStackClass`. We call
// the first argument the "namespace" and the second argument the
// actual class name.
static auto testStack =
torch::class_<MyStackClass<std::string>>("MyStackClass")
torch::class_<MyStackClass<std::string>>("my_classes", "MyStackClass")
// The following line registers the contructor of our MyStackClass
// class that takes a single `std::vector<std::string>` argument,
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
Expand Down Expand Up @@ -217,7 +218,7 @@ demonstrates that:
#
# This instantiation will invoke the MyStackClass(std::vector<T> init) constructor
# we registered earlier
s = torch.classes.MyStackClass(["foo", "bar"])
s = torch.classes.my_classes.MyStackClass(["foo", "bar"])

# We can call methods in Python
s.push("pushed")
Expand All @@ -233,16 +234,16 @@ demonstrates that:
# For now, we need to assign the class's type to a local in order to
# annotate the type on the TorchScript function. This may change
# in the future.
MyStackClass = torch.classes.MyStackClass
MyStackClass = torch.classes.my_classes.MyStackClass

@torch.jit.script
def do_stacks(s : MyStackClass): # We can pass a custom class instance to TorchScript
s2 = torch.classes.MyStackClass(["hi", "mom"]) # We can instantiate the class
s2 = torch.classes.my_classes.MyStackClass(["hi", "mom"]) # We can instantiate the class
s2.merge(s) # We can call a method on the class
return s2.clone(), s2.top() # We can also return instances of the class
# from TorchScript function/methods

stack, top = do_stacks(torch.classes.MyStackClass(["wow"]))
stack, top = do_stacks(torch.classes.my_classes.MyStackClass(["wow"]))
assert top == "wow"
for expected in ["wow", "mom", "hi"]:
assert stack.pop() == expected
Expand All @@ -265,7 +266,7 @@ instantiates and calls a method on our MyStackClass class:
super().__init__()

def forward(self, s : str) -> str:
stack = torch.classes.MyStackClass(["hi", "mom"])
stack = torch.classes.my_classes.MyStackClass(["hi", "mom"])
return stack.pop() + s

scripted_foo = torch.jit.script(Foo())
Expand Down Expand Up @@ -435,7 +436,7 @@ an attribute, you'll get the following error:
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.stack = torch.classes.MyStackClass(["just", "testing"])
self.stack = torch.classes.my_classes.MyStackClass(["just", "testing"])

def forward(self, s : str) -> str:
return self.stack.pop() + s
Expand All @@ -447,7 +448,7 @@ an attribute, you'll get the following error:
.. code-block:: shell

$ python export_attr.py
RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.my_classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)

This is because TorchScript cannot automatically figure out what information
save from your C++ class. You must specify that manually. The way to do that
Expand All @@ -466,7 +467,7 @@ Here is an example of how we can update the registration code for our
.. code-block:: cpp

static auto testStack =
torch::class_<MyStackClass<std::string>>("MyStackClass")
torch::class_<MyStackClass<std::string>>("my_classes", "MyStackClass")
.def(torch::init<std::vector<std::string>>())
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
return self->stack_.back();
Expand Down Expand Up @@ -528,7 +529,7 @@ now run successfully:
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.stack = torch.classes.MyStackClass(["just", "testing"])
self.stack = torch.classes.my_classes.MyStackClass(["just", "testing"])

def forward(self, s : str) -> str:
return self.stack.pop() + s
Expand Down Expand Up @@ -562,7 +563,7 @@ example of how to do that:
static auto instance_registry = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"foo::manipulate_instance(__torch__.torch.classes.MyStackClass x) -> __torch__.torch.classes.MyStackClass Y")
"foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y")
.catchAllKernel<decltype(manipulate_instance), &manipulate_instance>());

Refer to the `custom op tutorial <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_
Expand All @@ -575,7 +576,7 @@ Once this is done, you can use the op like the following example:
class TryCustomOp(torch.nn.Module):
def __init__(self):
super(TryCustomOp, self).__init__()
self.f = torch.classes.MyStackClass(["foo", "bar"])
self.f = torch.classes.my_classes.MyStackClass(["foo", "bar"])

def forward(self):
return torch.ops.foo.manipulate_instance(self.f)
Expand Down