Skip to content

Commit 3fe5370

Browse files
authored
Refine README and API (#82)
1. Expose API enable_auto_optimization to the end user 2. Refine README
1 parent dac86fe commit 3fe5370

File tree

10 files changed

+139
-132
lines changed

10 files changed

+139
-132
lines changed

README.md

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ Intel Extension for PyTorch is a Python package to extend official PyTorch. It i
44

55
- [Installation](#installation)
66
- [Install PyTorch from Source](#install-pytorch-from-source)
7-
- [Install Intel PyTorch Extension from Source](#install-intel-pytorch-extension-from-source)
7+
- [Install Intel Extension for PyTorch from Source](#install-intel-extension-for-pytorch-from-source)
88
- [Getting Started](#getting-started)
9+
- [Automatically Mix Precison](#automatically-mix-precision)
910
- [Contribution](#contribution)
1011
- [License](#license)
1112

@@ -20,8 +21,8 @@ Intel Extension for PyTorch is a Python package to extend official PyTorch. It i
2021

2122
# checkout source code to the specified version
2223
git checkout v1.5.0-rc3
23-
24-
# update submodules for the specified pytorch version
24+
25+
# update submodules for the specified PyTorch version
2526
git submodule sync
2627
git submodule update --init --recursive
2728
```
@@ -30,40 +31,40 @@ Intel Extension for PyTorch is a Python package to extend official PyTorch. It i
3031
```bash
3132
git clone --recursive https://github.com/intel/intel-extension-for-pytorch
3233
cd intel-extension-for-pytorch
33-
34+
3435
# if you are updating an existing checkout
3536
git submodule sync
3637
git submodule update --init --recursive
3738
```
3839

39-
3. Add an new backend for Intel PyTorch Extension
40+
3. Add an new backend for Intel Extension for PyTorch
4041
```bash
4142
# Apply git patch to pytorch code
4243
cd ${pytorch_directory}
43-
git apply ${intel_pytorch_extension_directory}/torch_patches/dpcpp-v1.5-rc3.patch
44+
git apply ${intel_extension_for_pytorch_directory}/torch_patches/dpcpp-v1.5-rc3.patch
4445
```
45-
46+
4647
4. Build and install PyTorch (Refer to [PyTorch guide](https://github.com/pytorch/pytorch#install-pytorch) for more details)
4748
```bash
4849
cd ${pytorch_directory}
4950
python setup.py install
5051
```
5152

52-
### Install Intel PyTorch Extension from Source
53+
### Install Intel Extension for PyTorch from Source
5354
Install dependencies
5455
```bash
5556
pip install lark-parser hypothesis
5657
```
5758

5859
Install the extension
5960
```bash
60-
cd ${intel_pytorch_extension_directory}
61+
cd ${intel_extension_for_pytorch_directory}
6162
python setup.py install
6263
```
6364

6465
## Getting Started
6566

66-
The user just needs to convert the model and input tensors to the extension device, then the extension will be enabled automatically. Take an example, the code as follows is a model without the extension.
67+
If you want to explore Intel Extension for PyTorch, you just need to convert the model and input tensors to the extension device, then the extension will be enabled automatically. Take an example, the code as follows is a model without the extension.
6768
```python
6869
import torch
6970
import torch.nn as nn
@@ -80,13 +81,13 @@ input = torch.randn(2, 4)
8081
model = Model()
8182
res = model(input)
8283
```
83-
If you want to explore the Intel PyTorch Extension, you just need to transform the above python script as follows.
84+
You just need to transform the above python script as follows and then the extension will be enabled and accelerate the computation automatically.
8485
```python
8586
import torch
8687
import torch.nn as nn
8788
88-
# Import Intel PyTorch Extension
89-
import intel_pytorch_extension
89+
# Import Extension
90+
import intel_pytorch_extension as ipex
9091
9192
class Model(nn.Module):
9293
def __init__(self):
@@ -96,20 +97,25 @@ class Model(nn.Module):
9697
def forward(self, input):
9798
return self.linear(input)
9899
99-
# Convert the input tensor to Intel PyTorch Extension device
100-
input = torch.randn(2, 4).to('dpcpp')
101-
# Convert the model to Intel PyTorch Extension device
102-
model = Model().to('dpcpp')
100+
# Convert the input tensor to the Extension device
101+
input = torch.randn(2, 4).to(ipex.DEVICE)
102+
# Convert the model to the Extension device
103+
model = Model().to(ipex.DEVICE)
103104
104105
res = model(input)
105106
```
106-
In addition, Intel PyTorch Extension can auto dispatch an OP to DNNL if the OP is supported with DNNL. Currently, the feature is not enabled by default. If you want to enable the feature, you can refine the above code as follows.
107+
108+
### Automatically Mix Precision
109+
In addition, Intel Extension for PyTorch supports the mixed precision. It means that some operators of a model may run with Float32 and some other operators may run with BFloat16 or INT8.
110+
In traditional, if you want to run a model with a low precision type, you need to convert the parameters and the input tensors to the low precision type manually. And if the model contains some operators that do not support the low precision type, then you have to convert back to Float32. Round after round until the model can run normally.
111+
The extension can simply the case, you just need to enable the auto-mix-precision as follows, then you can benefit from the low precision. Currently, the extension only supports BFloat16.
107112
```python
108113
import torch
109114
import torch.nn as nn
110115
111-
# Import Intel PyTorch Extension
112116
import intel_pytorch_extension as ipex
117+
# Automatically mix precision
118+
ipex.enable_auto_optimization(mixed_dtype = torch.bfloat16)
113119
114120
class Model(nn.Module):
115121
def __init__(self):
@@ -119,15 +125,13 @@ class Model(nn.Module):
119125
def forward(self, input):
120126
return self.linear(input)
121127
122-
# Convert the input tensor to Intel PyTorch Extension device
123-
input = torch.randn(2, 4).to('dpcpp')
124-
# Convert the model to Intel PyTorch Extension device
125-
model = Model().to('dpcpp')
128+
input = torch.randn(2, 4).to(ipex.DEVICE)
129+
model = Model().to(ipex.DEVICE)
126130
127-
ipex.core.enable_auto_dnnl()
128131
res = model(input)
129132
```
130133
134+
131135
## Contribution
132136
133137
Please submit PR or issue to communicate with us or contribute code.

intel_pytorch_extension_py/__init__.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,34 @@
88

99
DEVICE = 'dpcpp'
1010

11-
def get_auto_optimization():
12-
return core.get_auto_dnnl()
11+
def enable_auto_optimization(mixed_dtype = None):
12+
r""" Enable auto-mixed-precision to improve performance.
1313
14-
def enable_auto_optimization(enable = True):
15-
if enable:
16-
core.enable_auto_dnnl()
17-
else:
18-
core.disable_auto_dnnl()
14+
The auto-mixed-precision auto reorders the tensor to the specified low precision data type.
15+
You don't need to convert the input tensors and the model to the specified data type manually,
16+
the extension will do it automatically and then dispatch the extension backend to accelerate
17+
computation
1918
20-
def get_auto_mix_precision(bf16 = True):
21-
return core.get_mix_bf16_fp32()
19+
Args:
20+
mixed_dtype(torch.dtype): Auto reorder the input tensors to the specified low precision data type
21+
and dispatch to oneDNN backend for computation
2222
23-
def enable_auto_mix_precision(bf16 = True):
24-
if bf16:
23+
"""
24+
if mixed_dtype != None:
25+
core.enable_auto_dnnl(True)
26+
enable_auto_mix_precision(mixed_dtype)
27+
28+
def get_auto_optimization():
29+
return get_auto_mix_precision
30+
31+
def enable_auto_mix_precision(mixed_dtype = torch.bfloat16):
32+
if mixed_dtype == torch.bfloat16:
2533
core.enable_mix_bf16_fp32()
2634
else:
2735
core.disable_mix_bf16_fp32()
36+
37+
def get_auto_mix_precision():
38+
if core.get_mix_bf16_fp32():
39+
return torch.bfloat16
40+
else:
41+
return None

intel_pytorch_extension_py/ops/jit_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1212
jit_m = orig_script(obj, optimize=optimize, _frames_up=_frames_up+1, _rcb=_rcb)
1313
torch.jit.script = script_
1414

15-
if core.get_jit():
15+
if core.get_jit_opt():
1616
# bypass buggy broadcastable ops in dnnl during folding
1717
core.disable_auto_dnnl()
1818
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))

tests/cpu/common_ipex_conf.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import intel_pytorch_extension as ipex
23

34
class AutoMixPrecision(object):
@@ -7,15 +8,15 @@ def __init__(self, enable_or_not = False):
78

89
def __enter__(self):
910
if self.enable_or_not:
10-
ipex.enable_auto_mix_precision(bf16=True)
11+
ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16)
1112
else:
12-
ipex.enable_auto_mix_precision(bf16=False)
13+
ipex.enable_auto_mix_precision(mixed_dtype=None)
1314

1415
def __exit__(self, *args, **kwargs):
1516
if self.old_value:
16-
ipex.enable_auto_mix_precision(bf16=True)
17+
ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16)
1718
else:
18-
ipex.enable_auto_mix_precision(bf16=False)
19+
ipex.enable_auto_mix_precision(mixed_dtype=None)
1920

2021
class AutoDNNL(object):
2122
def __init__(self, enable_or_not = False):
@@ -24,12 +25,12 @@ def __init__(self, enable_or_not = False):
2425

2526
def __enter__(self):
2627
if self.enable_or_not:
27-
ipex.enable_auto_optimization()
28+
ipex.core.enable_auto_dnnl()
2829
else:
29-
ipex.enable_auto_optimization(False)
30+
ipex.core.disable_auto_dnnl()
3031

3132
def __exit__(self, *args, **kwargs):
3233
if self.old_value:
33-
ipex.enable_auto_optimization()
34+
ipex.core.enable_auto_dnnl()
3435
else:
35-
ipex.enable_auto_optimization(False)
36+
ipex.core.disable_auto_dnnl()

tests/cpu/test_conf.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,25 @@
2222

2323
class TestOptConf(TestCase):
2424
def test_auto_dnnl(self):
25-
self.assertFalse(ipex.get_auto_dnnl())
26-
ipex.enable_auto_dnnl()
2725
self.assertTrue(ipex.get_auto_dnnl())
2826
ipex.disable_auto_dnnl()
2927
self.assertFalse(ipex.get_auto_dnnl())
30-
28+
ipex.enable_auto_dnnl()
29+
self.assertTrue(ipex.get_auto_dnnl())
30+
3131
def test_mix_bf16_fp32(self):
3232
self.assertFalse(ipex.get_mix_bf16_fp32())
3333
ipex.enable_mix_bf16_fp32()
3434
self.assertTrue(ipex.get_mix_bf16_fp32())
3535
ipex.disable_mix_bf16_fp32()
3636
self.assertFalse(ipex.get_mix_bf16_fp32())
3737

38-
def test_pure_bf16(self):
39-
self.assertFalse(ipex.get_pure_bf16())
40-
ipex.enable_pure_bf16()
41-
self.assertTrue(ipex.get_pure_bf16())
42-
ipex.disable_pure_bf16()
43-
self.assertFalse(ipex.get_pure_bf16())
38+
def test_jit_fuse(self):
39+
self.assertTrue(ipex.get_jit_opt())
40+
ipex.disable_jit_opt()
41+
self.assertFalse(ipex.get_jit_opt())
42+
ipex.enable_jit_opt()
43+
self.assertTrue(ipex.get_jit_opt())
4444

4545
if __name__ == '__main__':
4646
test = unittest.main()

tests/cpu/test_jit.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class Tester(TestCase):
207207

208208
def _test_output(self, model, x, kind=None):
209209
modelName = model.__class__.__name__
210-
core.disable_jit()
210+
core.disable_jit_opt()
211211
core.disable_mix_bf16_fp32()
212212

213213
model = model.to(device).eval()
@@ -222,7 +222,7 @@ def _test_output(self, model, x, kind=None):
222222

223223
self.assertEqual(result, sresult)
224224

225-
core.enable_jit()
225+
core.enable_jit_opt()
226226
fused_model = torch.jit.script(model)
227227
with torch.no_grad():
228228
# conv relu fusion, conv sum fusion or conv sum relu fusion
@@ -244,7 +244,7 @@ def _test_output_bf16(self, model, x, kind=None, prec=None):
244244
modelName = model.__class__.__name__
245245

246246
core.enable_auto_dnnl()
247-
core.enable_jit()
247+
core.enable_jit_opt()
248248
core.disable_mix_bf16_fp32()
249249

250250
model = model.to(ipex.DEVICE).eval()
@@ -300,7 +300,7 @@ def test_output_conv_bn_3d(self):
300300
torch.randn(32, 3, 112, 112, 112),
301301
kind="aten::conv3d",
302302
prec=0.02)
303-
303+
304304

305305
def test_output_conv_relu_2d(self):
306306
self._test_output(
@@ -333,8 +333,8 @@ def test_output_conv_sum_2d(self):
333333
Conv2dSum(3, 32, kernel_size=3, stride=1),
334334
torch.randn(32, 3, 224, 224),
335335
kind="ipex::conv2d_sum",
336-
prec=0.02)
337-
336+
prec=0.04)
337+
338338

339339
def test_output_conv_sum_3d(self):
340340
self._test_output(
@@ -345,8 +345,8 @@ def test_output_conv_sum_3d(self):
345345
Conv3dSum(3, 32, kernel_size=3, stride=1),
346346
torch.randn(32, 3, 112, 112, 112),
347347
kind="ipex::conv3d_sum",
348-
prec=0.02)
349-
348+
prec=0.04)
349+
350350

351351
def test_output_cascaded_conv_bn_sum_relu_2d(self):
352352
self._test_output(
@@ -358,7 +358,7 @@ def test_output_cascaded_conv_bn_sum_relu_2d(self):
358358
torch.rand(32, 3, 224, 224),
359359
kind="ipex::conv2d_sum_relu",
360360
prec=0.02)
361-
361+
362362

363363
def test_output_cascaded_conv_bn_sum_relu_3d(self):
364364
self._test_output(
@@ -370,7 +370,7 @@ def test_output_cascaded_conv_bn_sum_relu_3d(self):
370370
torch.rand(32, 3, 112, 112, 112),
371371
kind="ipex::conv3d_sum_relu",
372372
prec=0.02)
373-
373+
374374

375375
def test_output_linear_relu(self):
376376
self._test_output(

0 commit comments

Comments
 (0)