Skip to content

Commit 41caea8

Browse files
committed
Enable upsample_bilinear2d to support the scale factor is vector
1 parent 054697e commit 41caea8

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@
9595
'aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> Tensor',
9696
'aten::upsample_linear1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size, bool align_corners, float? scales=None) -> Tensor',
9797
'aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor',
98+
'aten::upsample_bilinear2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor',
9899
'aten::upsample_bilinear2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor',
100+
'aten::upsample_bilinear2d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, bool align_corners, float[]? scale_factors) -> Tensor',
99101
'aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor',
100102
'aten::upsample_trilinear3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor',
101103
'aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)',
@@ -578,7 +580,7 @@ def is_conv_overrideable_func(fname):
578580

579581
# Gen OP Name
580582
code += '#if defined(IPEX_DISP_OP)\n'
581-
code += ' printf("{}::{}\\n");\n'.format(_IPEX_OP_FUNC_NS, cpp_sig.def_name)
583+
code += ' printf("{}::{}\\n");\n'.format(_IPEX_OP_FUNC_NS, new_cpp_func_name)
582584
code += '#endif\n'
583585

584586
# Gen profile info
@@ -587,7 +589,7 @@ def is_conv_overrideable_func(fname):
587589
if param.core_type in ['Tensor', 'Scalar']:
588590
profiler_inputs.append(param.name)
589591
code += '#if defined(IPEX_PROFILE_OP)\n'
590-
code += ' RECORD_FUNCTION("{ns}::{name}", std::vector<c10::IValue>({{}}));\n'.format(ns=_IPEX_OP_FUNC_NS, name=cpp_sig.def_name)
592+
code += ' RECORD_FUNCTION("{ns}::{name}", std::vector<c10::IValue>({{}}));\n'.format(ns=_IPEX_OP_FUNC_NS, name=new_cpp_func_name)
591593
code += '#endif\n'
592594

593595
if is_conv_overrideable_func(cpp_sig.def_name):

tests/cpu/test_lazy_reorder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,6 +1966,17 @@ def test_upsample_bilinear2d_scale_factor(self):
19661966
y_dpcpp.sum().backward()
19671967
self.assertEqual(x_cpu.grad, x_dpcpp.grad)
19681968

1969+
with AutoDNNL(True):
1970+
x = torch.randn(2, 2, 4, 4)
1971+
x_cpu = x.clone().requires_grad_()
1972+
x_dpcpp = x.clone().to(device=device).requires_grad_()
1973+
y_cpu = F.interpolate(x_cpu, scale_factor = [2, 3], mode='bilinear', align_corners=False, recompute_scale_factor=False)
1974+
y_dpcpp = F.interpolate(x_dpcpp, scale_factor = [2, 3], mode='bilinear', align_corners=False, recompute_scale_factor=False)
1975+
self.assertEqual(y_cpu, y_dpcpp)
1976+
y_cpu.sum().backward()
1977+
y_dpcpp.sum().backward()
1978+
self.assertEqual(x_cpu.grad, x_dpcpp.grad)
1979+
19691980
def test_upsample_bilinear2d_size(self):
19701981
rand_seed = int(get_rand_seed())
19711982
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2708,6 +2708,17 @@ at::Tensor AtenIpexCPUDev::dil_upsample_linear1d_backward(const at::Tensor & gra
27082708
return dbl::upsample::dil_upsample_backward(grad_output, input_size, dil::algorithm::resampling_linear, scales);
27092709
}
27102710

