Skip to content

Commit 8d52321

Browse files
authored
Jingxu10/cpu device verbose 2 (#155)
* enable mkldnn on-demands verbose functionality * make ipex verbose compatible to non-verbose enabled pytorch * fine tune for clang format * clang format; removed apply_to_pytorch parameter of ipex verbose class * added ut for verbose functionality * removed not used packages from test_verbose.py * added verbose on and off scenarios to test_verbose.py * removed magic numbers * fix a bug in test/cpu/verbose.py * add abs path to verbose.py for test_verbose.py
1 parent a986bd3 commit 8d52321

File tree

8 files changed

+116
-8
lines changed

8 files changed

+116
-8
lines changed

ideep/ideep/utils.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ inline void array_set(T *arr, const U &val, size_t size) {
298298
arr[i] = static_cast<T>(val);
299299
}
300300

301+
inline int set_verbose(int level) {
302+
dnnl::status ret = dnnl::set_verbose(level);
303+
return ret == dnnl::status::success;
304+
}
301305
}
302306
}
303307
#endif

intel_pytorch_extension_py/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .weight_prepack import _weight_prepack_with_ipex
99
from .weight_cast import _weight_dtype_convert_with_ipex
1010
from .optimizer_utils import _ipex_optimizer
11+
import _torch_ipex as core
1112

1213
def _replace_dropout_with_identity(model):
1314
# replace dropout with identity during inference, so that aten::dropout won't be on the JIT graph.
@@ -100,3 +101,36 @@ def optimize(model, dtype=torch.bfloat16, optimizer=None, level='O1', inplace=Fa
100101
return optimized_model
101102
else:
102103
return optimized_model, optimized_optimizer
104+
105+
VERBOSE_OFF = 0
106+
VERBOSE_ON = 1
107+
VERBOSE_ON_CREATION = 2
108+
class verbose(object):
109+
def __init__(self, level):
110+
self.level = level
111+
112+
def __enter__(self):
113+
if self.level == VERBOSE_OFF:
114+
return
115+
try:
116+
st = torch._C._verbose.mkldnn_set_verbose(self.level)
117+
assert bool(st), "Failed to set Verbose mode of MKLDNN in PyTorch. Please consider to disable this verbose scope."
118+
except:
119+
pass
120+
st = core.mkldnn_set_verbose(self.level)
121+
assert bool(st), "Failed to set Verbose mode of MKLDNN in IPEX. Please consider to disable this verbose scope."
122+
return self
123+
124+
def __exit__(self, exc_type, exc_val, exc_tb):
125+
core.mkldnn_set_verbose(VERBOSE_OFF)
126+
try:
127+
torch._C._verbose.mkldnn_set_verbose(VERBOSE_OFF)
128+
except:
129+
pass
130+
return False
131+
132+
try:
133+
verbose_torch=torch.backends.mkldnn.verbose
134+
torch.backends.mkldnn.verbose = verbose
135+
except:
136+
pass

tests/cpu/test_verbose.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import unittest
2+
from common_utils import TestCase
3+
import os
4+
import subprocess
5+
6+
class TestVerbose(TestCase):
7+
def test_verbose_on(self):
8+
num = 0
9+
loc = os.path.dirname(os.path.abspath(__file__))
10+
with subprocess.Popen('python -u {}/verbose.py --verbose-level=1'.format(loc), shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
11+
for line in p.stdout.readlines():
12+
line = str(line, 'utf-8').strip()
13+
if line.startswith("dnnl_verbose"):
14+
num = num + 1
15+
assert num > 0, 'oneDNN verbose messages not found.'
16+
17+
def test_verbose_off(self):
18+
num = 0
19+
loc = os.path.dirname(os.path.abspath(__file__))
20+
with subprocess.Popen('python -u {}/verbose.py --verbose-level=0'.format(loc), shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as p:
21+
for line in p.stdout.readlines():
22+
line = str(line, 'utf-8').strip()
23+
if line.startswith("dnnl_verbose"):
24+
num = num + 1
25+
assert num == 0, 'unexpected oneDNN verbose messages found.'
26+
27+
if __name__ == '__main__':
28+
test = unittest.main()

tests/cpu/verbose.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import argparse
2+
import torch
3+
import intel_pytorch_extension as ipex
4+
5+
class Module(torch.nn.Module):
6+
def __init__(self):
7+
super(Module, self).__init__()
8+
self.conv = torch.nn.Conv2d(1, 10, 5, 1)
9+
10+
def forward(self, x):
11+
y = self.conv(x)
12+
return y
13+
14+
def run_model(level):
15+
m = Module()
16+
m = ipex.optimize(m, dtype=torch.float32, level="O1")
17+
d = torch.rand(1, 1, 112, 112)
18+
with ipex.utils.verbose(level):
19+
m(d)
20+
21+
if __name__ == '__main__':
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument("--verbose-level", default=0, type=int)
24+
args = parser.parse_args()
25+
run_model(args.verbose_level)

torch_ipex/csrc/cpu/mkldnn/MKLDNNCommon.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,5 @@ at::Tensor empty_aten_tensor_from_desc(const ideep::tensor::desc& desc, const at
121121
return at::empty(at_sizes, options);
122122
}
123123

124+
int mkldnn_set_verbose(int level) { return ideep::utils::set_verbose(level); }
124125
}}

torch_ipex/csrc/cpu/mkldnn/MKLDNNCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ ideep::tensor itensor_from_tensor(const at::Tensor& tensor);
2525

2626
at::Tensor empty_aten_tensor_from_desc(const ideep::tensor::desc& desc, const at::TensorOptions& options);
2727

28+
int mkldnn_set_verbose(int level);
2829
}}

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
#include <string>
2020
#include <vector>
2121

