Skip to content

Commit a79be1b

Browse files
authored
Merge pull request #24 from pinzhenx/optional
use optional dil tensor & move semantics
2 parents fbc3b7a + 5099b96 commit a79be1b

File tree

7 files changed

+62
-58
lines changed

7 files changed

+62
-58
lines changed

torch_ipex/csrc/aten_ipex_bridge.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,13 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
6666
void *data_ctx = ipexTensor.unsafeGetTensorImpl()->storage().data_ptr().get_context();
6767
cpu::ShadeDataContext *shade_data_context = (cpu::ShadeDataContext*)data_ctx;
6868
#if defined(_DEBUG)
69-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_empty()));
69+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor->is_empty()));
7070
#endif
71-
dil::tensor &dil_tensor = shade_data_context->dil_tensor;
71+
dil::tensor &dil_tensor = *shade_data_context->dil_tensor;
7272

7373
if (dil_tensor.is_public_format()) {
7474
#if defined(_DEBUG)
75-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor.get_data_handle());
75+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_raw_data == shade_data_context->dil_tensor->get_data_handle());
7676
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_raw_data != nullptr);
7777
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->cpu_del_fun != nullptr);
7878
#endif
@@ -106,7 +106,7 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
106106
}
107107

108108

109-
void attachShadeDataConext(const at::Tensor& tensor) {
109+
void attachShadeDataContext(const at::Tensor& tensor) {
110110
auto tensor_storage_impl = tensor.storage().unsafeGetStorageImpl();
111111
auto& data_ptr = tensor_storage_impl->data_ptr();
112112

@@ -272,7 +272,7 @@ at::Tensor shallowUpgradeToDPCPPTensor(const at::Tensor& cpuTensor) {
272272
CHECK_TENSOR_CRITICAL(_tensor, cpuTensor, true);
273273
//TODO: Cannot set reserved_
274274
// dest_impl->reserved_ = src_impl->reserved_;
275-
attachShadeDataConext(_tensor);
275+
attachShadeDataContext(_tensor);
276276
return _tensor;
277277
}
278278
}
@@ -303,7 +303,7 @@ at::Tensor shallowUpgradeToDPCPPTensorA(const at::Tensor& ipexTensor, const at::
303303
ipex_impl->copy_auto_grad(cpuTensor.unsafeGetTensorImpl());
304304
CHECK_TENSOR_CRITICAL(_tensor, cpuTensor, true);
305305

306-
attachShadeDataConext(_tensor);
306+
attachShadeDataContext(_tensor);
307307
return _tensor;
308308
}
309309

@@ -388,7 +388,7 @@ const at::Tensor& shallowUpgradeToDPCPPTensorAW(const at::Tensor& ipexTensor, co
388388
ipex_tensor_impl->copy_meta_info(cpuTensor.unsafeGetTensorImpl());
389389
ipex_tensor_impl->copy_auto_grad(cpuTensor.unsafeGetTensorImpl());
390390
CHECK_TENSOR_CRITICAL(ipexTensor, cpuTensor, true);
391-
attachShadeDataConext(ipexTensor);
391+
attachShadeDataContext(ipexTensor);
392392
return ipexTensor;
393393
}
394394
}
@@ -417,7 +417,7 @@ void reorderTensorToScalarTypeForDNNL(const at::Tensor& ipexTensor, at::ScalarTy
417417
// Shade data context has been attached
418418
if (cpu::ShadeDataContext::isDilTensor(ipexTensor)) {
419419
cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(ipexTensor.storage().data_ptr().get_context());
420-
shade_context->dil_tensor.to_type(get_dil_data_type(dstScalarType));
420+
shade_context->dil_tensor->to_type(get_dil_data_type(dstScalarType));
421421
IPEXTensorImpl* ipex_tensor_impl = (IPEXTensorImpl *)ipexTensor.unsafeGetTensorImpl();
422422
ipex_tensor_impl->reset_data_type(dstScalarType);
423423
ipex_tensor_impl->storage().unsafeGetStorageImpl()->set_dtype(at::scalarTypeToTypeMeta(dstScalarType));

torch_ipex/csrc/aten_ipex_bridge.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace bridge {
1313
at::Tensor shallowFallbackToCPUTensor(const at::Tensor& ipexTensor);
1414
std::vector<at::Tensor> shallowFallbackToCPUTensorList(const at::TensorList&);
1515

16-
void attachShadeDataConext(const at::Tensor& tensor);
16+
void attachShadeDataContext(const at::Tensor& tensor);
1717

1818
/**
1919
* Reorder the DNNL tensor to the public format if the input tensor contains DNNL tensor.

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
6464
dilation,
6565
groups);
6666

67-
return dbl::comm::gen_aten_tensor_by(dil_output);
67+
return dbl::comm::gen_aten_tensor_by(std::move(dil_output));
6868
}
6969

7070
at::Tensor dil_convolution_backward_input(
@@ -87,7 +87,7 @@ at::Tensor dil_convolution_backward_input(
8787
padding.vec(),
8888
padding.vec(),
8989
groups);
90-
return dbl::comm::gen_aten_tensor_by(dil_grad_input);
90+
return dbl::comm::gen_aten_tensor_by(std::move(dil_grad_input));
9191
}
9292

9393
std::tuple<at::Tensor, at::Tensor> dil_convolution_backward_weights(
@@ -117,8 +117,8 @@ std::tuple<at::Tensor, at::Tensor> dil_convolution_backward_weights(
117117
groups,
118118
diff_weight_type);
119119
return std::make_tuple(
120-
dbl::comm::gen_aten_tensor_by(dil_grad_weight),
121-
dbl::comm::gen_aten_tensor_by(dil_grad_bias));
120+
dbl::comm::gen_aten_tensor_by(std::move(dil_grad_weight)),
121+
dbl::comm::gen_aten_tensor_by(std::move(dil_grad_bias)));
122122
} else {
123123
dil::convolution_backward_weights::compute(
124124
dil_input,
@@ -132,7 +132,7 @@ std::tuple<at::Tensor, at::Tensor> dil_convolution_backward_weights(
132132
groups,
133133
diff_weight_type);
134134
return std::make_tuple(
135-
dbl::comm::gen_aten_tensor_by(dil_grad_weight),
135+
dbl::comm::gen_aten_tensor_by(std::move(dil_grad_weight)),
136136
at::Tensor());
137137
}
138138
}
@@ -255,7 +255,7 @@ at::Tensor AtenIpexCPUDev::dil_add(const at::Tensor& self, const at::Tensor& oth
255255
const std::vector<float> scales{1.0, alpha.to<float>()};
256256
dil::sum::compute(scales, {x, y}, z);
257257

258-
return dbl::comm::gen_aten_tensor_by(z);
258+
return dbl::comm::gen_aten_tensor_by(std::move(z));
259259
}
260260

261261
at::Tensor & AtenIpexCPUDev::dil_add_(at::Tensor& self, const at::Tensor& other, at::Scalar alpha) {
@@ -552,9 +552,9 @@ at::Tensor AtenIpexCPUDev::dil_linear(
552552
output_size.push_back(weight.size(0));
553553

554554
if (self.dim() > 2) {
555-
return dbl::comm::gen_aten_tensor_by(y).reshape(output_size);
555+
return dbl::comm::gen_aten_tensor_by(std::move(y)).reshape(output_size);
556556
}
557-
return dbl::comm::gen_aten_tensor_by(y);
557+
return dbl::comm::gen_aten_tensor_by(std::move(y));
558558
}
559559

560560
at::Tensor dil_linear_backward_input(
@@ -574,9 +574,9 @@ at::Tensor dil_linear_backward_input(
574574
grady, w, {input_reshaped_size.begin(), input_reshaped_size.end()}, gradx);
575575

576576
if (input_size.size() > 2) {
577-
return dbl::comm::gen_aten_tensor_by(gradx).reshape(input_size);
577+
return dbl::comm::gen_aten_tensor_by(std::move(gradx)).reshape(input_size);
578578
}
579-
return dbl::comm::gen_aten_tensor_by(gradx);
579+
return dbl::comm::gen_aten_tensor_by(std::move(gradx));
580580
}
581581

582582
std::tuple<at::Tensor, at::Tensor> dil_linear_backward_weights(
@@ -593,12 +593,12 @@ std::tuple<at::Tensor, at::Tensor> dil_linear_backward_weights(
593593
if (bias_defined) {
594594
dil::inner_product_backward_weights::compute(x, grady, gradw, gradb, diff_weight_type);
595595
return std::tuple<at::Tensor, at::Tensor>{
596-
dbl::comm::gen_aten_tensor_by(gradw),
597-
dbl::comm::gen_aten_tensor_by(gradb)};
596+
dbl::comm::gen_aten_tensor_by(std::move(gradw)),
597+
dbl::comm::gen_aten_tensor_by(std::move(gradb))};
598598
} else {
599599
dil::inner_product_backward_weights::compute(x, grady, gradw, diff_weight_type);
600600
return std::tuple<at::Tensor, at::Tensor>{
601-
dbl::comm::gen_aten_tensor_by(gradw),
601+
dbl::comm::gen_aten_tensor_by(std::move(gradw)),
602602
at::Tensor()};
603603
}
604604
}
@@ -632,8 +632,8 @@ std::tuple<at::Tensor, at::Tensor> _dil_dropout(
632632
dil::tensor y;
633633
dil::dropout_forward::compute(x, ratio, y, mask);
634634
return std::tuple<at::Tensor, at::Tensor>{
635-
dbl::comm::gen_aten_tensor_by(y),
636-
dbl::comm::gen_aten_tensor_by(mask)};
635+
dbl::comm::gen_aten_tensor_by(std::move(y)),
636+
dbl::comm::gen_aten_tensor_by(std::move(mask))};
637637
}
638638

639639
at::Tensor AtenIpexCPUDev::dil_dropout(const at::Tensor& self, double ratio, bool train) {
@@ -657,7 +657,7 @@ at::Tensor AtenIpexCPUDev::dil_dropout_backward(
657657

658658
dil::tensor dX;
659659
dil::dropout_backward::compute(mask_dil, dY, dX);
660-
return dbl::comm::gen_aten_tensor_by(dX);
660+
return dbl::comm::gen_aten_tensor_by(std::move(dX));
661661
}
662662

663663
std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_norm(
@@ -696,9 +696,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
696696
dil::sum::compute(scales_var, {v, saved_var}, v);
697697
}
698698
return std::make_tuple(
699-
dbl::comm::gen_aten_tensor_by(y),
700-
dbl::comm::gen_aten_tensor_by(saved_mean),
701-
dbl::comm::gen_aten_tensor_by(saved_var));
699+
dbl::comm::gen_aten_tensor_by(std::move(y)),
700+
dbl::comm::gen_aten_tensor_by(std::move(saved_mean)),
701+
dbl::comm::gen_aten_tensor_by(std::move(saved_var)));
702702
} else {
703703
if (use_running_stat) {
704704
dil::tensor m = dbl::comm::try_gen_dil_tensor(running_mean);
@@ -710,7 +710,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
710710
x, w, b, y, eps);
711711
}
712712
return std::make_tuple(
713-
dbl::comm::gen_aten_tensor_by(y),
713+
dbl::comm::gen_aten_tensor_by(std::move(y)),
714714
at::Tensor(),
715715
at::Tensor());
716716
}
@@ -742,9 +742,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
742742
x, m, v, grady, w, gradx, gradw, gradb, eps);
743743

744744
return std::make_tuple(
745-
dbl::comm::gen_aten_tensor_by(gradx),
746-
dbl::comm::gen_aten_tensor_by(gradw),
747-
dbl::comm::gen_aten_tensor_by(gradb));
745+
dbl::comm::gen_aten_tensor_by(std::move(gradx)),
746+
dbl::comm::gen_aten_tensor_by(std::move(gradw)),
747+
dbl::comm::gen_aten_tensor_by(std::move(gradb)));
748748
}
749749

750750
at::Tensor AtenIpexCPUDev::dil_max_pooling(
@@ -969,7 +969,7 @@ at::Tensor AtenIpexCPUDev::dil_relu(const at::Tensor& input) {
969969
dil::tensor y;
970970
dil::eltwise_forward::compute(
971971
x, y, dil::algorithm::eltwise_relu, dil::prop_kind::forward_training, /*alpha*/ 0.0);
972-
return dbl::comm::gen_aten_tensor_by(y);
972+
return dbl::comm::gen_aten_tensor_by(std::move(y));
973973
}
974974

975975
at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
@@ -998,7 +998,7 @@ at::Tensor AtenIpexCPUDev::dil_threshold_backward(const at::Tensor& grad_output,
998998
dil::tensor gradx;
999999
dil::eltwise_backward::compute(x, grady, gradx,
10001000
dil::algorithm::eltwise_relu, /*alpha*/ 0.0);
1001-
return dbl::comm::gen_aten_tensor_by(gradx);
1001+
return dbl::comm::gen_aten_tensor_by(std::move(gradx));
10021002
}
10031003

10041004
at::Tensor AtenIpexCPUDev::dil__softmax(
@@ -1014,7 +1014,7 @@ at::Tensor AtenIpexCPUDev::dil__softmax(
10141014
dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
10151015
dil::tensor y;
10161016
dil::softmax_forward::compute(x, y, wrapped_dim);
1017-
return dbl::comm::gen_aten_tensor_by(y);
1017+
return dbl::comm::gen_aten_tensor_by(std::move(y));
10181018
}
10191019

10201020
at::Tensor AtenIpexCPUDev::dil__softmax_backward_data(
@@ -1032,7 +1032,7 @@ at::Tensor AtenIpexCPUDev::dil__softmax_backward_data(
10321032
dil::tensor grady = dbl::comm::try_gen_dil_tensor(grad_output_contiguous);
10331033
dil::tensor gradx;
10341034
dil::softmax_backward::compute(y, grady, gradx, wrapped_dim);
1035-
return dbl::comm::gen_aten_tensor_by(gradx);
1035+
return dbl::comm::gen_aten_tensor_by(std::move(gradx));
10361036
}
10371037

10381038
at::Tensor AtenIpexCPUDev::dil_sigmoid(const at::Tensor& self) {
@@ -1042,7 +1042,7 @@ at::Tensor AtenIpexCPUDev::dil_sigmoid(const at::Tensor& self) {
10421042
dil::tensor y;
10431043
dil::eltwise_forward::compute(
10441044
x, y, dil::algorithm::eltwise_logistic_use_dst_for_bwd, dil::prop_kind::forward);
1045-
return dbl::comm::gen_aten_tensor_by(y);
1045+
return dbl::comm::gen_aten_tensor_by(std::move(y));
10461046
}
10471047

10481048
at::Tensor& AtenIpexCPUDev::dil_sigmoid_(at::Tensor& self) {
@@ -1069,7 +1069,7 @@ at::Tensor AtenIpexCPUDev::dil_sigmoid_backward(
10691069
dil::tensor gx;
10701070
dil::eltwise_backward::compute(y, gy, gx,
10711071
dil::algorithm::eltwise_logistic_use_dst_for_bwd);
1072-
return dbl::comm::gen_aten_tensor_by(gx);
1072+
return dbl::comm::gen_aten_tensor_by(std::move(gx));
10731073
}
10741074

10751075
at::Tensor AtenIpexCPUDev::dil_reshape(const at::Tensor& self, at::IntArrayRef size) {
@@ -1082,7 +1082,7 @@ at::Tensor AtenIpexCPUDev::dil_reshape(const at::Tensor& self, at::IntArrayRef s
10821082
const dil::tensor x = dbl::comm::try_gen_dil_tensor(self);
10831083
dil::tensor y{x};
10841084
y.reshape(inferred_size);
1085-
return dbl::comm::gen_aten_tensor_by(y);
1085+
return dbl::comm::gen_aten_tensor_by(std::move(y));
10861086
}
10871087

10881088
at::Tensor AtenIpexCPUDev::dil_clone(const at::Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
@@ -1095,7 +1095,7 @@ at::Tensor AtenIpexCPUDev::dil_clone(const at::Tensor& self, c10::optional<c10::
10951095
dil::tensor src = dbl::comm::try_gen_dil_tensor(self);
10961096
dil::tensor dst;
10971097
dil::direct_copy::compute(src, dst);
1098-
return dbl::comm::gen_aten_tensor_by(dst);
1098+
return dbl::comm::gen_aten_tensor_by(std::move(dst));
10991099
}
11001100

11011101
at::Tensor AtenIpexCPUDev::dil_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1) {
@@ -1110,7 +1110,7 @@ at::Tensor AtenIpexCPUDev::dil_transpose(const at::Tensor & self, int64_t dim0,
11101110
dim1 = at::maybe_wrap_dim(dim1, self.dim());
11111111
std::swap(axes[dim0], axes[dim1]);
11121112
y.transpose_from(x, axes);
1113-
return dbl::comm::gen_aten_tensor_by(y);
1113+
return dbl::comm::gen_aten_tensor_by(std::move(y));
11141114
}
11151115

11161116
inline void check_cat_no_zero_dim(at::TensorList tensors) {
@@ -1154,7 +1154,7 @@ at::Tensor AtenIpexCPUDev::dil_cat(at::TensorList tensors, int64_t dim) {
11541154
}
11551155
dil::tensor y;
11561156
dil::concat::compute(x, dim, y);
1157-
return dbl::comm::gen_aten_tensor_by(y);
1157+
return dbl::comm::gen_aten_tensor_by(std::move(y));
11581158
}
11591159

11601160
std::vector<at::Tensor> AtenIpexCPUDev::dil_split_with_sizes(const at::Tensor& self, at::IntArrayRef split_sizes, int64_t dim) {
@@ -1175,7 +1175,7 @@ std::vector<at::Tensor> AtenIpexCPUDev::dil_split_with_sizes(const at::Tensor& s
11751175
dim = at::maybe_wrap_dim(dim, self.dim());
11761176
auto y = dil::spliter::compute(x, sizes, dim, false);
11771177
for (auto j = 0; j < num_splits; j++) {
1178-
splits[j] = dbl::comm::gen_aten_tensor_by(y[j]);
1178+
splits[j] = dbl::comm::gen_aten_tensor_by(std::move(y[j]));
11791179
}
11801180
return splits;
11811181
}

torch_ipex/csrc/cpu/ShadeDataContext.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <ATen/Tensor.h>
44
#include <c10/util/Exception.h>
5+
#include <c10/util/Optional.h>
56

67
#include "dil/dil.hpp"
78

@@ -13,9 +14,9 @@ namespace cpu {
1314
enum SHADE_DATA_TYPE {CPU_RAW, DIL};
1415

1516
struct ShadeDataContext {
17+
c10::optional<dil::tensor> dil_tensor; ///< DNNL memory buffer for lazy reorder
1618
void *cpu_raw_data; ///< The raw memory buffer of storage
1719
c10::DeleterFnPtr cpu_del_fun; ///< Delete function to release cpu_raw_data
18-
dil::tensor dil_tensor; ///< DNNL memory buffer for lazy reorder
1920

2021
SHADE_DATA_TYPE data_type; ///< Memory buffer type
2122

@@ -26,9 +27,10 @@ struct ShadeDataContext {
2627

2728
~ShadeDataContext() {
2829
if (this->data_type == SHADE_DATA_TYPE::DIL) { // DIL Tensor
29-
if (this->dil_tensor.is_public_format()) {
30+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->dil_tensor.has_value());
31+
if (this->dil_tensor->is_public_format()) {
3032
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->cpu_raw_data != nullptr);
31-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->dil_tensor.get_data_handle() == this->cpu_raw_data);
33+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->dil_tensor->get_data_handle() == this->cpu_raw_data);
3234
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(this->cpu_del_fun == &(c10::detail::deleteNothing));
3335
} else {
3436
// If dil tensor is block format, the cpu raw data means nothing here.
@@ -90,16 +92,17 @@ struct ShadeDataContext {
9092
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((data_type == SHADE_DATA_TYPE::CPU_RAW) || (data_type == SHADE_DATA_TYPE::DIL));
9193

9294
if (data_type == SHADE_DATA_TYPE::DIL) {
95+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.has_value());
9396
auto raw_cpu_data = tensor.storage().data_ptr().get();
9497
if (raw_cpu_data == nullptr) {
9598
// the dnnl tensor does not share data with raw tensor data.
96-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_empty()));
97-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor.is_public_format()));
99+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor->is_empty()));
100+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(! (shade_data_context->dil_tensor->is_public_format()));
98101
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(check_tensor_own_whole_storage(tensor));
99102
return true;
100103
} else {
101104
// The dnnl tensor shares some data with raw tensor.
102-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.is_public_format());
105+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor->is_public_format());
103106

104107
// For the case:
105108
// 1. There is a tensor named A
@@ -113,7 +116,7 @@ struct ShadeDataContext {
113116
// All these tensors share same buffer of Tensor A with different storge offsets and elements.
114117
// So the context modification will impact all these tensors.
115118
if (check_tensor_own_whole_storage(tensor)) {
116-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.get_size() == tensor.storage().capacity());
119+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor->get_size() == tensor.storage().capacity());
117120
return true;
118121
}
119122
}
@@ -139,13 +142,14 @@ struct ShadeDataContext {
139142
* @return If the input tensor does not contain DNNL buffer, the function will return
140143
* an empty DNNL buffer. The caller should check the return buffer is empty or not.
141144
*/
142-
static inline dil::tensor getDilTensor(const at::Tensor &tensor) {
145+
static inline dil::tensor& getDilTensor(const at::Tensor &tensor) {
143146
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.has_storage());
144147
void *raw_context = tensor.storage().data_ptr().get_context();
145148
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(raw_context != nullptr);
146149
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isDilTensor(tensor));
147150
ShadeDataContext *shade_data_context = (ShadeDataContext*)raw_context;
148-
return shade_data_context->dil_tensor;
151+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_data_context->dil_tensor.has_value());
152+
return *(shade_data_context->dil_tensor);
149153
}
150154

151155
/**

0 commit comments

Comments
 (0)