Skip to content

Commit 32e5407

Browse files
authored
Update torch_script_custom_classes to use TORCH_LIBRARY (#1062)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
1 parent c6059ec commit 32e5407

File tree

2 files changed

+106
-106
lines changed

2 files changed

+106
-106
lines changed

advanced_source/torch_script_custom_classes.rst

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ There are several things to note:
2525
with your custom class.
2626
- Notice that whenever we are working with instances of the custom
2727
class, we do it via instances of ``c10::intrusive_ptr<>``. Think of ``intrusive_ptr``
28-
as a smart pointer like ``std::shared_ptr``. The reason for using this smart pointer
29-
is to ensure consistent lifetime management of the object instances between languages
30-
(C++, Python and TorchScript).
28+
as a smart pointer like ``std::shared_ptr``, but the reference count is stored
29+
directly in the object, as opposed to a separate metadata block (as is done in
30+
``std::shared_ptr``. ``torch::Tensor`` internally uses the same pointer type;
31+
and custom classes have to also use this pointer type so that we can
32+
consistently manage different object types.
3133
- The second thing to notice is that the user-defined class must inherit from
32-
``torch::CustomClassHolder``. This ensures that everything is set up to handle
33-
the lifetime management system previously mentioned.
34+
``torch::CustomClassHolder``. This ensures that the custom class has space to
35+
store the reference count.
3436

3537
Now let's take a look at how we will make this class visible to TorchScript, a process called
3638
*binding* the class:
@@ -39,6 +41,9 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
3941
:language: cpp
4042
:start-after: BEGIN binding
4143
:end-before: END binding
44+
:append:
45+
;
46+
}
4247

4348

4449

@@ -269,13 +274,13 @@ the special ``def_pickle`` method on ``class_``.
269274
`read more <https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md#getstate-and-setstate>`_
270275
about how we use these methods.
271276
272-
Here is an example of how we can update the registration code for our
273-
``MyStackClass`` class to include serialization methods:
277+
Here is an example of the ``def_pickle`` call we can add to the registration of
278+
``MyStackClass`` to include serialization methods:
274279
275280
.. literalinclude:: ../advanced_source/torch_script_custom_classes/custom_class_project/class.cpp
276281
:language: cpp
277-
:start-after: BEGIN pickle_binding
278-
:end-before: END pickle_binding
282+
:start-after: BEGIN def_pickle
283+
:end-before: END def_pickle
279284
280285
.. note::
281286
We take a different approach from pybind11 in the pickle API. Whereas pybind11
@@ -295,14 +300,22 @@ Defining Custom Operators that Take or Return Bound C++ Classes
295300
---------------------------------------------------------------
296301
297302
Once you've defined a custom C++ class, you can also use that class
298-
as an argument or return from a custom operator (i.e. free functions). Here's an
299-
example of how to do that:
303+
as an argument or return from a custom operator (i.e. free functions). Suppose
304+
you have the following free function:
300305
301306
.. literalinclude:: ../advanced_source/torch_script_custom_classes/custom_class_project/class.cpp
302307
:language: cpp
303308
:start-after: BEGIN free_function
304309
:end-before: END free_function
305310
311+
You can register it running the following code inside your ``TORCH_LIBRARY``
312+
block:
313+
314+
.. literalinclude:: ../advanced_source/torch_script_custom_classes/custom_class_project/class.cpp
315+
:language: cpp
316+
:start-after: BEGIN def_free
317+
:end-before: END def_free
318+
306319
Refer to the `custom op tutorial <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_
307320
for more details on the registration API.
308321
@@ -321,12 +334,12 @@ Once this is done, you can use the op like the following example:
321334
.. note::
322335
323336
Registration of an operator that takes a C++ class as an argument requires that
324-
the custom class has already been registered. This is fine if your op is
325-
registered after your class in a single compilation unit, however, if your
326-
class is registered in a separate compilation unit from the op you will need
327-
to enforce that dependency. One way to do this is to wrap the class registration
328-
in a `Meyer's singleton <https://stackoverflow.com/q/1661529>`_, which can be
329-
called from the compilation unit that does the operator registration.
337+
the custom class has already been registered. You can enforce this by
338+
making sure the custom class registration and your free function definitions
339+
are in the same ``TORCH_LIBRARY`` block, and that the custom class
340+
registration comes first. In the future, we may relax this requirement,
341+
so that these can be registered in any order.
342+
330343
331344
Conclusion
332345
----------

advanced_source/torch_script_custom_classes/custom_class_project/class.cpp

Lines changed: 76 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ struct MyStackClass : torch::CustomClassHolder {
3737
};
3838
// END class
3939

40-
#ifdef NO_PICKLE
40+
// BEGIN free_function
41+
c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance(const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
42+
instance->pop();
43+
return instance;
44+
}
45+
// END free_function
4146

4247
// BEGIN binding
4348
// Notice a few things:
@@ -52,94 +57,76 @@ struct MyStackClass : torch::CustomClassHolder {
5257
// Python and C++ as `torch.classes.my_classes.MyStackClass`. We call
5358
// the first argument the "namespace" and the second argument the
5459
// actual class name.
55-
static auto testStack =
56-
torch::class_<MyStackClass<std::string>>("my_classes", "MyStackClass")
57-
// The following line registers the contructor of our MyStackClass
58-
// class that takes a single `std::vector<std::string>` argument,
59-
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
60-
// Currently, we do not support registering overloaded
61-
// constructors, so for now you can only `def()` one instance of
62-
// `torch::init`.
63-
.def(torch::init<std::vector<std::string>>())
64-
// The next line registers a stateless (i.e. no captures) C++ lambda
65-
// function as a method. Note that a lambda function must take a
66-
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
67-
// as the first argument. Other arguments can be whatever you want.
68-
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
69-
return self->stack_.back();
70-
})
71-
// The following four lines expose methods of the MyStackClass<std::string>
72-
// class as-is. `torch::class_` will automatically examine the
73-
// argument and return types of the passed-in method pointers and
74-
// expose these to Python and TorchScript accordingly. Finally, notice
75-
// that we must take the *address* of the fully-qualified method name,
76-
// i.e. use the unary `&` operator, due to C++ typing rules.
77-
.def("push", &MyStackClass<std::string>::push)
78-
.def("pop", &MyStackClass<std::string>::pop)
79-
.def("clone", &MyStackClass<std::string>::clone)
80-
.def("merge", &MyStackClass<std::string>::merge);
60+
TORCH_LIBRARY(my_classes, m) {
61+
m.class_<MyStackClass<std::string>>("MyStackClass")
62+
// The following line registers the contructor of our MyStackClass
63+
// class that takes a single `std::vector<std::string>` argument,
64+
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
65+
// Currently, we do not support registering overloaded
66+
// constructors, so for now you can only `def()` one instance of
67+
// `torch::init`.
68+
.def(torch::init<std::vector<std::string>>())
69+
// The next line registers a stateless (i.e. no captures) C++ lambda
70+
// function as a method. Note that a lambda function must take a
71+
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
72+
// as the first argument. Other arguments can be whatever you want.
73+
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
74+
return self->stack_.back();
75+
})
76+
// The following four lines expose methods of the MyStackClass<std::string>
77+
// class as-is. `torch::class_` will automatically examine the
78+
// argument and return types of the passed-in method pointers and
79+
// expose these to Python and TorchScript accordingly. Finally, notice
80+
// that we must take the *address* of the fully-qualified method name,
81+
// i.e. use the unary `&` operator, due to C++ typing rules.
82+
.def("push", &MyStackClass<std::string>::push)
83+
.def("pop", &MyStackClass<std::string>::pop)
84+
.def("clone", &MyStackClass<std::string>::clone)
85+
.def("merge", &MyStackClass<std::string>::merge)
8186
// END binding
87+
#ifndef NO_PICKLE
88+
// BEGIN def_pickle
89+
// class_<>::def_pickle allows you to define the serialization
90+
// and deserialization methods for your C++ class.
91+
// Currently, we only support passing stateless lambda functions
92+
// as arguments to def_pickle
93+
.def_pickle(
94+
// __getstate__
95+
// This function defines what data structure should be produced
96+
// when we serialize an instance of this class. The function
97+
// must take a single `self` argument, which is an intrusive_ptr
98+
// to the instance of the object. The function can return
99+
// any type that is supported as a return value of the TorchScript
100+
// custom operator API. In this instance, we've chosen to return
101+
// a std::vector<std::string> as the salient data to preserve
102+
// from the class.
103+
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
104+
-> std::vector<std::string> {
105+
return self->stack_;
106+
},
107+
// __setstate__
108+
// This function defines how to create a new instance of the C++
109+
// class when we are deserializing. The function must take a
110+
// single argument of the same type as the return value of
111+
// `__getstate__`. The function must return an intrusive_ptr
112+
// to a new instance of the C++ class, initialized however
113+
// you would like given the serialized state.
114+
[](std::vector<std::string> state)
115+
-> c10::intrusive_ptr<MyStackClass<std::string>> {
116+
// A convenient way to instantiate an object and get an
117+
// intrusive_ptr to it is via `make_intrusive`. We use
118+
// that here to allocate an instance of MyStackClass<std::string>
119+
// and call the single-argument std::vector<std::string>
120+
// constructor with the serialized state.
121+
return c10::make_intrusive<MyStackClass<std::string>>(std::move(state));
122+
});
123+
// END def_pickle
124+
#endif // NO_PICKLE
82125

