@@ -29,7 +29,7 @@ state in a member variable.
29
29
#include <vector>
30
30
31
31
template <class T>
32
- struct MyStackClass : torch::jit:: CustomClassHolder {
32
+ struct MyStackClass : torch::CustomClassHolder {
33
33
std::vector<T> stack_;
34
34
MyStackClass(std::vector<T> init) : stack_(init.begin(), init.end()) {}
35
35
@@ -63,7 +63,7 @@ There are several things to note:
63
63
is to ensure consistent lifetime management of the object instances between languages
64
64
(C++, Python and TorchScript).
65
65
- The second thing to notice is that the user-defined class must inherit from
66
- ``torch::jit:: CustomClassHolder ``. This ensures that everything is set up to handle
66
+ ``torch::CustomClassHolder ``. This ensures that everything is set up to handle
67
67
the lifetime management system previously mentioned.
68
68
69
69
Now let's take a look at how we will make this class visible to TorchScript, a process called
@@ -73,24 +73,24 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
73
73
74
74
// Notice a few things:
75
75
// - We pass the class to be registered as a template parameter to
76
- // `torch::jit:: class_`. In this instance, we've passed the
76
+ // `torch::class_`. In this instance, we've passed the
77
77
// specialization of the MyStackClass class ``MyStackClass<std::string>``.
78
78
// In general, you cannot register a non-specialized template
79
79
// class. For non-templated classes, you can just pass the
80
80
// class name directly as the template parameter.
81
- // - The single parameter to ``torch::jit:: class_()`` is a
81
+ // - The single parameter to ``torch::class_()`` is a
82
82
// string indicating the name of the class. This is the name
83
83
// the class will appear as in both Python and TorchScript.
84
84
// For example, our MyStackClass class would appear as ``torch.classes.MyStackClass``.
85
85
static auto testStack =
86
- torch::jit:: class_<MyStackClass<std::string>>("MyStackClass")
86
+ torch::class_<MyStackClass<std::string>>("MyStackClass")
87
87
// The following line registers the contructor of our MyStackClass
88
88
// class that takes a single `std::vector<std::string>` argument,
89
89
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
90
90
// Currently, we do not support registering overloaded
91
91
// constructors, so for now you can only `def()` one instance of
92
- // `torch::jit:: init`.
93
- .def(torch::jit:: init<std::vector<std::string>>())
92
+ // `torch::init`.
93
+ .def(torch::init<std::vector<std::string>>())
94
94
// The next line registers a stateless (i.e. no captures) C++ lambda
95
95
// function as a method. Note that a lambda function must take a
96
96
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
@@ -99,7 +99,7 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
99
99
return self->stack_.back();
100
100
})
101
101
// The following four lines expose methods of the MyStackClass<std::string>
102
- // class as-is. `torch::jit:: class_` will automatically examine the
102
+ // class as-is. `torch::class_` will automatically examine the
103
103
// argument and return types of the passed-in method pointers and
104
104
// expose these to Python and TorchScript accordingly. Finally, notice
105
105
// that we must take the *address* of the fully-qualified method name,
@@ -307,7 +307,7 @@ Let's populate ``infer.cpp`` with the following:
307
307
# include <memory>
308
308
309
309
int main(int argc, const char* argv[]) {
310
- torch::jit:: script::Module module;
310
+ torch::script::Module module;
311
311
try {
312
312
// Deserialize the ScriptModule from a file using torch::jit::load ().
313
313
module = torch::jit::load(" foo.pt" );
@@ -394,6 +394,31 @@ And now we can run our exciting C++ binary:
394
394
395
395
Incredible!
396
396
397
+ Moving Custom Classes To/From IValues
398
+ -------------------------------------
399
+
400
+ It's also possible that you may need to move custom classes into or out of
401
+ ` ` IValue` ` s, such as when you take or return ` ` IValue` ` s from TorchScript methods
402
+ or you want to instantiate a custom class attribute in C++. For creating an
403
+ ` ` IValue` ` from a custom C++ class instance:
404
+
405
+ - ` ` torch::make_custom_class<T>()` ` provides an API similar to c10::intrusive_ptr<T>
406
+ in that it will take whatever set of arguments you provide to it, call the constructor
407
+ of T that matches that set of arguments, and wrap that instance up and return it.
408
+ However, instead of returning just a pointer to a custom class object, it returns
409
+ an ` ` IValue` ` wrapping the object. You can then pass this ` ` IValue` ` directly to
410
+ TorchScript.
411
+ - In the event that you already have an ` ` intrusive_ptr` ` pointing to your class, you
412
+ can directly construct an IValue from it using the constructor ` ` IValue(intrusive_ptr<T>)` ` .
413
+
414
+ For converting ` ` IValue` ` s back to custom classes:
415
+
416
+ - ` ` IValue::toCustomClass<T>()` ` will return an ` ` intrusive_ptr<T>` ` pointing to the
417
+ custom class that the ` ` IValue` ` contains. Internally, this function is checking
418
+ that ` ` T` ` is registered as a custom class and that the ` ` IValue` ` does in fact contain
419
+ a custom class. You can check whether the ` ` IValue` ` contains a custom class manually by
420
+ calling ` ` isCustomClass()` ` .
421
+
397
422
Defining Serialization/Deserialization Methods for Custom C++ Classes
398
423
---------------------------------------------------------------------
399
424
@@ -422,7 +447,7 @@ an attribute, you'll get the following error:
422
447
.. code-block:: shell
423
448
424
449
$ python export_attr.py
425
- RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via torch::jit::pickle_ for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
450
+ RuntimeError: Cannot serialize custom bound C++ class __torch__.torch.classes.MyStackClass. Please define serialization methods via def_pickle for this class. (pushIValueImpl at ../torch/csrc/jit/pickler.cpp:128)
426
451
427
452
This is because TorchScript cannot automatically figure out what information
428
453
save from your C++ class. You must specify that manually. The way to do that
@@ -441,8 +466,8 @@ Here is an example of how we can update the registration code for our
441
466
.. code-block:: cpp
442
467
443
468
static auto testStack =
444
- torch::jit:: class_<MyStackClass<std::string>>("MyStackClass")
445
- .def(torch::jit:: init<std::vector<std::string>>())
469
+ torch::class_<MyStackClass<std::string>>("MyStackClass")
470
+ .def(torch::init<std::vector<std::string>>())
446
471
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
447
472
return self->stack_.back();
448
473
})
0 commit comments