Skip to content

Commit f6d75f8

Browse files
jingxu10MichaelHsu170malfet
authored
introduction of intel extension for pytorch (#1702)
* init for recipe for intel extension for pytorch * update intel_extension_for_pytorch.py * update intel_extension_for_pytorch.py for c++ part * update c++ so file name * Fix typos * fixed issue for inference sample codes Co-authored-by: michaelhsu <michaelhsu170@gmail.com> Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
1 parent d516108 commit f6d75f8

File tree

2 files changed

+362
-0
lines changed

2 files changed

+362
-0
lines changed
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
"""
2+
Intel® Extension for PyTorch*
3+
*******************************
4+
**Author**: `Jing Xu <https://github.com/jingxu10>`_
5+
6+
Intel Extension for PyTorch* extends PyTorch with optimizations for extra
7+
performance boost on Intel hardware. Most of the optimizations will be
8+
included in stock PyTorch releases eventually, and the intention of the
9+
extension is to deliver up to date features and optimizations for PyTorch
10+
on Intel hardware, examples include AVX-512 Vector Neural Network
11+
Instructions (AVX512 VNNI) and Intel® Advanced Matrix Extensions (Intel® AMX).
12+
13+
Intel® Extension for PyTorch* has been released as an open–source project
14+
at `Github <https://github.com/intel/intel-extension-for-pytorch>`_.
15+
16+
Features
17+
--------
18+
19+
* Ease-of-use Python API: Intel® Extension for PyTorch* provides simple
20+
frontend Python APIs and utilities for users to get performance optimizations
21+
such as graph optimization and operator optimization with minor code changes.
22+
Typically, only 2 to 3 clauses are required to be added to the original code.
23+
* Channels Last: Comparing to the default NCHW memory format, channels_last
24+
(NHWC) memory format could further accelerate convolutional neural networks.
25+
In Intel® Extension for PyTorch*, NHWC memory format has been enabled for
26+
most key CPU operators, though not all of them have been merged to PyTorch
27+
master branch yet. They are expected to be fully landed in PyTorch upstream
28+
soon.
29+
* Auto Mixed Precision (AMP): Low precision data type BFloat16 has been
30+
natively supported on the 3rd Generation Xeon scalable Servers (aka Cooper
31+
Lake) with AVX512 instruction set and will be supported on the next
32+
generation of Intel® Xeon® Scalable Processors with Intel® Advanced Matrix
33+
Extensions (Intel® AMX) instruction set with further boosted performance. The
34+
support of Auto Mixed Precision (AMP) with BFloat16 for CPU and BFloat16
35+
optimization of operators have been massively enabled in Intel® Extension
36+
for PyTorch*, and partially upstreamed to PyTorch master branch. Most of
37+
these optimizations will be landed in PyTorch master through PRs that are
38+
being submitted and reviewed.
39+
* Graph Optimization: To optimize performance further with torchscript,
40+
Intel® Extension for PyTorch* supports fusion of frequently used operator
41+
patterns, like Conv2D+ReLU, Linear+ReLU, etc. The benefit of the fusions are
42+
delivered to users in a transparent fashion. Detailed fusion patterns
43+
supported can be found `here <https://github.com/intel/intel-extension-for-pytorch>`_.
44+
The graph optimization will be up-streamed to PyTorch with the introduction
45+
of oneDNN Graph API.
46+
* Operator Optimization: Intel® Extension for PyTorch* also optimizes
47+
operators and implements several customized operators for performance. A few
48+
ATen operators are replaced by their optimized counterparts in Intel®
49+
Extension for PyTorch* via ATen registration mechanism. Moreover, some
50+
customized operators are implemented for several popular topologies. For
51+
instance, ROIAlign and NMS are defined in Mask R-CNN. To improve performance
52+
of these topologies, Intel® Extension for PyTorch* also optimized these
53+
customized operators.
54+
"""
55+
56+
###############################################################################
57+
# Getting Started
58+
# ---------------
59+
60+
###############################################################################
61+
# Minor code changes are required for users to get start with Intel® Extension
62+
# for PyTorch*. Both PyTorch imperative mode and TorchScript mode are
63+
# supported. This section introduces usage of Intel® Extension for PyTorch* API
64+
# functions for both imperative mode and TorchScript mode, covering data type
65+
# Float32 and BFloat16. C++ usage will also be introduced at the end.
66+
67+
###############################################################################
68+
# You just need to import Intel® Extension for PyTorch* package and apply its
69+
# optimize function against the model object. If it is a training workload, the
70+
# optimize function also needs to be applied against the optimizer object.
71+
72+
###############################################################################
73+
# For training and inference with BFloat16 data type, torch.cpu.amp has been
74+
# enabled in PyTorch upstream to support mixed precision with convenience, and
75+
# BFloat16 datatype has been enabled excessively for CPU operators in PyTorch
76+
# upstream and Intel® Extension for PyTorch*. Running torch.cpu.amp will match
77+
# each operator to its appropriate datatype and returns the best possible
78+
# performance.
79+
80+
###############################################################################
81+
# The code changes that are required for Intel® Extension for PyTorch* are
82+
# highlighted with comments in a line above.
83+
84+
###############################################################################
85+
# Training
86+
# ~~~~~~~~
87+
88+
###############################################################################
89+
# Float32
90+
# ^^^^^^^
91+
92+
import torch
93+
import torch.nn as nn
94+
# Import intel_extension_for_pytorch
95+
import intel_extension_for_pytorch as ipex
96+
97+
class Model(nn.Module):
98+
def __init__(self):
99+
super(Model, self).__init__()
100+
self.linear = nn.Linear(4, 5)
101+
102+
def forward(self, input):
103+
return self.linear(input)
104+
105+
model = Model()
106+
model.set_state_dict(torch.load(PATH))
107+
optimizer.set_state_dict(torch.load(PATH))
108+
# Invoke optimize function against the model object and optimizer object
109+
model, optimizer = ipex.optimize(model, optimizer, dtype=torch.float32)
110+
111+
for images, label in train_loader():
112+
# Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional.
113+
images = images.to(memory_format=torch.channels_last)
114+
loss = criterion(model(images), label)
115+
loss.backward()
116+
optimizer.step()
117+
torch.save(model.state_dict(), PATH)
118+
torch.save(optimizer.state_dict(), PATH)
119+
120+
###############################################################################
121+
# BFloat16
122+
# ^^^^^^^^
123+
124+
import torch
125+
import torch.nn as nn
126+
# Import intel_extension_for_pytorch
127+
import intel_extension_for_pytorch as ipex
128+
129+
class Model(nn.Module):
130+
def __init__(self):
131+
super(Model, self).__init__()
132+
self.linear = nn.Linear(4, 5)
133+
134+
def forward(self, input):
135+
return self.linear(input)
136+
137+
model = Model()
138+
model.set_state_dict(torch.load(PATH))
139+
optimizer.set_state_dict(torch.load(PATH))
140+
# Invoke optimize function against the model object and optimizer object with data type set to torch.bfloat16
141+
model, optimizer = ipex.optimize(model, optimizer, dtype=torch.bfloat16)
142+
143+
for images, label in train_loader():
144+
with torch.cpu.amp.autocast():
145+
# Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional.
146+
images = images.to(memory_format=torch.channels_last)
147+
loss = criterion(model(images), label)
148+
loss.backward()
149+
optimizer.step()
150+
torch.save(model.state_dict(), PATH)
151+
torch.save(optimizer.state_dict(), PATH)
152+
153+
###############################################################################
154+
# Inference - Imperative Mode
155+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
156+
157+
###############################################################################
158+
# Float32
159+
# """""""
160+
161+
import torch
162+
import torch.nn as nn
163+
# Import intel_extension_for_pytorch
164+
import intel_extension_for_pytorch as ipex
165+
166+
class Model(nn.Module):
167+
def __init__(self):
168+
super(Model, self).__init__()
169+
self.linear = nn.Linear(4, 5)
170+
171+
def forward(self, input):
172+
return self.linear(input)
173+
174+
input = torch.randn(2, 4)
175+
model = Model()
176+
model.eval()
177+
# Invoke optimize function against the model object
178+
model = ipex.optimize(model, dtype=torch.float32)
179+
res = model(input)
180+
181+
###############################################################################
182+
# BFloat16
183+
# ^^^^^^^^
184+
185+
import torch
186+
import torch.nn as nn
187+
# Import intel_extension_for_pytorch
188+
import intel_extension_for_pytorch as ipex
189+
190+
class Model(nn.Module):
191+
def __init__(self):
192+
super(Model, self).__init__()
193+
self.linear = nn.Linear(4, 5)
194+
195+
def forward(self, input):
196+
return self.linear(input)
197+
198+
input = torch.randn(2, 4)
199+
model = Model()
200+
model.eval()
201+
# Invoke optimize function against the model object with data type set to torch.bfloat16
202+
model = ipex.optimize(model, dtype=torch.bfloat16)
203+
with torch.cpu.amp.autocast():
204+
res = model(input)
205+
206+
###############################################################################
207+
# Inference - TorchScript Mode
208+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
209+
210+
###############################################################################
211+
# TorchScript mode makes graph optimization possible, hence improves
212+
# performance for some topologies. Intel® Extension for PyTorch* enables most
213+
# commonly used operator pattern fusion, and users can get the performance
214+
# benefit without additional code changes.
215+
216+
###############################################################################
217+
# Float32
218+
# """""""
219+
220+
import torch
221+
import torch.nn as nn
222+
# Import intel_extension_for_pytorch
223+
import intel_extension_for_pytorch as ipex
224+
225+
# oneDNN graph fusion is enabled by default, uncomment the line below to disable it explicitly
226+
# ipex.enable_onednn_fusion(False)
227+
228+
class Model(nn.Module):
229+
def __init__(self):
230+
super(Model, self).__init__()
231+
self.linear = nn.Linear(4, 5)
232+
233+
def forward(self, input):
234+
return self.linear(input)
235+
236+
input = torch.randn(2, 4)
237+
model = Model()
238+
model.eval()
239+
# Invoke optimize function against the model object
240+
model = ipex.optimize(model, dtype=torch.float32)
241+
model = torch.jit.trace(model, torch.randn(2, 4))
242+
model = torch.jit.freeze(model)
243+
res = model(input)
244+
245+
###############################################################################
246+
# BFloat16
247+
# ^^^^^^^^
248+
249+
import torch
250+
import torch.nn as nn
251+
# Import intel_extension_for_pytorch
252+
import intel_extension_for_pytorch as ipex
253+
254+
# oneDNN graph fusion is enabled by default, uncomment the line below to disable it explicitly
255+
# ipex.enable_onednn_fusion(False)
256+
257+
class Model(nn.Module):
258+
def __init__(self):
259+
super(Model, self).__init__()
260+
self.linear = nn.Linear(4, 5)
261+
262+
def forward(self, input):
263+
return self.linear(input)
264+
265+
input = torch.randn(2, 4)
266+
model = Model()
267+
model.eval()
268+
# Invoke optimize function against the model with data type set to torch.bfloat16
269+
model = ipex.optimize(model, dtype=torch.bfloat16)
270+
with torch.cpu.amp.autocast():
271+
model = torch.jit.trace(model, torch.randn(2, 4))
272+
model = torch.jit.freeze(model)
273+
res = model(input)
274+
275+
###############################################################################
276+
# C++
277+
# ~~~
278+
279+
###############################################################################
280+
# To work with libtorch, C++ library of PyTorch, Intel® Extension for PyTorch*
281+
# provides its C++ dynamic library as well. The C++ library is supposed to handle
282+
# inference workload only, such as service deployment. For regular development,
283+
# please use Python interface. Comparing to usage of libtorch, no specific code
284+
# changes are required, except for converting input data into channels last data
285+
# format. Compilation follows the recommended methodology with CMake. Detailed
286+
# instructions can be found in `PyTorch tutorial <https://pytorch.org/tutorials/advanced/cpp_export.html#depending-on-libtorch-and-building-the-application>`_.
287+
# During compilation, Intel optimizations will be activated automatically
288+
# once C++ dynamic library of Intel® Extension for PyTorch* is linked.
289+
290+
###############################################################################
291+
# **example-app.cpp**
292+
293+
'''
294+
#include <torch/script.h>
295+
#include <iostream>
296+
#include <memory>
297+
298+
int main(int argc, const char* argv[]) {
299+
torch::jit::script::Module module;
300+
try {
301+
module = torch::jit::load(argv[1]);
302+
}
303+
catch (const c10::Error& e) {
304+
std::cerr << "error loading the model\n";
305+
return -1;
306+
}
307+
std::vector<torch::jit::IValue> inputs;
308+
// make sure input data are converted to channels last format
309+
inputs.push_back(torch::ones({1, 3, 224, 224}).to(c10::MemoryFormat::ChannelsLast));
310+
311+
at::Tensor output = module.forward(inputs).toTensor();
312+
313+
return 0;
314+
}
315+
'''
316+
317+
###############################################################################
318+
# **CMakeList.txt**
319+
320+
'''
321+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
322+
project(example-app)
323+
324+
find_package(Torch REQUIRED)
325+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -Wl,--no-as-needed")
326+
327+
add_executable(example-app example-app.cpp)
328+
# Link the binary against the C++ dynamic library file of Intel® Extension for PyTorch*
329+
target_link_libraries(example-app "${TORCH_LIBRARIES}" "${INTEL_EXTENSION_FOR_PYTORCH_PATH}/lib/libintel-ext-pt-cpu.so")
330+
331+
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
332+
'''
333+
334+
###############################################################################
335+
# **Note:** Since Intel® Extension for PyTorch* is still under development, name of
336+
# the c++ dynamic library in the master branch may defer to
337+
# *libintel-ext-pt-cpu.so* shown above. Please check the name out in the
338+
# installation folder. The so file name starts with *libintel-*.
339+
340+
###############################################################################
341+
# **Command for compilation**
342+
343+
'''
344+
cmake -DCMAKE_PREFIX_PATH=<LIBPYTORCH_PATH> -DINTEL_EXTENSION_FOR_PYTORCH_PATH=<INTEL_EXTENSION_FOR_PYTORCH_INSTALLATION_PATH> ..
345+
'''
346+
347+
###############################################################################
348+
# Tutorials
349+
# ---------
350+
351+
###############################################################################
352+
# Please visit `Intel® Extension for PyTorch* Github repo <https://github.com/intel/intel-extension-for-pytorch>`_ for more tutorials.

recipes_source/recipes_index.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
239239
:link: ../recipes/recipes/tuning_guide.html
240240
:tags: Model-Optimization
241241

242+
.. Intel(R) Extension for PyTorch*
243+
244+
.. customcarditem::
245+
:header: Intel® Extension for PyTorch*
246+
:card_description: Introduction of Intel® Extension for PyTorch*
247+
:image: ../_static/img/thumbnails/cropped/profiler.png
248+
:link: ../recipes/recipes/intel_extension_for_pytorch.html
249+
:tags: Model-Optimization
250+
242251
.. Distributed Training
243252
244253
.. customcarditem::
@@ -295,6 +304,7 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
295304
/recipes/recipes/dynamic_quantization
296305
/recipes/recipes/amp_recipe
297306
/recipes/recipes/tuning_guide
307+
/recipes/recipes/intel_extension_for_pytorch
298308
/recipes/torchscript_inference
299309
/recipes/deployment_with_flask
300310
/recipes/distributed_rpc_profiling

0 commit comments

Comments
 (0)