Skip to content

Commit c6059ec

Browse files
committed
Make torch_script_custom_classes tutorial runnable
I also fixed some warnings in the tutorial, and fixed some minor bitrot (e.g., torch::script::Module to torch::jit::Module) I also added some missing quotes around some bash expansions. Signed-off-by: Edward Z. Yang <ezyang@fb.com>
1 parent f90f773 commit c6059ec

File tree

12 files changed

+353
-301
lines changed

12 files changed

+353
-301
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ intermediate
44
advanced
55
pytorch_basics
66
recipes
7+
prototype
78

89
#data things
910
_data/
@@ -117,3 +118,6 @@ ENV/
117118
.DS_Store
118119
cleanup.sh
119120
*.swp
121+
122+
# PyTorch things
123+
*.pt

advanced_source/torch_script_custom_classes.rst

Lines changed: 35 additions & 298 deletions
Large diffs are not rendered by default.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(infer)
3+
4+
find_package(Torch REQUIRED)
5+
6+
add_subdirectory(custom_class_project)
7+
8+
# Define our library target
9+
add_executable(infer infer.cpp)
10+
set(CMAKE_CXX_STANDARD 14)
11+
# Link against LibTorch
12+
target_link_libraries(infer "${TORCH_LIBRARIES}")
13+
# This is where we link in our libcustom_class code, making our
14+
# custom class available in our binary.
15+
target_link_libraries(infer -Wl,--no-as-needed custom_class)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(custom_class)
3+
4+
find_package(Torch REQUIRED)
5+
6+
# Define our library target
7+
add_library(custom_class SHARED class.cpp)
8+
set(CMAKE_CXX_STANDARD 14)
9+
# Link against LibTorch
10+
target_link_libraries(custom_class "${TORCH_LIBRARIES}")
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// BEGIN class
2+
// This header is all you need to do the C++ portions of this
3+
// tutorial
4+
#include <torch/script.h>
5+
// This header is what defines the custom class registration
6+
// behavior specifically. script.h already includes this, but
7+
// we include it here so you know it exists in case you want
8+
// to look at the API or implementation.
9+
#include <torch/custom_class.h>
10+
11+
#include <string>
12+
#include <vector>
13+
14+
template <class T>
15+
struct MyStackClass : torch::CustomClassHolder {
16+
std::vector<T> stack_;
17+
MyStackClass(std::vector<T> init) : stack_(init.begin(), init.end()) {}
18+
19+
void push(T x) {
20+
stack_.push_back(x);
21+
}
22+
T pop() {
23+
auto val = stack_.back();
24+
stack_.pop_back();
25+
return val;
26+
}
27+
28+
c10::intrusive_ptr<MyStackClass> clone() const {
29+
return c10::make_intrusive<MyStackClass>(stack_);
30+
}
31+
32+
void merge(const c10::intrusive_ptr<MyStackClass>& c) {
33+
for (auto& elem : c->stack_) {
34+
push(elem);
35+
}
36+
}
37+
};
38+
// END class
39+
40+
#ifdef NO_PICKLE
41+
42+
// BEGIN binding
43+
// Notice a few things:
44+
// - We pass the class to be registered as a template parameter to
45+
// `torch::class_`. In this instance, we've passed the
46+
// specialization of the MyStackClass class ``MyStackClass<std::string>``.
47+
// In general, you cannot register a non-specialized template
48+
// class. For non-templated classes, you can just pass the
49+
// class name directly as the template parameter.
50+
// - The arguments passed to the constructor make up the "qualified name"
51+
// of the class. In this case, the registered class will appear in
52+
// Python and C++ as `torch.classes.my_classes.MyStackClass`. We call
53+
// the first argument the "namespace" and the second argument the
54+
// actual class name.
55+
static auto testStack =
56+
torch::class_<MyStackClass<std::string>>("my_classes", "MyStackClass")
57+
// The following line registers the contructor of our MyStackClass
58+
// class that takes a single `std::vector<std::string>` argument,
59+
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
60+
// Currently, we do not support registering overloaded
61+
// constructors, so for now you can only `def()` one instance of
62+
// `torch::init`.
63+
.def(torch::init<std::vector<std::string>>())
64+
// The next line registers a stateless (i.e. no captures) C++ lambda
65+
// function as a method. Note that a lambda function must take a
66+
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
67+
// as the first argument. Other arguments can be whatever you want.
68+
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
69+
return self->stack_.back();
70+
})
71+
// The following four lines expose methods of the MyStackClass<std::string>
72+
// class as-is. `torch::class_` will automatically examine the
73+
// argument and return types of the passed-in method pointers and
74+
// expose these to Python and TorchScript accordingly. Finally, notice
75+
// that we must take the *address* of the fully-qualified method name,
76+
// i.e. use the unary `&` operator, due to C++ typing rules.
77+
.def("push", &MyStackClass<std::string>::push)
78+
.def("pop", &MyStackClass<std::string>::pop)
79+
.def("clone", &MyStackClass<std::string>::clone)
80+
.def("merge", &MyStackClass<std::string>::merge);
81+
// END binding
82+
83+
#else
84+
85+
// BEGIN pickle_binding
86+
static auto testStack =
87+
torch::class_<MyStackClass<std::string>>("my_classes", "MyStackClass")
88+
.def(torch::init<std::vector<std::string>>())
89+
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
90+
return self->stack_.back();
91+
})
92+
.def("push", &MyStackClass<std::string>::push)
93+
.def("pop", &MyStackClass<std::string>::pop)
94+
.def("clone", &MyStackClass<std::string>::clone)
95+
.def("merge", &MyStackClass<std::string>::merge)
96+
// class_<>::def_pickle allows you to define the serialization
97+
// and deserialization methods for your C++ class.
98+
// Currently, we only support passing stateless lambda functions
99+
// as arguments to def_pickle
100+
.def_pickle(
101+
// __getstate__
102+
// This function defines what data structure should be produced
103+
// when we serialize an instance of this class. The function
104+
// must take a single `self` argument, which is an intrusive_ptr
105+
// to the instance of the object. The function can return
106+
// any type that is supported as a return value of the TorchScript
107+
// custom operator API. In this instance, we've chosen to return
108+
// a std::vector<std::string> as the salient data to preserve
109+
// from the class.
110+
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
111+
-> std::vector<std::string> {
112+
return self->stack_;
113+
},
114+
// __setstate__
115+
// This function defines how to create a new instance of the C++
116+
// class when we are deserializing. The function must take a
117+
// single argument of the same type as the return value of
118+
// `__getstate__`. The function must return an intrusive_ptr
119+
// to a new instance of the C++ class, initialized however
120+
// you would like given the serialized state.
121+
[](std::vector<std::string> state)
122+
-> c10::intrusive_ptr<MyStackClass<std::string>> {
123+
// A convenient way to instantiate an object and get an
124+
// intrusive_ptr to it is via `make_intrusive`. We use
125+
// that here to allocate an instance of MyStackClass<std::string>
126+
// and call the single-argument std::vector<std::string>
127+
// constructor with the serialized state.
128+
return c10::make_intrusive<MyStackClass<std::string>>(std::move(state));
129+
});
130+
// END pickle_binding
131+
132+
// BEGIN free_function
133+
c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance(const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
134+
instance->pop();
135+
return instance;
136+
}
137+
138+
static auto instance_registry = torch::RegisterOperators().op(
139+
torch::RegisterOperators::options()
140+
.schema(
141+
"foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y")
142+
.catchAllKernel<decltype(manipulate_instance), &manipulate_instance>());
143+
// END free_function
144+
145+
#endif
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
3+
# `torch.classes.load_library()` allows you to pass the path to your .so file
4+
# to load it in and make the custom C++ classes available to both Python and
5+
# TorchScript
6+
torch.classes.load_library("build/libcustom_class.so")
7+
# You can query the loaded libraries like this:
8+
print(torch.classes.loaded_libraries)
9+
# prints {'/custom_class_project/build/libcustom_class.so'}
10+
11+
# We can find and instantiate our custom C++ class in python by using the
12+
# `torch.classes` namespace:
13+
#
14+
# This instantiation will invoke the MyStackClass(std::vector<T> init)
15+
# constructor we registered earlier
16+
s = torch.classes.my_classes.MyStackClass(["foo", "bar"])
17+
18+
# We can call methods in Python
19+
s.push("pushed")
20+
assert s.pop() == "pushed"
21+
22+
# Returning and passing instances of custom classes works as you'd expect
23+
s2 = s.clone()
24+
s.merge(s2)
25+
for expected in ["bar", "foo", "bar", "foo"]:
26+
assert s.pop() == expected
27+
28+
# We can also use the class in TorchScript
29+
# For now, we need to assign the class's type to a local in order to
30+
# annotate the type on the TorchScript function. This may change
31+
# in the future.
32+
MyStackClass = torch.classes.my_classes.MyStackClass
33+
34+
35+
@torch.jit.script
36+
def do_stacks(s: MyStackClass): # We can pass a custom class instance
37+
# We can instantiate the class
38+
s2 = torch.classes.my_classes.MyStackClass(["hi", "mom"])
39+
s2.merge(s) # We can call a method on the class
40+
# We can also return instances of the class
41+
# from TorchScript function/methods
42+
return s2.clone(), s2.top()
43+
44+
45+
stack, top = do_stacks(torch.classes.my_classes.MyStackClass(["wow"]))
46+
assert top == "wow"
47+
for expected in ["wow", "mom", "hi"]:
48+
assert stack.pop() == expected
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# export_attr.py
2+
import torch
3+
4+
torch.classes.load_library('build/libcustom_class.so')
5+
6+
7+
class Foo(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
self.stack = torch.classes.my_classes.MyStackClass(["just", "testing"])
11+
12+
def forward(self, s: str) -> str:
13+
return self.stack.pop() + s
14+
15+
16+
scripted_foo = torch.jit.script(Foo())
17+
18+
scripted_foo.save('foo.pt')
19+
loaded = torch.jit.load('foo.pt')
20+
21+
print(loaded.stack.pop())
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
3+
torch.classes.load_library('build/libcustom_class.so')
4+
5+
6+
class Foo(torch.nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
10+
def forward(self, s: str) -> str:
11+
stack = torch.classes.my_classes.MyStackClass(["hi", "mom"])
12+
return stack.pop() + s
13+
14+
15+
scripted_foo = torch.jit.script(Foo())
16+
print(scripted_foo.graph)
17+
18+
scripted_foo.save('foo.pt')
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <torch/script.h>
2+
3+
#include <iostream>
4+
#include <memory>
5+
6+
int main(int argc, const char* argv[]) {
7+
torch::jit::Module module;
8+
try {
9+
// Deserialize the ScriptModule from a file using torch::jit::load().
10+
module = torch::jit::load("foo.pt");
11+
}
12+
catch (const c10::Error& e) {
13+
std::cerr << "error loading the model\n";
14+
return -1;
15+
}
16+
17+
std::vector<c10::IValue> inputs = {"foobarbaz"};
18+
auto output = module.forward(inputs).toString();
19+
std::cout << output->string() << std::endl;
20+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
3+
set -ex
4+
5+
rm -rf build
6+
rm -rf custom_class_project/build
7+
8+
pushd custom_class_project
9+
mkdir build
10+
(cd build && cmake CXXFLAGS="-DNO_PICKLE" -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..)
11+
(cd build && make)
12+
python custom_test.py
13+
python save.py
14+
! python export_attr.py
15+
popd
16+
17+
mkdir build
18+
(cd build && cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..)
19+
(cd build && make)
20+
mv custom_class_project/foo.pt build/foo.pt
21+
(cd build && ./infer)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/bash
2+
3+
set -ex
4+
5+
rm -rf build
6+
rm -rf custom_class_project/build
7+
8+
pushd custom_class_project
9+
mkdir build
10+
(cd build && cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..)
11+
(cd build && make)
12+
python export_attr.py
13+
popd

advanced_source/torch_script_custom_ops.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ To now build our operator, we can run the following commands from our
212212
213213
$ mkdir build
214214
$ cd build
215-
$ cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)') ..
215+
$ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..
216216
-- The C compiler identification is GNU 5.4.0
217217
-- The CXX compiler identification is GNU 5.4.0
218218
-- Check for working C compiler: /usr/bin/cc
@@ -609,7 +609,7 @@ At this point, we should be able to build the application:
609609
610610
$ mkdir build
611611
$ cd build
612-
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
612+
$ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..
613613
-- The C compiler identification is GNU 5.4.0
614614
-- The CXX compiler identification is GNU 5.4.0
615615
-- Check for working C compiler: /usr/bin/cc
@@ -752,7 +752,7 @@ library. In the top level ``example_app`` directory:
752752
753753
$ mkdir build
754754
$ cd build
755-
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
755+
$ cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" ..
756756
-- The C compiler identification is GNU 5.4.0
757757
-- The CXX compiler identification is GNU 5.4.0
758758
-- Check for working C compiler: /usr/bin/cc

0 commit comments

Comments
 (0)