Skip to content

Commit b0e83a0

Browse files
committed
support size() to avoid fallback to public format
1 parent 609a583 commit b0e83a0

File tree

3 files changed

+10
-0
lines changed

3 files changed

+10
-0
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
'aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor',
6565
'aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)',
6666
'aten::resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)',
67+
'aten::size.int(Tensor self, int dim) -> int',
6768
'aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor',
6869
'aten::gelu(Tensor self) -> Tensor',
6970
'aten::gelu_backward(Tensor grad, Tensor self) -> Tensor',

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,14 @@ at::Tensor& AtenIpexCPUDev::dil_resize_(at::Tensor& self, at::IntArrayRef size,
13951395
return self;
13961396
}
13971397

1398+
int64_t AtenIpexCPUDev::dil_size(const at::Tensor & self, int64_t dim) {
1399+
DEBUG("AtenIpexCPUDev::dil_size\n");
1400+
CHECK_DNNL_OP_PRE_COND(self);
1401+
1402+
dim = at::maybe_wrap_dim(dim, self.dim(), false);
1403+
return self.sizes()[dim];
1404+
}
1405+
13981406
at::Tensor AtenIpexCPUDev::dil_clone(const at::Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
13991407
DEBUG("AtenIpexCPUDev::dil_clone\n");
14001408
CHECK_DNNL_OP_PRE_COND(self);

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class AtenIpexCPUDev {
6666
static at::Tensor dil_sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& output);
6767
static at::Tensor dil_reshape(const at::Tensor& self, at::IntArrayRef size);
6868
static at::Tensor& dil_resize_(at::Tensor& self, at::IntArrayRef size, c10::optional<c10::MemoryFormat> memory_format);
69+
static int64_t dil_size(const at::Tensor & self, int64_t dim);
6970
static at::Tensor dil_clone(const at::Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format);
7071
static at::Tensor dil_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1);
7172
static at::Tensor& dil_cat_out(at::Tensor& result, at::TensorList tensors, int64_t dim);

0 commit comments

Comments
 (0)