You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+111-3Lines changed: 111 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -3,10 +3,12 @@
3
3
Intel Extension for PyTorch is a Python package to extend official PyTorch. It is designed to make the Out-of-Box user experience of PyTorch CPU better while achieving good performance. The extension also will be the PR(Pull-Request) buffer for the Intel PyTorch framework dev team. The PR buffer will not only contain functions, but also optimization (for example, take advantage of Intel's new hardware features).
4
4
5
5
-[Installation](#installation)
6
-
- [Install PyTorch from Source](#install-pytorch-from-source)
7
-
- [Install Intel Extension for PyTorch from Source](#install-intel-extension-for-pytorch-from-source)
6
+
-[Install PyTorch from Source](#install-pytorch-from-source)
7
+
-[Install Intel Extension for PyTorch from Source](#install-intel-extension-for-pytorch-from-source)
In addition, Intel Extension for PyTorch supports the mixed precision. It means that some operators of a model may run with Float32 and some other operators may run with BFloat16 or INT8.
110
112
In traditional, if you want to run a model with a low precision type, you need to convert the parameters and the input tensors to the low precision type manually. And if the model contains some operators that do not support the low precision type, then you have to convert back to Float32. Round after round until the model can run normally.
111
113
The extension can simply the case, you just need to enable the auto-mix-precision as follows, then you can benefit from the low precision. Currently, the extension only supports BFloat16.
114
+
115
+
#### BFloat16
112
116
```python
113
117
import torch
114
118
import torch.nn as nn
@@ -130,6 +134,110 @@ model = Model().to(ipex.DEVICE)
130
134
131
135
res = model(input)
132
136
```
137
+
#### INT8 Quantization
138
+
Currently, Intel Extension for PyTorch has supported static and symmetric quantization. Development of dynamic quantization is undergoing. And asymmetric quantization will be enabled once oneDNN is upgraded to v2.0 or higher versions.
139
+
140
+
How to quantize the following model:
141
+
```python
142
+
import torch
143
+
import torch.nn as nn
144
+
145
+
class Model(nn.Module):
146
+
def __init__(self):
147
+
super(Model, self).__init__()
148
+
self.conv = nn.Conv2d(3, 64, 7, stride=2)
149
+
150
+
def forward(self, input):
151
+
returnself.conv(input).relu()
152
+
```
153
+
Firstly we need to do calibration step against a representative dataset (set ```running_mode``` to ```calibration```):
154
+
```python
155
+
# Convert the model to the Extension device
156
+
model = Model().to(ipex.DEVICE)
157
+
158
+
# Create a configuration file to save quantization parameters.
159
+
conf = ipex.AmpConf(torch.int8)
160
+
with torch.no_grad():
161
+
forxin cali_dataset:
162
+
# Run the model under calibration mode to collect quantization parameters
163
+
with ipex.AutoMixPrecision(conf, running_mode='calibration'):
164
+
y = model(x.to(ipex.DEVICE))
165
+
# Save the configuration file
166
+
conf.save('configure.json')
167
+
```
168
+
The content of the configuration file is as follows.
169
+
170
+
```json
171
+
[
172
+
{
173
+
"id": 0,
174
+
"name": "Convolution",
175
+
"algorithm": "min_max",
176
+
"weight_granularity": "per_channel",
177
+
"inputs_scale": [
178
+
25.05583953857422
179
+
],
180
+
"outputs_scale": [
181
+
43.98969650268555
182
+
],
183
+
"inputs_uint8_used": [
184
+
false
185
+
],
186
+
"outputs_uint8_used": [
187
+
false
188
+
],
189
+
"quantized": true
190
+
},
191
+
{
192
+
"id": 1,
193
+
"name": "Relu",
194
+
"algorithm": "min_max",
195
+
"weight_granularity": "per_channel",
196
+
"inputs_scale": [
197
+
43.98969650268555
198
+
],
199
+
"outputs_scale": [
200
+
43.98969650268555
201
+
],
202
+
"inputs_uint8_used": [
203
+
false
204
+
],
205
+
"outputs_uint8_used": [
206
+
false
207
+
],
208
+
"quantized": true
209
+
}
210
+
]
211
+
```
212
+
- ```id``` is a sequence number of operators which were quantized statically in the calibration step.
213
+
**Manually changing this value will cause unexpected behaviors**.
214
+
- ```name``` is the name of the operator to be quantized.
215
+
- ```algorithm``` indicates how to calculate the scales of the observed tensors. Currently only ```min_max``` is supported.
216
+
- ```weight_granularity``` controls how to quantize the operator weights. The ```Convolution``` and ```Linear``` both supports ```per_channel``` and ```per_tensor```. And the other operators only supports ```per_tensor```.
217
+
- ```inputs_scale``` and ```outputs_scale``` are the scales to quantize the input tensors and output tensors respectively.
218
+
- ```inputs_uint8_used``` and ```outputs_uint8_used``` indicate whether to use ```int8``` or ```uint8```. Default value is ```false```, indicating that ```int8``` is used.
219
+
- ```quantized``` determines whether this operator should be quantized or not during inference.
220
+
221
+
After doing calibration step, we can use the saved configuration json file to do evalution (set ```running_mode``` to ```inference```):
222
+
```python
223
+
conf = ipex.AmpConf(torch.int8, 'configure.json')
224
+
with torch.no_grad():
225
+
forxin cali_dataset:
226
+
with ipex.AutoMixPrecision(conf, running_mode='inference'):
0 commit comments