@@ -78,12 +78,13 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
78
78
// In general, you cannot register a non-specialized template
79
79
// class. For non-templated classes, you can just pass the
80
80
// class name directly as the template parameter.
81
- // - The single parameter to ``torch::class_()`` is a
82
- // string indicating the name of the class. This is the name
83
- // the class will appear as in both Python and TorchScript.
84
- // For example, our MyStackClass class would appear as ``torch.classes.MyStackClass``.
81
+ // - The arguments passed to the constructor make up the "qualified name"
82
+ // of the class. In this case, the registered class will appear in
83
+ // Python and C++ as `torch.classes.my_classes.MyStackClass`. We call
84
+ // the first argument the "namespace" and the second argument the
85
+ // actual class name.
85
86
static auto testStack =
86
- torch::class_<MyStackClass<std::string>>("MyStackClass")
87
+ torch::class_<MyStackClass<std::string>>("my_classes", " MyStackClass")
87
88
// The following line registers the contructor of our MyStackClass
88
89
// class that takes a single `std::vector<std::string>` argument,
89
90
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
@@ -217,7 +218,7 @@ demonstrates that:
217
218
#
218
219
# This instantiation will invoke the MyStackClass(std::vector<T> init) constructor
219
220
# we registered earlier
220
- s = torch.classes.MyStackClass([" foo" , " bar" ])
221
+ s = torch.classes.my_classes. MyStackClass([" foo" , " bar" ])
221
222
222
223
# We can call methods in Python
223
224
s.push(" pushed" )
@@ -233,16 +234,16 @@ demonstrates that:
233
234
# For now, we need to assign the class's type to a local in order to
234
235
# annotate the type on the TorchScript function. This may change
235
236
# in the future.
236
- MyStackClass = torch.classes.MyStackClass
237
+ MyStackClass = torch.classes.my_classes. MyStackClass
237
238
238
239
@torch.jit.script
239
240
def do_stacks(s : MyStackClass): # We can pass a custom class instance to TorchScript
240
- s2 = torch.classes.MyStackClass([" hi" , " mom" ]) # We can instantiate the class
241
+ s2 = torch.classes.my_classes. MyStackClass([" hi" , " mom" ]) # We can instantiate the class
241
242
s2.merge(s) # We can call a method on the class
242
243
return s2.clone (), s2.top () # We can also return instances of the class
243
244
# from TorchScript function/methods
244
245
245
- stack, top = do_stacks(torch.classes.MyStackClass([" wow" ]))
246
+ stack, top = do_stacks(torch.classes.my_classes. MyStackClass([" wow" ]))
246
247
assert top == " wow"
247
248
for expected in [" wow" , " mom" , " hi" ]:
248
249
assert stack.pop () == expected
@@ -265,7 +266,7 @@ instantiates and calls a method on our MyStackClass class:
265
266
super().__init__()
266
267
267
268
def forward(self, s : str) -> str:
268
- stack = torch.classes.MyStackClass(["hi", "mom"])
269
+ stack = torch.classes.my_classes. MyStackClass(["hi", "mom"])
269
270
return stack.pop() + s
270
271
271
272
scripted_foo = torch.jit.script(Foo())
@@ -435,7 +436,7 @@ an attribute, you'll get the following error:
435
436
class Foo(torch.nn.Module):
436
437
def __init__(self):
437
438
super().__init__()
438
- self.stack = torch.classes.MyStackClass(["just", "testing"])
439
+ self.stack = torch.classes.my_classes. MyStackClass(["just", "testing"])
439
440
440
441
def forward(self, s : str) -> str:
441
442
return self.stack.pop() + s
@@ -447,7 +448,7 @@ an attribute, you'll get the following error:
447
448
.. code-block:: shell
448
449
449
450
$ python export_attr.py
450
- RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
451
+ RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.my_classes. MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
451
452
452
453
This is because TorchScript cannot automatically figure out what information
453
454
save from your C++ class. You must specify that manually. The way to do that
@@ -466,7 +467,7 @@ Here is an example of how we can update the registration code for our
466
467
.. code-block:: cpp
467
468
468
469
static auto testStack =
469
- torch::class_<MyStackClass<std::string>>("MyStackClass")
470
+ torch::class_<MyStackClass<std::string>>("my_classes", " MyStackClass")
470
471
.def(torch::init<std::vector<std::string>>())
471
472
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
472
473
return self->stack_.back();
@@ -528,7 +529,7 @@ now run successfully:
528
529
class Foo(torch.nn.Module):
529
530
def __init__(self):
530
531
super().__init__()
531
- self.stack = torch.classes.MyStackClass(["just", "testing"])
532
+ self.stack = torch.classes.my_classes. MyStackClass(["just", "testing"])
532
533
533
534
def forward(self, s : str) -> str:
534
535
return self.stack.pop() + s
@@ -562,7 +563,7 @@ example of how to do that:
562
563
static auto instance_registry = torch::RegisterOperators().op(
563
564
torch::RegisterOperators::options()
564
565
.schema(
565
- "foo::manipulate_instance(__torch__.torch.classes.MyStackClass x) -> __torch__.torch.classes.MyStackClass Y")
566
+ "foo::manipulate_instance(__torch__.torch.classes.my_classes. MyStackClass x) -> __torch__.torch.classes.my_classes .MyStackClass Y")
566
567
.catchAllKernel<decltype(manipulate_instance), &manipulate_instance>());
567
568
568
569
Refer to the ` custom op tutorial < https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html> ` _
@@ -575,7 +576,7 @@ Once this is done, you can use the op like the following example:
575
576
class TryCustomOp(torch.nn.Module):
576
577
def __init__(self):
577
578
super(TryCustomOp, self).__init__()
578
- self.f = torch.classes.MyStackClass(["foo", "bar"])
579
+ self.f = torch.classes.my_classes. MyStackClass(["foo", "bar"])
579
580
580
581
def forward(self):
581
582
return torch.ops.foo.manipulate_instance(self.f)
0 commit comments