83-
#else
84-
85-
// BEGIN pickle_binding
86-
static auto testStack =
87-
torch::class_<MyStackClass<std::string>>("my_classes", "MyStackClass")
88-
.def(torch::init<std::vector<std::string>>())
89-
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
90-
return self->stack_.back();
91-
})
92-
.def("push", &MyStackClass<std::string>::push)
93-
.def("pop", &MyStackClass<std::string>::pop)
94-
.def("clone", &MyStackClass<std::string>::clone)
95-
.def("merge", &MyStackClass<std::string>::merge)
96-
// class_<>::def_pickle allows you to define the serialization
97-
// and deserialization methods for your C++ class.
98-
// Currently, we only support passing stateless lambda functions
99-
// as arguments to def_pickle
100-
.def_pickle(
101-
// __getstate__
102-
// This function defines what data structure should be produced
103-
// when we serialize an instance of this class. The function
104-
// must take a single `self` argument, which is an intrusive_ptr
105-
// to the instance of the object. The function can return
106-
// any type that is supported as a return value of the TorchScript
107-
// custom operator API. In this instance, we've chosen to return
108-
// a std::vector<std::string> as the salient data to preserve
109-
// from the class.
110-
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
111-
-> std::vector<std::string> {
112-
return self->stack_;
113-
},
114-
// __setstate__
115-
// This function defines how to create a new instance of the C++
116-
// class when we are deserializing. The function must take a
117-
// single argument of the same type as the return value of
118-
// `__getstate__`. The function must return an intrusive_ptr
119-
// to a new instance of the C++ class, initialized however
120-
// you would like given the serialized state.
121-
[](std::vector<std::string> state)
122-
-> c10::intrusive_ptr<MyStackClass<std::string>> {
123-
// A convenient way to instantiate an object and get an
124-
// intrusive_ptr to it is via `make_intrusive`. We use
125-
// that here to allocate an instance of MyStackClass<std::string>
126-
// and call the single-argument std::vector<std::string>
127-
// constructor with the serialized state.
128-
return c10::make_intrusive<MyStackClass<std::string>>(std::move(state));
129-
});
130-
// END pickle_binding
131-
132-
// BEGIN free_function
133-
c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance(const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
134-
instance->pop();
135-
return instance;
126+
// BEGIN def_free
127+
m.def(
128+
"foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y",
129+
manipulate_instance
130+
);
131+
// END def_free
136132
}
137-
138-
static auto instance_registry = torch::RegisterOperators().op(
139-
torch::RegisterOperators::options()
140-
.schema(
141-
"foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y")
142-
.catchAllKernel<decltype(manipulate_instance), &manipulate_instance>());
143-
// END free_function
144-
145-
#endif

0 commit comments

Comments
 (0)