|
| 1 | +""" |
| 2 | +Intel® Extension for PyTorch* |
| 3 | +******************************* |
| 4 | +**Author**: `Jing Xu <https://github.com/jingxu10>`_ |
| 5 | +
|
| 6 | +Intel Extension for PyTorch* extends PyTorch with optimizations for extra |
| 7 | +performance boost on Intel hardware. Most of the optimizations will be |
| 8 | +included in stock PyTorch releases eventually, and the intention of the |
| 9 | +extension is to deliver up to date features and optimizations for PyTorch |
| 10 | +on Intel hardware, examples include AVX-512 Vector Neural Network |
| 11 | +Instructions (AVX512 VNNI) and Intel® Advanced Matrix Extensions (Intel® AMX). |
| 12 | +
|
| 13 | +Intel® Extension for PyTorch* has been released as an open–source project |
| 14 | +at `Github <https://github.com/intel/intel-extension-for-pytorch>`_. |
| 15 | +
|
| 16 | +Features |
| 17 | +-------- |
| 18 | +
|
| 19 | +* Ease-of-use Python API: Intel® Extension for PyTorch* provides simple |
| 20 | + frontend Python APIs and utilities for users to get performance optimizations |
| 21 | + such as graph optimization and operator optimization with minor code changes. |
| 22 | + Typically, only 2 to 3 clauses are required to be added to the original code. |
| 23 | +* Channels Last: Comparing to the default NCHW memory format, channels_last |
| 24 | + (NHWC) memory format could further accelerate convolutional neural networks. |
| 25 | + In Intel® Extension for PyTorch*, NHWC memory format has been enabled for |
| 26 | + most key CPU operators, though not all of them have been merged to PyTorch |
| 27 | + master branch yet. They are expected to be fully landed in PyTorch upstream |
| 28 | + soon. |
| 29 | +* Auto Mixed Precision (AMP): Low precision data type BFloat16 has been |
| 30 | + natively supported on the 3rd Generation Xeon scalable Servers (aka Cooper |
| 31 | + Lake) with AVX512 instruction set and will be supported on the next |
| 32 | + generation of Intel® Xeon® Scalable Processors with Intel® Advanced Matrix |
| 33 | + Extensions (Intel® AMX) instruction set with further boosted performance. The |
| 34 | + support of Auto Mixed Precision (AMP) with BFloat16 for CPU and BFloat16 |
| 35 | + optimization of operators have been massively enabled in Intel® Extension |
| 36 | + for PyTorch*, and partially upstreamed to PyTorch master branch. Most of |
| 37 | + these optimizations will be landed in PyTorch master through PRs that are |
| 38 | + being submitted and reviewed. |
| 39 | +* Graph Optimization: To optimize performance further with torchscript, |
| 40 | + Intel® Extension for PyTorch* supports fusion of frequently used operator |
| 41 | + patterns, like Conv2D+ReLU, Linear+ReLU, etc. The benefit of the fusions are |
| 42 | + delivered to users in a transparent fashion. Detailed fusion patterns |
| 43 | + supported can be found `here <https://github.com/intel/intel-extension-for-pytorch>`_. |
| 44 | + The graph optimization will be up-streamed to PyTorch with the introduction |
| 45 | + of oneDNN Graph API. |
| 46 | +* Operator Optimization: Intel® Extension for PyTorch* also optimizes |
| 47 | + operators and implements several customized operators for performance. A few |
| 48 | + ATen operators are replaced by their optimized counterparts in Intel® |
| 49 | + Extension for PyTorch* via ATen registration mechanism. Moreover, some |
| 50 | + customized operators are implemented for several popular topologies. For |
| 51 | + instance, ROIAlign and NMS are defined in Mask R-CNN. To improve performance |
| 52 | + of these topologies, Intel® Extension for PyTorch* also optimized these |
| 53 | + customized operators. |
| 54 | +""" |
| 55 | + |
| 56 | +############################################################################### |
| 57 | +# Getting Started |
| 58 | +# --------------- |
| 59 | + |
| 60 | +############################################################################### |
| 61 | +# Minor code changes are required for users to get start with Intel® Extension |
| 62 | +# for PyTorch*. Both PyTorch imperative mode and TorchScript mode are |
| 63 | +# supported. This section introduces usage of Intel® Extension for PyTorch* API |
| 64 | +# functions for both imperative mode and TorchScript mode, covering data type |
| 65 | +# Float32 and BFloat16. C++ usage will also be introduced at the end. |
| 66 | + |
| 67 | +############################################################################### |
| 68 | +# You just need to import Intel® Extension for PyTorch* package and apply its |
| 69 | +# optimize function against the model object. If it is a training workload, the |
| 70 | +# optimize function also needs to be applied against the optimizer object. |
| 71 | + |
| 72 | +############################################################################### |
| 73 | +# For training and inference with BFloat16 data type, torch.cpu.amp has been |
| 74 | +# enabled in PyTorch upstream to support mixed precision with convenience, and |
| 75 | +# BFloat16 datatype has been enabled excessively for CPU operators in PyTorch |
| 76 | +# upstream and Intel® Extension for PyTorch*. Running torch.cpu.amp will match |
| 77 | +# each operator to its appropriate datatype and returns the best possible |
| 78 | +# performance. |
| 79 | + |
| 80 | +############################################################################### |
| 81 | +# The code changes that are required for Intel® Extension for PyTorch* are |
| 82 | +# highlighted with comments in a line above. |
| 83 | + |
| 84 | +############################################################################### |
| 85 | +# Training |
| 86 | +# ~~~~~~~~ |
| 87 | + |
| 88 | +############################################################################### |
| 89 | +# Float32 |
| 90 | +# ^^^^^^^ |
| 91 | + |
| 92 | +import torch |
| 93 | +import torch.nn as nn |
| 94 | +# Import intel_extension_for_pytorch |
| 95 | +import intel_extension_for_pytorch as ipex |
| 96 | + |
| 97 | +class Model(nn.Module): |
| 98 | + def __init__(self): |
| 99 | + super(Model, self).__init__() |
| 100 | + self.linear = nn.Linear(4, 5) |
| 101 | + |
| 102 | + def forward(self, input): |
| 103 | + return self.linear(input) |
| 104 | + |
| 105 | +model = Model() |
| 106 | +model.set_state_dict(torch.load(PATH)) |
| 107 | +optimizer.set_state_dict(torch.load(PATH)) |
| 108 | +# Invoke optimize function against the model object and optimizer object |
| 109 | +model, optimizer = ipex.optimize(model, optimizer, dtype=torch.float32) |
| 110 | + |
| 111 | +for images, label in train_loader(): |
| 112 | + # Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional. |
| 113 | + images = images.to(memory_format=torch.channels_last) |
| 114 | + loss = criterion(model(images), label) |
| 115 | + loss.backward() |
| 116 | + optimizer.step() |
| 117 | +torch.save(model.state_dict(), PATH) |
| 118 | +torch.save(optimizer.state_dict(), PATH) |
| 119 | + |
| 120 | +############################################################################### |
| 121 | +# BFloat16 |
| 122 | +# ^^^^^^^^ |
| 123 | + |
| 124 | +import torch |
| 125 | +import torch.nn as nn |
| 126 | +# Import intel_extension_for_pytorch |
| 127 | +import intel_extension_for_pytorch as ipex |
| 128 | + |
| 129 | +class Model(nn.Module): |
| 130 | + def __init__(self): |
| 131 | + super(Model, self).__init__() |
| 132 | + self.linear = nn.Linear(4, 5) |
| 133 | + |
| 134 | + def forward(self, input): |
| 135 | + return self.linear(input) |
| 136 | + |
| 137 | +model = Model() |
| 138 | +model.set_state_dict(torch.load(PATH)) |
| 139 | +optimizer.set_state_dict(torch.load(PATH)) |
| 140 | +# Invoke optimize function against the model object and optimizer object with data type set to torch.bfloat16 |
| 141 | +model, optimizer = ipex.optimize(model, optimizer, dtype=torch.bfloat16) |
| 142 | + |
| 143 | +for images, label in train_loader(): |
| 144 | + with torch.cpu.amp.autocast(): |
| 145 | + # Setting memory_format to torch.channels_last could improve performance with 4D input data. This is optional. |
| 146 | + images = images.to(memory_format=torch.channels_last) |
| 147 | + loss = criterion(model(images), label) |
| 148 | + loss.backward() |
| 149 | + optimizer.step() |
| 150 | +torch.save(model.state_dict(), PATH) |
| 151 | +torch.save(optimizer.state_dict(), PATH) |
| 152 | + |
| 153 | +############################################################################### |
| 154 | +# Inference - Imperative Mode |
| 155 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 156 | + |
| 157 | +############################################################################### |
| 158 | +# Float32 |
| 159 | +# """"""" |
| 160 | + |
| 161 | +import torch |
| 162 | +import torch.nn as nn |
| 163 | +# Import intel_extension_for_pytorch |
| 164 | +import intel_extension_for_pytorch as ipex |
| 165 | + |
| 166 | +class Model(nn.Module): |
| 167 | + def __init__(self): |
| 168 | + super(Model, self).__init__() |
| 169 | + self.linear = nn.Linear(4, 5) |
| 170 | + |
| 171 | + def forward(self, input): |
| 172 | + return self.linear(input) |
| 173 | + |
| 174 | +input = torch.randn(2, 4) |
| 175 | +model = Model() |
| 176 | +model.eval() |
| 177 | +# Invoke optimize function against the model object |
| 178 | +model = ipex.optimize(model, dtype=torch.float32) |
| 179 | +res = model(input) |
| 180 | + |
| 181 | +############################################################################### |
| 182 | +# BFloat16 |
| 183 | +# ^^^^^^^^ |
| 184 | + |
| 185 | +import torch |
| 186 | +import torch.nn as nn |
| 187 | +# Import intel_extension_for_pytorch |
| 188 | +import intel_extension_for_pytorch as ipex |
| 189 | + |
| 190 | +class Model(nn.Module): |
| 191 | + def __init__(self): |
| 192 | + super(Model, self).__init__() |
| 193 | + self.linear = nn.Linear(4, 5) |
| 194 | + |
| 195 | + def forward(self, input): |
| 196 | + return self.linear(input) |
| 197 | + |
| 198 | +input = torch.randn(2, 4) |
| 199 | +model = Model() |
| 200 | +model.eval() |
| 201 | +# Invoke optimize function against the model object with data type set to torch.bfloat16 |
| 202 | +model = ipex.optimize(model, dtype=torch.bfloat16) |
| 203 | +with torch.cpu.amp.autocast(): |
| 204 | + res = model(input) |
| 205 | + |
| 206 | +############################################################################### |
| 207 | +# Inference - TorchScript Mode |
| 208 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 209 | + |
| 210 | +############################################################################### |
| 211 | +# TorchScript mode makes graph optimization possible, hence improves |
| 212 | +# performance for some topologies. Intel® Extension for PyTorch* enables most |
| 213 | +# commonly used operator pattern fusion, and users can get the performance |
| 214 | +# benefit without additional code changes. |
| 215 | + |
| 216 | +############################################################################### |
| 217 | +# Float32 |
| 218 | +# """"""" |
| 219 | + |
| 220 | +import torch |
| 221 | +import torch.nn as nn |
| 222 | +# Import intel_extension_for_pytorch |
| 223 | +import intel_extension_for_pytorch as ipex |
| 224 | + |
| 225 | +# oneDNN graph fusion is enabled by default, uncomment the line below to disable it explicitly |
| 226 | +# ipex.enable_onednn_fusion(False) |
| 227 | + |
| 228 | +class Model(nn.Module): |
| 229 | + def __init__(self): |
| 230 | + super(Model, self).__init__() |
| 231 | + self.linear = nn.Linear(4, 5) |
| 232 | + |
| 233 | + def forward(self, input): |
| 234 | + return self.linear(input) |
| 235 | + |
| 236 | +input = torch.randn(2, 4) |
| 237 | +model = Model() |
| 238 | +model.eval() |
| 239 | +# Invoke optimize function against the model object |
| 240 | +model = ipex.optimize(model, dtype=torch.float32) |
| 241 | +model = torch.jit.trace(model, torch.randn(2, 4)) |
| 242 | +model = torch.jit.freeze(model) |
| 243 | +res = model(input) |
| 244 | + |
| 245 | +############################################################################### |
| 246 | +# BFloat16 |
| 247 | +# ^^^^^^^^ |
| 248 | + |
| 249 | +import torch |
| 250 | +import torch.nn as nn |
| 251 | +# Import intel_extension_for_pytorch |
| 252 | +import intel_extension_for_pytorch as ipex |
| 253 | + |
| 254 | +# oneDNN graph fusion is enabled by default, uncomment the line below to disable it explicitly |
| 255 | +# ipex.enable_onednn_fusion(False) |
| 256 | + |
| 257 | +class Model(nn.Module): |
| 258 | + def __init__(self): |
| 259 | + super(Model, self).__init__() |
| 260 | + self.linear = nn.Linear(4, 5) |
| 261 | + |
| 262 | + def forward(self, input): |
| 263 | + return self.linear(input) |
| 264 | + |
| 265 | +input = torch.randn(2, 4) |
| 266 | +model = Model() |
| 267 | +model.eval() |
| 268 | +# Invoke optimize function against the model with data type set to torch.bfloat16 |
| 269 | +model = ipex.optimize(model, dtype=torch.bfloat16) |
| 270 | +with torch.cpu.amp.autocast(): |
| 271 | + model = torch.jit.trace(model, torch.randn(2, 4)) |
| 272 | + model = torch.jit.freeze(model) |
| 273 | + res = model(input) |
| 274 | + |
| 275 | +############################################################################### |
| 276 | +# C++ |
| 277 | +# ~~~ |
| 278 | + |
| 279 | +############################################################################### |
| 280 | +# To work with libtorch, C++ library of PyTorch, Intel® Extension for PyTorch* |
| 281 | +# provides its C++ dynamic library as well. The C++ library is supposed to handle |
| 282 | +# inference workload only, such as service deployment. For regular development, |
| 283 | +# please use Python interface. Comparing to usage of libtorch, no specific code |
| 284 | +# changes are required, except for converting input data into channels last data |
| 285 | +# format. Compilation follows the recommended methodology with CMake. Detailed |
| 286 | +# instructions can be found in `PyTorch tutorial <https://pytorch.org/tutorials/advanced/cpp_export.html#depending-on-libtorch-and-building-the-application>`_. |
| 287 | +# During compilation, Intel optimizations will be activated automatically |
| 288 | +# once C++ dynamic library of Intel® Extension for PyTorch* is linked. |
| 289 | + |
| 290 | +############################################################################### |
| 291 | +# **example-app.cpp** |
| 292 | + |
| 293 | +''' |
| 294 | +#include <torch/script.h> |
| 295 | +#include <iostream> |
| 296 | +#include <memory> |
| 297 | +
|
| 298 | +int main(int argc, const char* argv[]) { |
| 299 | + torch::jit::script::Module module; |
| 300 | + try { |
| 301 | + module = torch::jit::load(argv[1]); |
| 302 | + } |
| 303 | + catch (const c10::Error& e) { |
| 304 | + std::cerr << "error loading the model\n"; |
| 305 | + return -1; |
| 306 | + } |
| 307 | + std::vector<torch::jit::IValue> inputs; |
| 308 | + // make sure input data are converted to channels last format |
| 309 | + inputs.push_back(torch::ones({1, 3, 224, 224}).to(c10::MemoryFormat::ChannelsLast)); |
| 310 | +
|
| 311 | + at::Tensor output = module.forward(inputs).toTensor(); |
| 312 | +
|
| 313 | + return 0; |
| 314 | +} |
| 315 | +''' |
| 316 | + |
| 317 | +############################################################################### |
| 318 | +# **CMakeList.txt** |
| 319 | + |
| 320 | +''' |
| 321 | +cmake_minimum_required(VERSION 3.0 FATAL_ERROR) |
| 322 | +project(example-app) |
| 323 | +
|
| 324 | +find_package(Torch REQUIRED) |
| 325 | +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -Wl,--no-as-needed") |
| 326 | +
|
| 327 | +add_executable(example-app example-app.cpp) |
| 328 | +# Link the binary against the C++ dynamic library file of Intel® Extension for PyTorch* |
| 329 | +target_link_libraries(example-app "${TORCH_LIBRARIES}" "${INTEL_EXTENSION_FOR_PYTORCH_PATH}/lib/libintel-ext-pt-cpu.so") |
| 330 | +
|
| 331 | +set_property(TARGET example-app PROPERTY CXX_STANDARD 14) |
| 332 | +''' |
| 333 | + |
| 334 | +############################################################################### |
| 335 | +# **Note:** Since Intel® Extension for PyTorch* is still under development, name of |
| 336 | +# the c++ dynamic library in the master branch may defer to |
| 337 | +# *libintel-ext-pt-cpu.so* shown above. Please check the name out in the |
| 338 | +# installation folder. The so file name starts with *libintel-*. |
| 339 | + |
| 340 | +############################################################################### |
| 341 | +# **Command for compilation** |
| 342 | + |
| 343 | +''' |
| 344 | +cmake -DCMAKE_PREFIX_PATH=<LIBPYTORCH_PATH> -DINTEL_EXTENSION_FOR_PYTORCH_PATH=<INTEL_EXTENSION_FOR_PYTORCH_INSTALLATION_PATH> .. |
| 345 | +''' |
| 346 | + |
| 347 | +############################################################################### |
| 348 | +# Tutorials |
| 349 | +# --------- |
| 350 | + |
| 351 | +############################################################################### |
| 352 | +# Please visit `Intel® Extension for PyTorch* Github repo <https://github.com/intel/intel-extension-for-pytorch>`_ for more tutorials. |
0 commit comments