@@ -37,7 +37,12 @@ struct MyStackClass : torch::CustomClassHolder {
37
37
};
38
38
// END class
39
39
40
- #ifdef NO_PICKLE
40
+ // BEGIN free_function
41
+ c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance (const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
42
+ instance->pop ();
43
+ return instance;
44
+ }
45
+ // END free_function
41
46
42
47
// BEGIN binding
43
48
// Notice a few things:
@@ -52,94 +57,76 @@ struct MyStackClass : torch::CustomClassHolder {
52
57
// Python and C++ as `torch.classes.my_classes.MyStackClass`. We call
53
58
// the first argument the "namespace" and the second argument the
54
59
// actual class name.
55
- static auto testStack =
56
- torch:: class_<MyStackClass<std::string>>(" my_classes " , " MyStackClass" )
57
- // The following line registers the contructor of our MyStackClass
58
- // class that takes a single `std::vector<std::string>` argument,
59
- // i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
60
- // Currently, we do not support registering overloaded
61
- // constructors, so for now you can only `def()` one instance of
62
- // `torch::init`.
63
- .def(torch::init<std::vector<std::string>>())
64
- // The next line registers a stateless (i.e. no captures) C++ lambda
65
- // function as a method. Note that a lambda function must take a
66
- // `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
67
- // as the first argument. Other arguments can be whatever you want.
68
- .def(" top" , [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
69
- return self->stack_ .back ();
70
- })
71
- // The following four lines expose methods of the MyStackClass<std::string>
72
- // class as-is. `torch::class_` will automatically examine the
73
- // argument and return types of the passed-in method pointers and
74
- // expose these to Python and TorchScript accordingly. Finally, notice
75
- // that we must take the *address* of the fully-qualified method name,
76
- // i.e. use the unary `&` operator, due to C++ typing rules.
77
- .def(" push" , &MyStackClass<std::string>::push)
78
- .def(" pop" , &MyStackClass<std::string>::pop)
79
- .def(" clone" , &MyStackClass<std::string>::clone)
80
- .def(" merge" , &MyStackClass<std::string>::merge);
60
+ TORCH_LIBRARY (my_classes, m) {
61
+ m. class_ <MyStackClass<std::string>>(" MyStackClass" )
62
+ // The following line registers the contructor of our MyStackClass
63
+ // class that takes a single `std::vector<std::string>` argument,
64
+ // i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
65
+ // Currently, we do not support registering overloaded
66
+ // constructors, so for now you can only `def()` one instance of
67
+ // `torch::init`.
68
+ .def (torch::init<std::vector<std::string>>())
69
+ // The next line registers a stateless (i.e. no captures) C++ lambda
70
+ // function as a method. Note that a lambda function must take a
71
+ // `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
72
+ // as the first argument. Other arguments can be whatever you want.
73
+ .def (" top" , [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
74
+ return self->stack_ .back ();
75
+ })
76
+ // The following four lines expose methods of the MyStackClass<std::string>
77
+ // class as-is. `torch::class_` will automatically examine the
78
+ // argument and return types of the passed-in method pointers and
79
+ // expose these to Python and TorchScript accordingly. Finally, notice
80
+ // that we must take the *address* of the fully-qualified method name,
81
+ // i.e. use the unary `&` operator, due to C++ typing rules.
82
+ .def (" push" , &MyStackClass<std::string>::push)
83
+ .def (" pop" , &MyStackClass<std::string>::pop)
84
+ .def (" clone" , &MyStackClass<std::string>::clone)
85
+ .def (" merge" , &MyStackClass<std::string>::merge)
81
86
// END binding
87
+ #ifndef NO_PICKLE
88
+ // BEGIN def_pickle
89
+ // class_<>::def_pickle allows you to define the serialization
90
+ // and deserialization methods for your C++ class.
91
+ // Currently, we only support passing stateless lambda functions
92
+ // as arguments to def_pickle
93
+ .def_pickle (
94
+ // __getstate__
95
+ // This function defines what data structure should be produced
96
+ // when we serialize an instance of this class. The function
97
+ // must take a single `self` argument, which is an intrusive_ptr
98
+ // to the instance of the object. The function can return
99
+ // any type that is supported as a return value of the TorchScript
100
+ // custom operator API. In this instance, we've chosen to return
101
+ // a std::vector<std::string> as the salient data to preserve
102
+ // from the class.
103
+ [](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
104
+ -> std::vector<std::string> {
105
+ return self->stack_ ;
106
+ },
107
+ // __setstate__
108
+ // This function defines how to create a new instance of the C++
109
+ // class when we are deserializing. The function must take a
110
+ // single argument of the same type as the return value of
111
+ // `__getstate__`. The function must return an intrusive_ptr
112
+ // to a new instance of the C++ class, initialized however
113
+ // you would like given the serialized state.
114
+ [](std::vector<std::string> state)
115
+ -> c10::intrusive_ptr<MyStackClass<std::string>> {
116
+ // A convenient way to instantiate an object and get an
117
+ // intrusive_ptr to it is via `make_intrusive`. We use
118
+ // that here to allocate an instance of MyStackClass<std::string>
119
+ // and call the single-argument std::vector<std::string>
120
+ // constructor with the serialized state.
121
+ return c10::make_intrusive<MyStackClass<std::string>>(std::move (state));
122
+ });
123
+ // END def_pickle
124
+ #endif // NO_PICKLE
82
125
83
- #else
84
-
85
- // BEGIN pickle_binding
86
- static auto testStack =
87
- torch::class_<MyStackClass<std::string>>(" my_classes" , " MyStackClass" )
88
- .def(torch::init<std::vector<std::string>>())
89
- .def(" top" , [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
90
- return self->stack_ .back ();
91
- })
92
- .def(" push" , &MyStackClass<std::string>::push)
93
- .def(" pop" , &MyStackClass<std::string>::pop)
94
- .def(" clone" , &MyStackClass<std::string>::clone)
95
- .def(" merge" , &MyStackClass<std::string>::merge)
96
- // class_<>::def_pickle allows you to define the serialization
97
- // and deserialization methods for your C++ class.
98
- // Currently, we only support passing stateless lambda functions
99
- // as arguments to def_pickle
100
- .def_pickle(
101
- // __getstate__
102
- // This function defines what data structure should be produced
103
- // when we serialize an instance of this class. The function
104
- // must take a single `self` argument, which is an intrusive_ptr
105
- // to the instance of the object. The function can return
106
- // any type that is supported as a return value of the TorchScript
107
- // custom operator API. In this instance, we've chosen to return
108
- // a std::vector<std::string> as the salient data to preserve
109
- // from the class.
110
- [](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
111
- -> std::vector<std::string> {
112
- return self->stack_ ;
113
- },
114
- // __setstate__
115
- // This function defines how to create a new instance of the C++
116
- // class when we are deserializing. The function must take a
117
- // single argument of the same type as the return value of
118
- // `__getstate__`. The function must return an intrusive_ptr
119
- // to a new instance of the C++ class, initialized however
120
- // you would like given the serialized state.
121
- [](std::vector<std::string> state)
122
- -> c10::intrusive_ptr<MyStackClass<std::string>> {
123
- // A convenient way to instantiate an object and get an
124
- // intrusive_ptr to it is via `make_intrusive`. We use
125
- // that here to allocate an instance of MyStackClass<std::string>
126
- // and call the single-argument std::vector<std::string>
127
- // constructor with the serialized state.
128
- return c10::make_intrusive<MyStackClass<std::string>>(std::move (state));
129
- });
130
- // END pickle_binding
131
-
132
- // BEGIN free_function
133
- c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance (const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
134
- instance->pop ();
135
- return instance;
126
+ // BEGIN def_free
127
+ m.def (
128
+ " foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y" ,
129
+ manipulate_instance
130
+ );
131
+ // END def_free
136
132
}
137
-
138
- static auto instance_registry = torch::RegisterOperators().op(
139
- torch::RegisterOperators::options ()
140
- .schema(
141
- " foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y" )
142
- .catchAllKernel<decltype(manipulate_instance), &manipulate_instance>());
143
- // END free_function
144
-
145
- #endif
0 commit comments