Skip to content

Commit 8c9fb3d

Browse files
committed
only jit fusion for extension path
1 parent 23990f7 commit 8c9fb3d

15 files changed

+198
-441
lines changed

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -67,41 +67,6 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
6767
return dbl::comm::gen_aten_tensor_by(dil_output);
6868
}
6969

70-
at::Tensor AtenIpexCPUDev::dil_convolution_relu(
71-
const at::Tensor & input,
72-
const at::Tensor & weight,
73-
const at::Tensor & bias,
74-
at::IntArrayRef stride,
75-
at::IntArrayRef padding,
76-
at::IntArrayRef dilation,
77-
int64_t groups) {
78-
DEBUG("AtenIpexCPUDev::dil_convolution\n");
79-
dil::tensor dil_input;
80-
dil::tensor dil_weight;
81-
c10::optional<dil::tensor> dil_bias{c10::nullopt};
82-
83-
CHECK_DNNL_OP_PRE_COND(input);
84-
CHECK_DNNL_OP_PRE_COND(weight);
85-
dil_input = dbl::comm::try_gen_dil_tensor(input);
86-
dil_weight = dbl::comm::try_gen_dil_tensor(weight);
87-
if (bias.defined()) {
88-
CHECK_DNNL_OP_PRE_COND(bias);
89-
dil_bias = dbl::comm::try_gen_dil_tensor(bias);
90-
}
91-
92-
dil::tensor dil_output = dbl::conv::conv2d_impl(
93-
dil_input,
94-
dil_weight,
95-
dil_bias,
96-
padding,
97-
stride,
98-
dilation,
99-
groups,
100-
dil::attr_t::fuse_relu());
101-
102-
return dbl::comm::gen_aten_tensor_by(dil_output);
103-
}
104-
10570
at::Tensor dil_convolution_backward_input(
10671
at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
10772
at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool bias_defined)

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ class AtenIpexCPUDev {
6868
static std::vector<at::Tensor> dil_split_with_sizes(const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim);
6969
static std::vector<at::Tensor> dil_split(const at::Tensor& self, int64_t split_size, int64_t dim);
7070

71-
// for JIT ops
72-
static at::Tensor dil_convolution_relu(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
73-
7471
};
7572

7673
} // namespace cpu

torch_ipex/csrc/cpu/FusionOPs.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "torch_ipex/csrc/cpu/FusionOPs.h"
2+
3+
#include <ATen/Context.h>
4+
#include <ATen/CPUGenerator.h>
5+
#include <ATen/InferSize.h>
6+
#include <c10/util/Exception.h>
7+
#include <c10/util/Logging.h>
8+
9+
#include <limits>
10+
11+
#include "torch_ipex/csrc/aten_ipex_bridge.h"
12+
#include "torch_ipex/csrc/ipex_tensor_impl.h"
13+
#include "torch_ipex/csrc/utils.h"
14+
#include "dbl/Common.h"
15+
#include "dbl/Conv.h"
16+
#include "ShadeDataContext.h"
17+
18+
#include "dil/dil.hpp"
19+
20+
namespace torch_ipex {
21+
namespace cpu {
22+
23+
at::Tensor AtenIpexJITDev::dil_convolution_relu(
24+
const at::Tensor & input,
25+
const at::Tensor & weight,
26+
const at::Tensor & bias,
27+
at::IntArrayRef stride,
28+
at::IntArrayRef padding,
29+
at::IntArrayRef dilation,
30+
int64_t groups) {
31+
dil::tensor dil_input;
32+
dil::tensor dil_weight;
33+
c10::optional<dil::tensor> dil_bias{c10::nullopt};
34+
35+
auto input_contiguous = input.contiguous();
36+
auto weight_contiguous = weight.contiguous();
37+
38+
dil_input = dbl::comm::try_gen_dil_tensor(input_contiguous);
39+
dil_weight = dbl::comm::try_gen_dil_tensor(weight_contiguous);
40+
if (bias.defined()) {
41+
auto bias_contiguous = bias.contiguous();
42+
dil_bias = dbl::comm::try_gen_dil_tensor(bias_contiguous);
43+
}
44+
45+
dil::tensor dil_output = dbl::conv::conv2d_impl(
46+
dil_input,
47+
dil_weight,
48+
dil_bias,
49+
padding,
50+
stride,
51+
dilation,
52+
groups,
53+
dil::attr_t::fuse_relu());
54+
55+
return dbl::comm::gen_aten_tensor_by(dil_output);
56+
}
57+
58+
} // namespace cpu
59+
} // namespace torch_ipex

