Skip to content

Commit 5c44056

Browse files
authored
fine tuned intel_extension_for_pytorch. 1. moved intel_extension_for_pytorch.py to intel_extension_for_pytorch.rst to make it compatible to cpp source codes. 2. fixed bug of Float32 level missing in intel_extension_for_pytorch (#1704)
1 parent 96f2247 commit 5c44056

File tree

3 files changed

+340
-353
lines changed

3 files changed

+340
-353
lines changed
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
Intel® Extension for PyTorch*
2+
=============================
3+
4+
Intel Extension for PyTorch* extends PyTorch with optimizations for extra
5+
performance boost on Intel hardware. Most of the optimizations will be
6+
included in stock PyTorch releases eventually, and the intention of the
7+
extension is to deliver up to date features and optimizations for PyTorch
8+
on Intel hardware, examples include AVX-512 Vector Neural Network
9+
Instructions (AVX512 VNNI) and Intel® Advanced Matrix Extensions (Intel® AMX).
10+
11+
Intel® Extension for PyTorch* has been released as an open–source project
12+
at `Github <https://github.com/intel/intel-extension-for-pytorch>`_.
13+
14+
Features
15+
--------
16+
17+
- **Ease-of-use Python API:** Intel® Extension for PyTorch* provides simple
18+
frontend Python APIs and utilities for users to get performance optimizations
19+
such as graph optimization and operator optimization with minor code changes.
20+
Typically, only 2 to 3 clauses are required to be added to the original code.
21+
- **Channels Last:** Comparing to the default NCHW memory format, channels_last
22+
(NHWC) memory format could further accelerate convolutional neural networks.
23+
In Intel® Extension for PyTorch*, NHWC memory format has been enabled for
24+
most key CPU operators, though not all of them have been merged to PyTorch
25+
master branch yet. They are expected to be fully landed in PyTorch upstream
26+
soon.
27+
- **Auto Mixed Precision (AMP):** Low precision data type BFloat16 has been
28+
natively supported on the 3rd Generation Xeon scalable Servers (aka Cooper
29+
Lake) with AVX512 instruction set and will be supported on the next
30+
generation of Intel® Xeon® Scalable Processors with Intel® Advanced Matrix
31+
Extensions (Intel® AMX) instruction set with further boosted performance. The
32+
support of Auto Mixed Precision (AMP) with BFloat16 for CPU and BFloat16
33+
optimization of operators have been massively enabled in Intel® Extension
34+
for PyTorch*, and partially upstreamed to PyTorch master branch. Most of
35+
these optimizations will be landed in PyTorch master through PRs that are
36+
being submitted and reviewed.
37+
- **Graph Optimization:** To optimize performance further with torchscript,
38+
Intel® Extension for PyTorch* supports fusion of frequently used operator
39+
patterns, like Conv2D+ReLU, Linear+ReLU, etc. The benefit of the fusions are
40+
delivered to users in a transparent fashion. Detailed fusion patterns
41+
supported can be found `here <https://github.com/intel/intel-extension-for-pytorch>`_.
42+
The graph optimization will be up-streamed to PyTorch with the introduction
43+
of oneDNN Graph API.
44+
- **Operator Optimization:** Intel® Extension for PyTorch* also optimizes
45+
operators and implements several customized operators for performance. A few
46+
ATen operators are replaced by their optimized counterparts in Intel®
47+
Extension for PyTorch* via ATen registration mechanism. Moreover, some
48+
customized operators are implemented for several popular topologies. For
49+
instance, ROIAlign and NMS are defined in Mask R-CNN. To improve performance
50+
of these topologies, Intel® Extension for PyTorch* also optimized these
51+
customized operators.
52+
53+
Getting Started
54+
---------------
55+
56+
Minor code changes are required for users to get start with Intel® Extension
57+
for PyTorch*. Both PyTorch imperative mode and TorchScript mode are
58+
supported. This section introduces usage of Intel® Extension for PyTorch* API
59+
functions for both imperative mode and TorchScript mode, covering data type
60+
Float32 and BFloat16. C++ usage will also be introduced at the end.
61+
62+
You just need to import Intel® Extension for PyTorch* package and apply its
63+
optimize function against the model object. If it is a training workload, the
64+
optimize function also needs to be applied against the optimizer object.
65+
66+
For training and inference with BFloat16 data type, torch.cpu.amp has been
67+
enabled in PyTorch upstream to support mixed precision with convenience, and
68+
BFloat16 datatype has been enabled excessively for CPU operators in PyTorch
69+
upstream and Intel® Extension for PyTorch*. Running torch.cpu.amp will match
70+
each operator to its appropriate datatype and returns the best possible
71+
performance.
72+
73+
The code changes that are required for Intel® Extension for PyTorch* are
74+
highlighted with comments in a line above.
75+
76+
Training
77+
~~~~~~~~
78+
79+
Float32
80+
^^^^^^^
81+
82+
.. code:: python3
83+
84+
import torch
85+
import torch.nn as nn
86+
# Import intel_extension_for_pytorch
87+
import intel_extension_for_pytorch as ipex
88+
89+
class Model(nn.Module):
90+
def __init__(self):
91+
super(Model, self).__init__()
92+
self.linear = nn.Linear(4, 5)
93+
94+
def forward(self, input):
95+
return self.linear(input)
96+
97+
model = Model()
98+
model.set_state_dict(torch.load(PATH))
99+
optimizer.set_state_dict(torch.load(PATH))
100+
# Invoke optimize function against the model object and optimizer object
101+
model, optimizer = ipex.optimize(model, optimizer, dtype=torch.float32)
102+
103+
for images, label in train_loader():
104+
# Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional.
105+
images = images.to(memory_format=torch.channels_last)
106+
loss = criterion(model(images), label)
107+
loss.backward()
108+
optimizer.step()
109+
torch.save(model.state_dict(), PATH)
110+
torch.save(optimizer.state_dict(), PATH)
111+
112+
BFloat16
113+
^^^^^^^^
114+
115+
.. code:: python3
116+
117+
import torch
118+
import torch.nn as nn
119+
# Import intel_extension_for_pytorch
120+
import intel_extension_for_pytorch as ipex
121+
122+
class Model(nn.Module):
123+
def __init__(self):
124+
super(Model, self).__init__()
125+
self.linear = nn.Linear(4, 5)
126+
127+
def forward(self, input):
128+
return self.linear(input)
129+
130+
model = Model()
131+
model.set_state_dict(torch.load(PATH))
132+
optimizer.set_state_dict(torch.load(PATH))
133+
# Invoke optimize function against the model object and optimizer object with data type set to torch.bfloat16
134+
model, optimizer = ipex.optimize(model, optimizer, dtype=torch.bfloat16)
135+
136+
for images, label in train_loader():
137+
with torch.cpu.amp.autocast():
138+
# Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional.
139+
images = images.to(memory_format=torch.channels_last)
140+
loss = criterion(model(images), label)
141+
loss.backward()
142+
optimizer.step()
143+
torch.save(model.state_dict(), PATH)
144+
torch.save(optimizer.state_dict(), PATH)
145+
146+
Inference - Imperative Mode
147+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
148+
149+
Float32
150+
^^^^^^^
151+
152+
.. code:: python3
153+
154+
import torch
155+
import torch.nn as nn
156+
# Import intel_extension_for_pytorch
157+
import intel_extension_for_pytorch as ipex
158+
159+
class Model(nn.Module):
160+
def __init__(self):
161+
super(Model, self).__init__()
162+
self.linear = nn.Linear(4, 5)
163+
164+
def forward(self, input):
165+
return self.linear(input)
166+
167+
input = torch.randn(2, 4)
168+
model = Model()
169+
model.eval()
170+
# Invoke optimize function against the model object
171+
model = ipex.optimize(model, dtype=torch.float32)
172+
res = model(input)
173+
174+
BFloat16
175+
^^^^^^^^
176+
177+
.. code:: python3
178+
179+
import torch
180+
import torch.nn as nn
181+
# Import intel_extension_for_pytorch
182+
import intel_extension_for_pytorch as ipex
183+
184+
class Model(nn.Module):
185+
def __init__(self):
186+
super(Model, self).__init__()
187+
self.linear = nn.Linear(4, 5)
188+
189+
def forward(self, input):
190+
return self.linear(input)
191+
192+
input = torch.randn(2, 4)
193+
model = Model()
194+
model.eval()
195+
# Invoke optimize function against the model object with data type set to torch.bfloat16
196+
model = ipex.optimize(model, dtype=torch.bfloat16)
197+
with torch.cpu.amp.autocast():
198+
res = model(input)
199+
200+
Inference - TorchScript Mode
201+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
202+
203+
TorchScript mode makes graph optimization possible, hence improves
204+
performance for some topologies. Intel® Extension for PyTorch* enables most
205+
commonly used operator pattern fusion, and users can get the performance
206+
benefit without additional code changes.
207+
208+
Float32
209+
^^^^^^^
210+
211+
.. code:: python3
212+
213+
import torch
214+
import torch.nn as nn
215+
# Import intel_extension_for_pytorch
216+
import intel_extension_for_pytorch as ipex
217+
218+
# oneDNN graph fusion is enabled by default, uncomment the line below to disable it explicitly
219+
# ipex.enable_onednn_fusion(False)
220+
221+
class Model(nn.Module):
222+
def __init__(self):
223+
super(Model, self).__init__()
224+
self.linear = nn.Linear(4, 5)
225+
226+
def forward(self, input):
227+
return self.linear(input)
228+
229+
input = torch.randn(2, 4)
230+
model = Model()
231+
model.eval()
232+
# Invoke optimize function against the model object
233+
model = ipex.optimize(model, dtype=torch.float32)
234+
model = torch.jit.trace(model, torch.randn(2, 4))
235+
model = torch.jit.freeze(model)
236+
res = model(input)
237+
238+
BFloat16
239+
^^^^^^^^
240+
241+
.. code:: python3
242+
243+
import torch
244+
import torch.nn as nn
245+
# Import intel_extension_for_pytorch
246+
import intel_extension_for_pytorch as ipex
247+
248+
# oneDNN graph fusion is enabled by default, uncomment the line below to disable it explicitly
249+
# ipex.enable_onednn_fusion(False)
250+
251+
class Model(nn.Module):
252+
def __init__(self):
253+
super(Model, self).__init__()
254+
self.linear = nn.Linear(4, 5)
255+
256+
def forward(self, input):
257+
return self.linear(input)
258+
259+
input = torch.randn(2, 4)
260+
model = Model()
261+
model.eval()
262+
# Invoke optimize function against the model with data type set to torch.bfloat16
263+
model = ipex.optimize(model, dtype=torch.bfloat16)
264+
with torch.cpu.amp.autocast():
265+
model = torch.jit.trace(model, torch.randn(2, 4))
266+
model = torch.jit.freeze(model)
267+
res = model(input)
268+
269+
C++
270+
~~~
271+
272+
To work with libtorch, C++ library of PyTorch, Intel® Extension for PyTorch*
273+
provides its C++ dynamic library as well. The C++ library is supposed to handle
274+
inference workload only, such as service deployment. For regular development,
275+
please use Python interface. Comparing to usage of libtorch, no specific code
276+
changes are required, except for converting input data into channels last data
277+
format. Compilation follows the recommended methodology with CMake. Detailed
278+
instructions can be found in `PyTorch tutorial <https://pytorch.org/tutorials/advanced/cpp_export.html#depending-on-libtorch-and-building-the-application>`_.
279+
During compilation, Intel optimizations will be activated automatically
280+
once C++ dynamic library of Intel® Extension for PyTorch* is linked.
281+
282+
**example-app.cpp**
283+
284+
.. code:: cpp
285+
286+
#include <torch/script.h>
287+
#include <iostream>
288+
#include <memory>
289+
290+
int main(int argc, const char* argv[]) {
291+
torch::jit::script::Module module;
292+
try {
293+
module = torch::jit::load(argv[1]);
294+
}
295+
catch (const c10::Error& e) {
296+
std::cerr << "error loading the model\n";
297+
return -1;
298+
}
299+
std::vector<torch::jit::IValue> inputs;
300+
// make sure input data are converted to channels last format
301+
inputs.push_back(torch::ones({1, 3, 224, 224}).to(c10::MemoryFormat::ChannelsLast));
302+
303+
at::Tensor output = module.forward(inputs).toTensor();
304+
305+
return 0;
306+
}
307+
308+
**CMakeList.txt**
309+
310+
::
311+
312+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
313+
project(example-app)
314+
315+
find_package(Torch REQUIRED)
316+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -Wl,--no-as-needed")
317+
318+
add_executable(example-app example-app.cpp)
319+
# Link the binary against the C++ dynamic library file of Intel® Extension for PyTorch*
320+
target_link_libraries(example-app "${TORCH_LIBRARIES}" "${INTEL_EXTENSION_FOR_PYTORCH_PATH}/lib/libintel-ext-pt-cpu.so")
321+
322+
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
323+
324+
**Note:** Since Intel® Extension for PyTorch* is still under development, name of
325+
the c++ dynamic library in the master branch may defer to
326+
*libintel-ext-pt-cpu.so* shown above. Please check the name out in the
327+
installation folder. The so file name starts with *libintel-*.
328+
329+
**Command for compilation**
330+
331+
::
332+
333+
$ cmake -DCMAKE_PREFIX_PATH=<LIBPYTORCH_PATH> -DINTEL_EXTENSION_FOR_PYTORCH_PATH=<INTEL_EXTENSION_FOR_PYTORCH_INSTALLATION_PATH> ..
334+
$ make
335+
336+
Tutorials
337+
---------
338+
339+
Please visit `Intel® Extension for PyTorch* Github repo <https://github.com/intel/intel-extension-for-pytorch>`_ for more tutorials.

0 commit comments

Comments
 (0)