Skip to content

Commit 04eebae

Browse files
input quantization parameters propagate to the output for max_pool2d and adaptive_avg_pool2d (#12)
* input quantization parameters propagate to the output for max_pool2d and adaptive_avg_pool2d * [LLGA] add UT for int8 max_pool2d * [LLGA] add skipped UT for adap_avg_pool Co-authored-by: chunyuan <chunyuan.wu@intel.com>
1 parent 750d619 commit 04eebae

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

tests/cpu/test_jit_llga_quantization_fuser.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,60 @@ def test_linear(self):
104104
]
105105
self.checkPatterns(graph, patterns)
106106

107+
@llga_test_env
108+
def test_max_pool2d(self):
109+
for [
110+
spatial,
111+
kernel,
112+
padding,
113+
stride,
114+
dilation,
115+
ceil_mode
116+
] in itertools.product(
117+
[15], # [15, 16], TODO: check backend
118+
[3, 5], # [3, 4, 5], TODO: check backend
119+
[0, 1],
120+
[1, 2], # [1, 2, 4], TODO: fix issue in pad calculation
121+
[1, 2],
122+
[True, False]):
123+
124+
m = nn.MaxPool2d(kernel_size=kernel,
125+
stride=stride,
126+
padding=padding,
127+
dilation=dilation,
128+
ceil_mode=ceil_mode)
129+
x = torch.rand(1, 3, spatial, spatial)
130+
131+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="max_pool2d")
132+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
133+
self.assertFused(graph, ['aten::max_pool2d', 'aten::quantize_per_tensor', 'aten::dequantize'])
134+
135+
patterns = [
136+
["aten::quantize_per_tensor"],
137+
["aten::dequantize", "aten::max_pool2d", "aten::quantize_per_tensor"],
138+
["aten::dequantize"]
139+
]
140+
self.checkPatterns(graph, patterns)
141+
142+
@llga_test_env
143+
@unittest.skipIf(True, 'int8 adaptive_avg_pool2d is not supported in the backend')
144+
def test_adaptive_avg_pool2d(self):
145+
m = nn.AdaptiveAvgPool2d((1, 1))
146+
N = torch.randint(3, 10, (1,)).item()
147+
C = torch.randint(3, 10, (1,)).item()
148+
x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
149+
150+
graph = self.checkQuantizeTrace(m, [x], atol=1e-1, config_name="adaptive_avg_pool2d")
151+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
152+
self.assertFused(graph, ['aten::adaptive_avg_pool2d', 'aten::quantize_per_tensor', 'aten::dequantize'])
153+
154+
patterns = [
155+
["aten::quantize_per_tensor"],
156+
["aten::dequantize", "aten::adaptive_avg_pool2d", "aten::quantize_per_tensor"],
157+
["aten::dequantize"]
158+
]
159+
self.checkPatterns(graph, patterns)
160+
107161
class TestFusionPattern(JitLlgaTestCase):
108162
@llga_test_env
109163
def test_conv2d_eltwise(self):

torch_ipex/csrc/quantization/AutoCast.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ at::Tensor max_pool2d(const at::Tensor &input, at::IntArrayRef kernel_size, at::
274274
op_outputs.push_back(op_output);
275275
tensors_flow.emplace(output.unsafeGetTensorImpl(),
276276
val_name{weakref_scales(output.getIntrusivePtr()), op_output});
277-
torch_ipex::insert_or_updata_observer({input}, {output}, "max_pool2d",
277+
torch_ipex::insert_or_updata_observer({input}, {input}, "max_pool2d",
278278
op_id, op_inputs, op_outputs);
279279
return output;
280280
}
@@ -318,7 +318,7 @@ at::Tensor adaptive_avg_pool2d(const at::Tensor &input, at::IntArrayRef output_s
318318
op_outputs.push_back(op_output);
319319
tensors_flow.emplace(output.unsafeGetTensorImpl(),
320320
val_name{weakref_scales(output.getIntrusivePtr()), op_output});
321-
torch_ipex::insert_or_updata_observer({input}, {output}, "adaptive_avg_pool2d",
321+
torch_ipex::insert_or_updata_observer({input}, {input}, "adaptive_avg_pool2d",
322322
op_id, op_inputs, op_outputs);
323323
return output;
324324
}

0 commit comments

Comments
 (0)