@@ -81,18 +81,6 @@ class tensor : public memory {
81
81
return static_cast <data_type>(data.data_type );
82
82
}
83
83
84
- inline dims get_strides () const {
85
- DIL_ENFORCE (is_plain (), " Call to_public() before get_strides()" );
86
- const auto & strides = blocking_strides ();
87
- if (!is_grouped ()) {
88
- return dims (strides, strides + data.ndims );
89
- } else {
90
- auto ret = dims (strides + 1 , strides + data.ndims );
91
- ret[0 ] = std::min (strides[0 ], strides[1 ]);
92
- return ret;
93
- }
94
- }
95
-
96
84
/* * returns true if memory descriptor is zero */
97
85
bool is_zero () const { return data.ndims == 0 ; }
98
86
@@ -379,6 +367,17 @@ class tensor : public memory {
379
367
return const_cast <dnnl_memory_desc_t &>(data).format_desc .blocking .strides ;
380
368
}
381
369
370
+ inline dims get_strides () const {
371
+ const auto & strides = blocking_strides ();
372
+ if (!is_grouped ()) {
373
+ return dims (strides, strides + data.ndims );
374
+ } else {
375
+ auto ret = dims (strides + 1 , strides + data.ndims );
376
+ ret[0 ] = std::min (strides[0 ], strides[1 ]);
377
+ return ret;
378
+ }
379
+ }
380
+
382
381
void set_g (dim groups) {
383
382
auto reserved_size = sizeof (((dnnl_memory_extra_desc_t *)0 )->reserved );
384
383
auto offset = reserved_size / sizeof (dim) - 1 ;
@@ -582,7 +581,20 @@ class tensor : public memory {
582
581
// / Returns dimension vector
583
582
inline dims get_dims () const { return get_desc ().get_dims (); }
584
583
585
- inline dims get_strides () const { return get_desc ().get_strides (); }
584
+ inline dims get_strides () const {
585
+ DIL_ENFORCE (is_public_format (), " Call to_public() before get_strides()" );
586
+ return get_desc ().get_strides ();
587
+ }
588
+
589
+ inline void set_dims_and_strides (const dims &adims, const dims &astrides) {
590
+ DIL_ENFORCE (is_public_format (), " Call to_public() before set_dims_and_strides()" );
591
+ DIL_ENFORCE (adims.size () == astrides.size (), " Dims and strides must have the same size" );
592
+ if (get_dims () == adims && get_strides () == astrides)
593
+ return ;
594
+ auto new_desc = desc (adims, get_data_type (), astrides);
595
+ DIL_ENFORCE (get_size () == new_desc.get_size (), " Invalid dims and strides for the original desc" );
596
+ set_desc (new_desc);
597
+ }
586
598
587
599
// / Return element number of the param.
588
600
// / The number is the meaning values for a tensor, instead of whole buffer.
0 commit comments