22-
#include "utils.h"
2322
#include "auto_opt_config.hpp"
24-
#include "quantization/Observer.hpp"
25-
#include "quantization/Config.hpp"
2623
#include "quantization/AutoCast.hpp"
24+
#include "quantization/Config.hpp"
25+
#include "quantization/Observer.hpp"
26+
#include "utils.h"
27+
#include "verbose.hpp"
2728

2829
//#include "ProcessGroupCCL.hpp"
2930
#include <pybind11/chrono.h>
@@ -47,6 +48,7 @@ py::object GetRevisions() {
4748

4849
void InitIpexModuleBindings(py::module m) {
4950
m.def("_get_git_revs", []() { return GetRevisions(); });
51+
m.def("mkldnn_set_verbose", &torch_ipex::verbose::_mkldnn_set_verbose);
5052
// ipex amp autocast
5153
m.def("is_autocast_enabled", &torch_ipex::autocast::is_autocast_enabled);
5254
m.def("set_autocast_enabled", &torch_ipex::autocast::set_autocast_enabled);
@@ -65,7 +67,7 @@ void InitIpexModuleBindings(py::module m) {
6567
m.def("autocast_decrement_nesting",
6668
&torch_ipex::autocast::autocast_decrement_nesting);
6769
m.def("clear_autocast_cache", &torch_ipex::autocast::clear_autocast_cache);
68-
70+
6971
// llga path
7072
m.def("_jit_set_llga_enabled", &torch::jit::RegisterLlgaFuseGraph::setEnabled);
7173
m.def("_jit_llga_enabled", &torch::jit::RegisterLlgaFuseGraph::isEnabled);
@@ -108,7 +110,7 @@ void InitIpexModuleBindings(py::module m) {
108110
d["weight_granularity"] = indicator.get_indicator_weight_granularity();
109111
std::vector<float> x_scales, y_scales;
110112
std::vector<int64_t> x_zero_points, y_zero_points;
111-
std::vector<quant_utils::TensorQuantizationParams> x_params, y_params;
113+
std::vector<quant_utils::TensorQuantizationParams> x_params, y_params;
112114
std::tie(x_params, y_params) = indicator.get_indicator_scales();
113115
for (auto& p: x_params) {
114116
x_scales.push_back(p.scale);
@@ -123,13 +125,14 @@ void InitIpexModuleBindings(py::module m) {
123125
d["input_zero_points"] = x_zero_points;
124126
d["output_scales"] = y_scales;
125127
d["output_zero_points"] = y_zero_points;
126-
d["weight_scales"] = w_scales;
128+
d["weight_scales"] = w_scales;
127129
std::vector<std::string> i_quantized_dtypes, o_quantized_dtypes;
128130
std::tie(i_quantized_dtypes, o_quantized_dtypes)= indicator.get_indicator_quantized_dtypes();
129131
d["input_quantized_dtypes"] = i_quantized_dtypes;
130132
d["output_quantized_dtypes"] = o_quantized_dtypes;
131133
std::vector<bool> inputs_quantized, outputs_quantized;
132-
std::tie(inputs_quantized, outputs_quantized) = indicator.get_indicator_insert_quantized_status();
134+
std::tie(inputs_quantized, outputs_quantized) =
135+
indicator.get_indicator_insert_quantized_status();
133136
d["inputs_quantized"] = inputs_quantized;
134137
d["outputs_quantized"] = outputs_quantized;
135138
std::vector<std::string> inputs_flow, outputs_flow;
@@ -188,7 +191,7 @@ using namespace torch::jit;
188191

189192
void InitIpexBindings(py::module m) {
190193
InitIpexModuleBindings(m);
191-
194+
192195
// // llga jit fusion pass
193196
// torch::jit::registerPrePass([](std::shared_ptr<Graph>& g) {
194197
// if (torch::jit::RegisterLlgaFuseGraph::isEnabled()) {

torch_ipex/csrc/verbose.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include "cpu/mkldnn/MKLDNNCommon.h"
2+
3+
namespace torch_ipex {
4+
5+
namespace verbose {
6+
7+
int _mkldnn_set_verbose(int level) {
8+
return torch_ipex::cpu::mkldnn_set_verbose(level);
9+
}
10+
11+
} // namespace verbose
12+
} // namespace torch_ipex

0 commit comments

Comments
 (0)