@@ -29,9 +29,9 @@ state in a member variable.
29
29
#include <vector>
30
30
31
31
template <class T>
32
- struct Stack : torch::jit ::CustomClassHolder {
32
+ struct MyStackClass : torch::CustomClassHolder {
33
33
std::vector<T> stack_;
34
- Stack (std::vector<T> init) : stack_(init.begin(), init.end()) {}
34
+ MyStackClass (std::vector<T> init) : stack_(init.begin(), init.end()) {}
35
35
36
36
void push(T x) {
37
37
stack_.push_back(x);
@@ -42,11 +42,11 @@ state in a member variable.
42
42
return val;
43
43
}
44
44
45
- c10::intrusive_ptr<Stack > clone() const {
46
- return c10::make_intrusive<Stack >(stack_);
45
+ c10::intrusive_ptr<MyStackClass > clone() const {
46
+ return c10::make_intrusive<MyStackClass >(stack_);
47
47
}
48
48
49
- void merge(const c10::intrusive_ptr<Stack >& c) {
49
+ void merge(const c10::intrusive_ptr<MyStackClass >& c) {
50
50
for (auto& elem : c->stack_) {
51
51
push(elem);
52
52
}
@@ -63,7 +63,7 @@ There are several things to note:
63
63
is to ensure consistent lifetime management of the object instances between languages
64
64
(C++, Python and TorchScript).
65
65
- The second thing to notice is that the user-defined class must inherit from
66
- ``torch::jit:: CustomClassHolder ``. This ensures that everything is set up to handle
66
+ ``torch::CustomClassHolder ``. This ensures that everything is set up to handle
67
67
the lifetime management system previously mentioned.
68
68
69
69
Now let's take a look at how we will make this class visible to TorchScript, a process called
@@ -73,41 +73,41 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
73
73
74
74
// Notice a few things:
75
75
// - We pass the class to be registered as a template parameter to
76
- // `torch::jit:: class_`. In this instance, we've passed the
77
- // specialization of the Stack class ``Stack <std::string>``.
76
+ // `torch::class_`. In this instance, we've passed the
77
+ // specialization of the MyStackClass class ``MyStackClass <std::string>``.
78
78
// In general, you cannot register a non-specialized template
79
79
// class. For non-templated classes, you can just pass the
80
80
// class name directly as the template parameter.
81
- // - The single parameter to ``torch::jit:: class_()`` is a
81
+ // - The single parameter to ``torch::class_()`` is a
82
82
// string indicating the name of the class. This is the name
83
83
// the class will appear as in both Python and TorchScript.
84
- // For example, our Stack class would appear as ``torch.classes.Stack ``.
84
+ // For example, our MyStackClass class would appear as ``torch.classes.MyStackClass ``.
85
85
static auto testStack =
86
- torch::jit:: class_<Stack <std::string>>("Stack ")
87
- // The following line registers the contructor of our Stack
86
+ torch::class_<MyStackClass <std::string>>("MyStackClass ")
87
+ // The following line registers the contructor of our MyStackClass
88
88
// class that takes a single `std::vector<std::string>` argument,
89
- // i.e. it exposes the C++ method `Stack (std::vector<T> init)`.
89
+ // i.e. it exposes the C++ method `MyStackClass (std::vector<T> init)`.
90
90
// Currently, we do not support registering overloaded
91
91
// constructors, so for now you can only `def()` one instance of
92
- // `torch::jit:: init`.
93
- .def(torch::jit:: init<std::vector<std::string>>())
92
+ // `torch::init`.
93
+ .def(torch::init<std::vector<std::string>>())
94
94
// The next line registers a stateless (i.e. no captures) C++ lambda
95
95
// function as a method. Note that a lambda function must take a
96
96
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
97
97
// as the first argument. Other arguments can be whatever you want.
98
- .def("top", [](const c10::intrusive_ptr<Stack <std::string>>& self) {
98
+ .def("top", [](const c10::intrusive_ptr<MyStackClass <std::string>>& self) {
99
99
return self->stack_.back();
100
100
})
101
- // The following four lines expose methods of the Stack <std::string>
102
- // class as-is. `torch::jit:: class_` will automatically examine the
101
+ // The following four lines expose methods of the MyStackClass <std::string>
102
+ // class as-is. `torch::class_` will automatically examine the
103
103
// argument and return types of the passed-in method pointers and
104
104
// expose these to Python and TorchScript accordingly. Finally, notice
105
105
// that we must take the *address* of the fully-qualified method name,
106
106
// i.e. use the unary `&` operator, due to C++ typing rules.
107
- .def("push", &Stack <std::string>::push)
108
- .def("pop", &Stack <std::string>::pop)
109
- .def("clone", &Stack <std::string>::clone)
110
- .def("merge", &Stack <std::string>::merge);
107
+ .def("push", &MyStackClass <std::string>::push)
108
+ .def("pop", &MyStackClass <std::string>::pop)
109
+ .def("clone", &MyStackClass <std::string>::clone)
110
+ .def("merge", &MyStackClass <std::string>::merge);
111
111
112
112
113
113
@@ -215,9 +215,9 @@ demonstrates that:
215
215
# We can find and instantiate our custom C++ class in python by using the
216
216
# `torch.classes` namespace:
217
217
#
218
- # This instantiation will invoke the Stack (std::vector<T> init) constructor
218
+ # This instantiation will invoke the MyStackClass (std::vector<T> init) constructor
219
219
# we registered earlier
220
- s = torch.classes.Stack ([" foo" , " bar" ])
220
+ s = torch.classes.MyStackClass ([" foo" , " bar" ])
221
221
222
222
# We can call methods in Python
223
223
s.push(" pushed" )
@@ -233,16 +233,16 @@ demonstrates that:
233
233
# For now, we need to assign the class's type to a local in order to
234
234
# annotate the type on the TorchScript function. This may change
235
235
# in the future.
236
- Stack = torch.classes.Stack
236
+ MyStackClass = torch.classes.MyStackClass
237
237
238
238
@torch.jit.script
239
- def do_stacks(s : Stack ): # We can pass a custom class instance to TorchScript
240
- s2 = torch.classes.Stack ([" hi" , " mom" ]) # We can instantiate the class
239
+ def do_stacks(s : MyStackClass ): # We can pass a custom class instance to TorchScript
240
+ s2 = torch.classes.MyStackClass ([" hi" , " mom" ]) # We can instantiate the class
241
241
s2.merge(s) # We can call a method on the class
242
242
return s2.clone (), s2.top () # We can also return instances of the class
243
243
# from TorchScript function/methods
244
244
245
- stack, top = do_stacks(torch.classes.Stack ([" wow" ]))
245
+ stack, top = do_stacks(torch.classes.MyStackClass ([" wow" ]))
246
246
assert top == " wow"
247
247
for expected in [" wow" , " mom" , " hi" ]:
248
248
assert stack.pop () == expected
@@ -252,7 +252,7 @@ Saving, Loading, and Running TorchScript Code Using Custom Classes
252
252
253
253
We can also use custom-registered C++ classes in a C++ process using
254
254
libtorch. As an example, let' s define a simple ``nn.Module`` that
255
- instantiates and calls a method on our Stack class:
255
+ instantiates and calls a method on our MyStackClass class:
256
256
257
257
.. code-block:: python
258
258
@@ -265,7 +265,7 @@ instantiates and calls a method on our Stack class:
265
265
super().__init__()
266
266
267
267
def forward(self, s : str) -> str:
268
- stack = torch.classes.Stack (["hi", "mom"])
268
+ stack = torch.classes.MyStackClass (["hi", "mom"])
269
269
return stack.pop() + s
270
270
271
271
scripted_foo = torch.jit.script(Foo())
@@ -307,7 +307,7 @@ Let's populate ``infer.cpp`` with the following:
307
307
# include <memory>
308
308
309
309
int main(int argc, const char* argv[]) {
310
- torch::jit:: script::Module module;
310
+ torch::script::Module module;
311
311
try {
312
312
// Deserialize the ScriptModule from a file using torch::jit::load ().
313
313
module = torch::jit::load(" foo.pt" );
@@ -394,6 +394,31 @@ And now we can run our exciting C++ binary:
394
394
395
395
Incredible!
396
396
397
+ Moving Custom Classes To/From IValues
398
+ -------------------------------------
399
+
400
+ It's also possible that you may need to move custom classes into or out of
401
+ ` ` IValue` ` s, such as when you take or return ` ` IValue` ` s from TorchScript methods
402
+ or you want to instantiate a custom class attribute in C++. For creating an
403
+ ` ` IValue` ` from a custom C++ class instance:
404
+
405
+ - ` ` torch::make_custom_class<T>()` ` provides an API similar to c10::intrusive_ptr<T>
406
+ in that it will take whatever set of arguments you provide to it, call the constructor
407
+ of T that matches that set of arguments, and wrap that instance up and return it.
408
+ However, instead of returning just a pointer to a custom class object, it returns
409
+ an ` ` IValue` ` wrapping the object. You can then pass this ` ` IValue` ` directly to
410
+ TorchScript.
411
+ - In the event that you already have an ` ` intrusive_ptr` ` pointing to your class, you
412
+ can directly construct an IValue from it using the constructor ` ` IValue(intrusive_ptr<T>)` ` .
413
+
414
+ For converting ` ` IValue` ` s back to custom classes:
415
+
416
+ - ` ` IValue::toCustomClass<T>()` ` will return an ` ` intrusive_ptr<T>` ` pointing to the
417
+ custom class that the ` ` IValue` ` contains. Internally, this function is checking
418
+ that ` ` T` ` is registered as a custom class and that the ` ` IValue` ` does in fact contain
419
+ a custom class. You can check whether the ` ` IValue` ` contains a custom class manually by
420
+ calling ` ` isCustomClass()` ` .
421
+
397
422
Defining Serialization/Deserialization Methods for Custom C++ Classes
398
423
---------------------------------------------------------------------
399
424
@@ -410,7 +435,7 @@ an attribute, you'll get the following error:
410
435
class Foo(torch.nn.Module):
411
436
def __init__(self):
412
437
super().__init__()
413
- self.stack = torch.classes.Stack (["just", "testing"])
438
+ self.stack = torch.classes.MyStackClass (["just", "testing"])
414
439
415
440
def forward(self, s : str) -> str:
416
441
return self.stack.pop() + s
@@ -422,7 +447,7 @@ an attribute, you'll get the following error:
422
447
.. code-block:: shell
423
448
424
449
$ python export_attr.py
425
- RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.Stack . Please define serialization methods via torch::jit::pickle_ for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
450
+ RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass . Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
426
451
427
452
This is because TorchScript cannot automatically figure out what information
428
453
save from your C++ class. You must specify that manually. The way to do that
@@ -436,20 +461,20 @@ the special ``def_pickle`` method on ``class_``.
436
461
about how we use these methods.
437
462
438
463
Here is an example of how we can update the registration code for our
439
- ` ` Stack ` ` class to include serialization methods:
464
+ ` ` MyStackClass ` ` class to include serialization methods:
440
465
441
466
.. code-block:: cpp
442
467
443
468
static auto testStack =
444
- torch::jit:: class_<Stack <std::string>>("Stack ")
445
- .def(torch::jit:: init<std::vector<std::string>>())
446
- .def("top", [](const c10::intrusive_ptr<Stack <std::string>>& self) {
469
+ torch::class_<MyStackClass <std::string>>("MyStackClass ")
470
+ .def(torch::init<std::vector<std::string>>())
471
+ .def("top", [](const c10::intrusive_ptr<MyStackClass <std::string>>& self) {
447
472
return self->stack_.back();
448
473
})
449
- .def("push", &Stack <std::string>::push)
450
- .def("pop", &Stack <std::string>::pop)
451
- .def("clone", &Stack <std::string>::clone)
452
- .def("merge", &Stack <std::string>::merge)
474
+ .def("push", &MyStackClass <std::string>::push)
475
+ .def("pop", &MyStackClass <std::string>::pop)
476
+ .def("clone", &MyStackClass <std::string>::clone)
477
+ .def("merge", &MyStackClass <std::string>::merge)
453
478
// class_<>::def_pickle allows you to define the serialization
454
479
// and deserialization methods for your C++ class.
455
480
// Currently, we only support passing stateless lambda functions
@@ -464,7 +489,7 @@ Here is an example of how we can update the registration code for our
464
489
// custom operator API. In this instance, we've chosen to return
465
490
// a std::vector<std::string> as the salient data to preserve
466
491
// from the class.
467
- [](const c10::intrusive_ptr<Stack <std::string>>& self)
492
+ [](const c10::intrusive_ptr<MyStackClass <std::string>>& self)
468
493
-> std::vector<std::string> {
469
494
return self->stack_;
470
495
},
@@ -476,13 +501,13 @@ Here is an example of how we can update the registration code for our
476
501
// to a new instance of the C++ class, initialized however
477
502
// you would like given the serialized state.
478
503
[](std::vector<std::string> state)
479
- -> c10::intrusive_ptr<Stack <std::string>> {
504
+ -> c10::intrusive_ptr<MyStackClass <std::string>> {
480
505
// A convenient way to instantiate an object and get an
481
506
// intrusive_ptr to it is via ` make_intrusive` . We use
482
- // that here to allocate an instance of Stack <std::string>
507
+ // that here to allocate an instance of MyStackClass <std::string>
483
508
// and call the single-argument std::vector<std::string>
484
509
// constructor with the serialized state.
485
- return c10::make_intrusive<Stack <std::string>>(std::move(state));
510
+ return c10::make_intrusive<MyStackClass <std::string>>(std::move(state));
486
511
});
487
512
488
513
.. note::
@@ -503,7 +528,7 @@ now run successfully:
503
528
class Foo(torch.nn.Module):
504
529
def __init__(self):
505
530
super().__init__()
506
- self.stack = torch.classes.Stack (["just", "testing"])
531
+ self.stack = torch.classes.MyStackClass (["just", "testing"])
507
532
508
533
def forward(self, s : str) -> str:
509
534
return self.stack.pop() + s
@@ -520,24 +545,25 @@ now run successfully:
520
545
$ python ../export_attr.py
521
546
testing
522
547
523
- Defining Custom Operators that Take C++ Classes as Arguments
524
- ------------------------------------------------------------
548
+ Defining Custom Operators that Take or Return Bound C++ Classes
549
+ ---------------------------------------------------------------
525
550
526
551
Once you've defined a custom C++ class, you can also use that class
527
- as an argument to custom operator (i.e. free functions). Here's an
552
+ as an argument or return from a custom operator (i.e. free functions). Here's an
528
553
example of how to do that:
529
554
530
555
.. code-block:: cpp
531
556
532
- std::string take_an_instance(const c10::intrusive_ptr<Stack<std::string>>& instance) {
533
- return instance->pop();
557
+ c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance(const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
558
+ instance->pop();
559
+ return instance;
534
560
}
535
561
536
562
static auto instance_registry = torch::RegisterOperators().op(
537
563
torch::RegisterOperators::options()
538
564
.schema(
539
- "foo::take_an_instance (__torch__.torch.classes.Stack x) -> str Y")
540
- .catchAllKernel<decltype(take_an_instance ), &take_an_instance >());
565
+ "foo::manipulate_instance (__torch__.torch.classes.MyStackClass x) -> __torch__.torch.classes.MyStackClass Y")
566
+ .catchAllKernel<decltype(manipulate_instance ), &manipulate_instance >());
541
567
542
568
Refer to the ` custom op tutorial < https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html> ` _
543
569
for more details on the registration API.
@@ -549,10 +575,10 @@ Once this is done, you can use the op like the following example:
549
575
class TryCustomOp(torch.nn.Module):
550
576
def __init__(self):
551
577
super(TryCustomOp, self).__init__()
552
- self.f = torch.classes.Stack (["foo", "bar"])
578
+ self.f = torch.classes.MyStackClass (["foo", "bar"])
553
579
554
- def forward(self) -> str :
555
- return torch.ops._TorchScriptTesting.take_an_instance (self.f)
580
+ def forward(self):
581
+ return torch.ops.foo.manipulate_instance (self.f)
556
582
557
583
.. note::
558
584
0 commit comments