torch_ipex/csrc/cpu/FusionOPs.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
5+
#include <torch/csrc/jit/runtime/custom_operator.h>
6+
7+
#include "dil/dil.hpp"
8+
9+
namespace torch { namespace jit {
10+
11+
// XXX: PyTorch does not support nesting namespace
12+
// And the alias analysis is not working for namespace other than aten ...
13+
// So we fake some op namespaces to workaround that.
14+
namespace dnnl {
15+
static auto conv2d_relu = Symbol::fromQualString("dnnl::conv2d_relu");
16+
static auto conv2d_sum = Symbol::fromQualString("dnnl::conv2d_sum");
17+
static auto conv2d_relu_sum = Symbol::fromQualString("dnnl::conv2d_relu_sum");
18+
static auto conv2d_sum_relu = Symbol::fromQualString("dnnl::conv2d_sum_relu");
19+
20+
}
21+
22+
}} // namespace torch::jit
23+
24+
namespace torch_ipex {
25+
namespace cpu {
26+
27+
class AtenIpexJITDev {
28+
public:
29+
// for JIT ops
30+
static at::Tensor dil_convolution_relu(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
31+
32+
};
33+
34+
} // namespace cpu
35+
} // namespace torch_ipex

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <torch/csrc/jit/runtime/operator_options.h>
1111
#include <torch/csrc/jit/passes/pass_manager.h>
1212
#include "jit/fusion_pass.h"
13-
#include "jit/op_rewrite.h"
1413

1514
#include <cstring>
1615
#include <sstream>
@@ -141,23 +140,8 @@ using namespace torch::jit;
141140

142141
void InitIpexBindings(py::module m) {
143142
InitIpexModuleBindings(m);
144-
145-
// fro jit path
146-
RegisterPass pass_1([](std::shared_ptr<Graph>& g) {
147-
if (AutoOptConfig::singleton().get_jit_fuse()) {
148-
torch::jit::OpRewritePass(g);
149-
}
150-
});
151-
/*
152-
RegisterPass pass_2([](std::shared_ptr<Graph>& g) {
153-
if (AutoOptConfig::singleton().get_jit_fuse()) {
154-
std::cout<<"uisng pass2"<<std::endl;
155-
torch::jit::FormatOptimize(g);
156-
}
157-
});
158-
*/
159143
// jit fusion pass
160-
RegisterPass pass3([](std::shared_ptr<Graph>& g) {
144+
RegisterPass pass([](std::shared_ptr<Graph>& g) {
161145
if (AutoOptConfig::singleton().get_jit_fuse()) {
162146
torch::jit::FusionPass(g);
163147
}

torch_ipex/csrc/jit/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
LIST(APPEND DPCPP_JIT_SRCS
22
${DPCPP_ROOT}/jit/fusion_pass.cpp
3-
${DPCPP_ROOT}/jit/graph_ext.cpp
4-
${DPCPP_ROOT}/jit/op_rewrite.cpp
53
${DPCPP_ROOT}/jit/register_dnnl_jit_ops.cpp
64

75
)

torch_ipex/csrc/jit/accelerated_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "cpu/dil/dil.hpp"
3+
#include <ideep.hpp>
44
#include <torch/csrc/jit/runtime/custom_operator.h>
55

66
namespace torch { namespace jit {

torch_ipex/csrc/jit/dnnl_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "cpu/dil/dil.hpp"
3+
#include <ideep.hpp>
44
#include <ATen/ATen.h>
55
#include <ATen/NativeFunctions.h>
66

0 commit comments

Comments
 (0)