2711+
at::Tensor AtenIpexCPUDev::dil_upsample_bilinear2d(const at::Tensor & self, c10::optional<at::IntArrayRef> output_size, bool align_corners, c10::optional<c10::ArrayRef<double>> scale_factors) {
2712+
DEBUG("AtenIpexCPUDev::dil_upsample_bilinear2d_vec\n");
2713+
auto scale_h = c10::optional<double>(1.0);
2714+
auto scale_w = c10::optional<double>(1.0);
2715+
if (scale_factors.has_value()) {
2716+
scale_h = c10::optional<double>(scale_factors->at(0));
2717+
scale_w = c10::optional<double>(scale_factors->at(1));
2718+
}
2719+
return dbl::upsample::dil_upsample(self, output_size.value(), dil::algorithm::resampling_linear, scale_h, scale_w);
2720+
}
2721+
27112722
at::Tensor AtenIpexCPUDev::dil_upsample_bilinear2d(const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional<double> scales_h, c10::optional<double> scales_w) {
27122723
DEBUG("AtenIpexCPUDev::dil_upsample_bilinear2d\n");
27132724
IPEX_CHECK(align_corners == false, "dil_upsample_bilinear2d not support align_corners mode yet");
@@ -2722,6 +2733,19 @@ at::Tensor AtenIpexCPUDev::dil_upsample_bilinear2d_backward(const at::Tensor & g
27222733
return dbl::upsample::dil_upsample_backward(grad_output, input_size, dil::algorithm::resampling_linear, scales_h, scales_w);
27232734
}
27242735

2736+
at::Tensor AtenIpexCPUDev::dil_upsample_bilinear2d_backward(const at::Tensor & grad_output, c10::optional<at::IntArrayRef> output_size, at::IntArrayRef input_size, bool align_corners, c10::optional<c10::ArrayRef<double>> scale_factors) {
2737+
DEBUG("AtenIpexCPUDev::dil_upsample_bilinear2d_backward_vec\n");
2738+
IPEX_CHECK(align_corners == false, "dil_upsample_bilinear2d_backward_vec not support align_corners mode yet");
2739+
CHECK_DNNL_OP_PRE_COND(grad_output);
2740+
auto scales_h = c10::optional<double>(1.0);
2741+
auto scales_w = c10::optional<double>(1.0);
2742+
if (scale_factors.has_value()) {
2743+
scales_h = c10::optional<double>(scale_factors->at(0));
2744+
scales_w = c10::optional<double>(scale_factors->at(1));
2745+
}
2746+
return dbl::upsample::dil_upsample_backward(grad_output, input_size, dil::algorithm::resampling_linear, scales_h, scales_w);
2747+
}
2748+
27252749
at::Tensor AtenIpexCPUDev::dil_upsample_trilinear3d(const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional<double> scales_d, c10::optional<double> scales_h, c10::optional<double> scales_w) {
27262750
DEBUG("AtenIpexCPUDev::dil_upsample_trilinear3d\n");
27272751
IPEX_CHECK(align_corners == false, "dil_upsample_trilinear3d not support align_corners mode yet");

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ class AtenIpexCPUDev {
104104
static at::Tensor dil_upsample_nearest3d_backward(const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, c10::optional<double> scales_d, c10::optional<double> scales_h, c10::optional<double> scales_w);
105105
static at::Tensor dil_upsample_linear1d(const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional<double> scales);
106106
static at::Tensor dil_upsample_linear1d_backward(const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional<double> scales);
107+
static at::Tensor dil_upsample_bilinear2d(const at::Tensor & self, c10::optional<at::IntArrayRef> output_size, bool align_corners, c10::optional<c10::ArrayRef<double>> scale_factors);
107108
static at::Tensor dil_upsample_bilinear2d(const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional<double> scales_h, c10::optional<double> scales_w);
108109
static at::Tensor dil_upsample_bilinear2d_backward(const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional<double> scales_h, c10::optional<double> scales_w);
110+
static at::Tensor dil_upsample_bilinear2d_backward(const at::Tensor & grad_output, c10::optional<at::IntArrayRef> output_size, at::IntArrayRef input_size, bool align_corners, c10::optional<c10::ArrayRef<double>> scale_factors);
109111
static at::Tensor dil_upsample_trilinear3d(const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, c10::optional<double> scales_d, c10::optional<double> scales_h, c10::optional<double> scales_w);
110112
static at::Tensor dil_upsample_trilinear3d_backward(const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, c10::optional<double> scales_d, c10::optional<double> scales_h, c10::optional<double> scales_w);
111113
static at::Tensor dil_unsqueeze(const at::Tensor& self, int64_t dim);

0 commit comments

Comments
 (0)