Skip to content

Commit ed5e9dd

Browse files
committed
make rewrited linear op can be traced
1 parent a262ee1 commit ed5e9dd

File tree

9 files changed

+82
-37
lines changed

9 files changed

+82
-37
lines changed

intel_pytorch_extension_py/ops/linear.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,4 @@
33
import torch.nn.functional as F
44
import _torch_ipex as core
55

6-
F_linear = F.linear
7-
8-
class LinearFunction(Function):
9-
@staticmethod
10-
def forward(ctx, input, weight, bias):
11-
output = core.linear(input, weight, bias)
12-
ctx.save_for_backward(input, weight, bias)
13-
return output
14-
15-
@staticmethod
16-
def backward(ctx, grad_output):
17-
input, weight, bias = ctx.saved_tensors
18-
grad_output = grad_output.contiguous()
19-
if bias == None:
20-
output_mask = (input.requires_grad, weight.requires_grad, 0)
21-
else:
22-
output_mask = (input.requires_grad, weight.requires_grad, bias.requires_grad)
23-
grad_input, grad_weight, grad_bias = core.linear_backward(input, grad_output, weight, output_mask)
24-
return (grad_input, grad_weight, grad_bias)
25-
26-
def linear(input, weight, bias=None):
27-
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
28-
return LinearFunction.apply(input, weight, bias)
29-
return F_linear(input, weight, bias)
30-
31-
F.linear = linear
6+
F.linear = torch.ops.torch_ipex.linear

torch_ipex/csrc/cpu/CustomerOps.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#pragma once
2+
3+
#include <torch/csrc/autograd/variable.h>
4+
#include <torch/csrc/autograd/custom_function.h>
5+
#include <torch/csrc/autograd/function.h>
6+
#include <ATen/Tensor.h>
7+
#include <torch/script.h>
8+
#include <c10/util/Optional.h>
9+
#include "torch_ipex/csrc/utils.h"
10+
#include "DevOPs.h"
11+
12+
using namespace at;
13+
14+
class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
15+
public:
16+
static at::Tensor forward(
17+
torch::autograd::AutogradContext* ctx,
18+
at::Tensor input,
19+
at::Tensor weight,
20+
at::Tensor bias) {
21+
ctx->save_for_backward({input, weight, bias});
22+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
23+
return torch_ipex::cpu::AtenIpexCPUDev::dil_linear(input, weight, bias);
24+
} else {
25+
return at::linear(input, weight, bias);
26+
}
27+
}
28+
29+
static torch::autograd::tensor_list backward(
30+
torch::autograd::AutogradContext* ctx,
31+
torch::autograd::tensor_list grad_outputs) {
32+
auto saved = ctx->get_saved_variables();
33+
at::Tensor input = saved[0];
34+
at::Tensor weight = saved[1];
35+
at::Tensor bias = saved[2];
36+
37+
at::Tensor grad_output = grad_outputs[0];
38+
at::Tensor grad_input, grad_weight;
39+
at::Tensor grad_bias = torch::Tensor();
40+
41+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
42+
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_input(
43+
input.sizes(), grad_output, weight);
44+
std::tie(grad_weight, grad_bias) = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_weights(
45+
grad_output, input, weight, bias.defined());
46+
} else {
47+
auto grad_input = grad_output.mm(weight);
48+
auto grad_weight = grad_output.t().mm(input);
49+
if (bias.defined()) {
50+
grad_bias = grad_output.sum(0);
51+
}
52+
}
53+
return {grad_input, grad_weight, grad_bias};
54+
}
55+
};

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ at::Tensor& AtenIpexCPUDev::dil_addbmm_(
526526
at::Tensor AtenIpexCPUDev::dil_linear(
527527
const at::Tensor& self,
528528
const at::Tensor& weight,
529-
const c10::optional<at::Tensor>& bias) {
529+
const at::Tensor& bias) {
530530
DEBUG("AtenIpexCPUDev::dil_linear\n");
531531
CHECK_DNNL_OP_PRE_COND(self);
532532
CHECK_DNNL_OP_PRE_COND(weight);
@@ -539,9 +539,8 @@ at::Tensor AtenIpexCPUDev::dil_linear(
539539
const dil::tensor w = dbl::comm::try_gen_dil_tensor(weight);
540540

541541
dil::tensor y;
542-
if (bias.has_value()) {
543-
at::Tensor bias_vec = bias.value();
544-
const dil::tensor b = dbl::comm::try_gen_dil_tensor(bias_vec);
542+
if (bias.defined()) {
543+
const dil::tensor b = dbl::comm::try_gen_dil_tensor(bias);
545544
dil::inner_product_forward::compute(x, w, b, y);
546545
} else {
547546
dil::inner_product_forward::compute(x, w, y);
@@ -557,7 +556,7 @@ at::Tensor AtenIpexCPUDev::dil_linear(
557556
return dbl::comm::gen_aten_tensor_by(y);
558557
}
559558

560-
at::Tensor dil_linear_backward_input(
559+
at::Tensor AtenIpexCPUDev::dil_linear_backward_input(
561560
at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight){
562561
DEBUG("AtenIpexCPUDev::dil_linear_backward_input\n");
563562
auto grad_output_reshaped = grad_output.dim() > 2 ?
@@ -579,7 +578,7 @@ at::Tensor dil_linear_backward_input(
579578
return dbl::comm::gen_aten_tensor_by(gradx);
580579
}
581580

582-
std::tuple<at::Tensor, at::Tensor> dil_linear_backward_weights(
581+
std::tuple<at::Tensor, at::Tensor> AtenIpexCPUDev::dil_linear_backward_weights(
583582
const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& weight, bool bias_defined) {
584583
DEBUG("AtenIpexCPUDev::dil_linear_backward_weights\n");
585584
auto grad_output_reshaped = grad_output.dim() > 2 ?

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ class AtenIpexCPUDev {
3838
static at::Tensor dil_addbmm(const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha);
3939
static at::Tensor& dil_addbmm_(at::Tensor& self, const at::Tensor& batch1, const at::Tensor& batch2, at::Scalar beta, at::Scalar alpha);
4040
static at::Tensor& dil_addbmm_out(at::Tensor& result, const at::Tensor &self, const at::Tensor &batch1, const at::Tensor &batch2, at::Scalar beta, at::Scalar alpha);
41-
static at::Tensor dil_linear(const at::Tensor& self, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
41+
static at::Tensor dil_linear(const at::Tensor& self, const at::Tensor& weight, const at::Tensor& bias);
42+
static at::Tensor dil_linear_backward_input(at::IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight);
43+
static std::tuple<at::Tensor, at::Tensor> dil_linear_backward_weights(const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& weight, bool bias_defined);
4244
static std::tuple<at::Tensor, at::Tensor, at::Tensor> dil_linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
4345
static at::Tensor dil_dropout(const at::Tensor& self, double ratio, bool train);
4446
static at::Tensor dil_dropout_backward(const at::Tensor& grady, const at::Tensor& mask, double ratio);

torch_ipex/csrc/cpu/ExtendOPs.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "xsmm/libxsmm_utils.h"
1010
#include "../utils.h"
1111
#include "DevOPs.h"
12+
#include "CustomerOps.h"
1213

1314
namespace torch_ipex {
1415

@@ -423,8 +424,9 @@ at::Tensor AtenIpexTypeExt::embedding_bag_backward(const at::Tensor &grad_out,
423424
}
424425
}
425426

426-
at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
427-
return cpu::AtenIpexCPUDev::dil_linear(input, weight, bias);
427+
428+
at::Tensor AtenIpexTypeExt::linear(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) {
429+
return NewLinearOp::apply(input, weight, bias);
428430
}
429431

430432
std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexTypeExt::linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask) {

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class AtenIpexTypeExt {
1212
static std::vector<at::Tensor> interaction_backward(const at::Tensor & grad_out, const std::vector<at::Tensor> & input);
1313
static at::Tensor embedding_bag_forward(const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
1414
static at::Tensor embedding_bag_backward(const at::Tensor &grad_out, const at::Tensor &weights, const at::Tensor &inputs, const at::Tensor &offsets);
15-
static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias);
15+
static at::Tensor linear(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias = at::Tensor());
1616
static std::tuple<at::Tensor, at::Tensor, at::Tensor> linear_backward(const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, std::array<bool,3> output_mask);
1717
static at::Tensor adaptive_avg_pool2d(at::Tensor const& input, at::IntArrayRef output_size);
1818
static at::Tensor adaptive_avg_pool2d_backward(const at::Tensor& grad_output, const at::Tensor& input);

torch_ipex/csrc/cpu/RegisterOps.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include <torch/script.h>
2+
#include "ExtendOPs.h"
3+
4+
static auto registry =
5+
torch::RegisterOperators()
6+
.op("torch_ipex::linear",
7+
[](const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) {
8+
return torch_ipex::AtenIpexTypeExt::linear(input, weight, bias);
9+
});
10+
11+

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tenso
9191
dil::dims sizes = dil_tensor.get_dims();
9292
if (dil_tensor.is_public_format()) {
9393
dil::dims strides = dil_tensor.get_strides();
94+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ipex_tensor.device().type() == at::DeviceType::DPCPP);
9495
auto* _tensor_impl = (IPEXTensorImpl *)ipex_tensor.unsafeGetTensorImpl();
9596
_tensor_impl->force_set_strided(sizes, strides);
9697
} else {

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ void InitIpexModuleBindings(py::module m) {
9595
return AtenIpexTypeExt::embedding_bag_backward(grad_out, weights, inputs, offsets);
9696
});
9797
m.def("linear",
98-
[](const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) {
98+
[](const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) {
9999
return AtenIpexTypeExt::linear(input, weight, bias);
100100
});
101101
m.def("linear_backward",

0 commit comments

Comments
 (0)