@@ -164,7 +164,6 @@ struct inner_product_forward : public dnnl::inner_product_forward {
164
164
}
165
165
} else {
166
166
op_attr = attr;
167
- src_desc = {src.get_dims (), data_type::f32 , format_tag::any};
168
167
if (src.has_scale ()) {
169
168
auto src_scale = src.get_scale ();
170
169
src_scale[0 ] = 1 .f / src_scale[0 ];
@@ -178,56 +177,50 @@ struct inner_product_forward : public dnnl::inner_product_forward {
178
177
// align weights data type with src
179
178
dst_data_type = src.get_data_type () == data_type::bf16 ? data_type::bf16
180
179
: data_type::f32 ;
181
- src_desc = src.get_desc ().to_type (dst_data_type). to_format_any () ;
182
- weights_desc = weights.get_desc ().to_type (dst_data_type). to_format_any () ;
180
+ src_desc = src.get_desc ().to_type (dst_data_type);
181
+ weights_desc = weights.get_desc ().to_type (dst_data_type);
183
182
if (with_bias) {
184
183
IDEEP_ENFORCE (utils::one_of (bias.get_data_type (),
185
184
data_type::f32 , data_type::bf16 ),
186
185
" Incorrect data type in bias" );
187
- bias_desc = bias.get_desc (). to_format_any () ;
186
+ bias_desc = bias.get_desc ();
188
187
}
189
188
}
190
189
191
- tensor::desc dst_desc (dst_dims, dst_data_type, format_tag::any );
190
+ tensor::desc dst_desc = dst. get_desc (). to_type (dst_data_type );
192
191
auto pd = with_bias
193
192
? primitive_desc ({aprop_kind, src_desc, weights_desc, bias_desc,
194
193
dst_desc}, op_attr, aengine)
195
194
: primitive_desc ({aprop_kind, src_desc, weights_desc, dst_desc},
196
195
op_attr, aengine);
197
196
198
- auto expected_src = src.reorder_if_differ_in (pd.src_desc (), src_attr);
199
- auto expected_weights = weights.reorder_if_differ_in (pd.weights_desc (), weights_attr);
200
197
// [ Note output buffer ]
201
198
// In this case, dst is an empty ideep tensor, can be re-init
202
199
// If dst is not empty, ideep must write result to dst's memory and it is caller's duty to
203
200
// make sure dst is big enough to hold the result
204
201
if (dst.is_empty ()) {
205
202
dst.init (pd.dst_desc ());
206
203
}
207
- auto expected_dst = dst.reorder_if_differ_in (pd.dst_desc ());
208
- if (!dst_scales.empty () && utils::one_of (dst.get_data_type (), data_type::u8 , data_type::s8)) {
209
- expected_dst.set_scale (dst_scales_in);
204
+
205
+ if (!dst_scales.empty () &&
206
+ utils::one_of (dst.get_data_type (), data_type::u8 , data_type::s8)) {
207
+ dst.set_scale (dst_scales_in);
210
208
}
211
209
212
210
if (with_bias){
213
- auto expected_bias = bias.reorder_if_differ_in (pd.bias_desc (), bias_attr);
214
- super (pd).execute (stream::default_stream (),
215
- {{DNNL_ARG_SRC, expected_src},
216
- {DNNL_ARG_WEIGHTS, expected_weights},
217
- {DNNL_ARG_BIAS, expected_bias},
218
- {DNNL_ARG_DST, expected_dst}});
211
+ super (pd).execute (stream::default_stream (), {{DNNL_ARG_SRC, src},
212
+ {DNNL_ARG_WEIGHTS, weights},
213
+ {DNNL_ARG_BIAS, bias},
214
+ {DNNL_ARG_DST, dst}});
219
215
} else {
220
- super (pd).execute (stream::default_stream (),
221
- {{DNNL_ARG_SRC, expected_src},
222
- {DNNL_ARG_WEIGHTS, expected_weights},
223
- {DNNL_ARG_DST, expected_dst}});
216
+ super (pd).execute (stream::default_stream (), {{DNNL_ARG_SRC, src},
217
+ {DNNL_ARG_WEIGHTS, weights},
218
+ {DNNL_ARG_DST, dst}});
224
219
}
225
220
226
- if (attr.non_negitive_output () && expected_dst .get_data_type () == data_type::s8) {
227
- expected_dst .to_type (data_type::u8 );
221
+ if (attr.non_negitive_output () && dst .get_data_type () == data_type::s8) {
222
+ dst .to_type (data_type::u8 );
228
223
}
229
- // reorder back to dst's buffer if needed
230
- expected_dst.reorder_to_if_differ_from (dst);
231
224
}
232
225
};
233
226
@@ -242,11 +235,6 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
242
235
tensor& diff_src,
243
236
const engine& aengine = engine::cpu_engine()) {
244
237
auto weights_ = weights;
245
- if (diff_dst.get_data_type () == data_type::bf16 ) {
246
- weights_.init (weights.get_desc ().to_type (data_type::bf16 ));
247
- weights_.reorder_from (weights);
248
- }
249
-
250
238
// workaround: diff_src and weights from caffe2 may have different dims.
251
239
// It would be better for caffe2 to do this reshape anyway.
252
240
if (diff_src_dims.size () != weights.ndims ()) {
@@ -255,10 +243,9 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
255
243
weights_.reshape (new_dims);
256
244
}
257
245
258
- auto diff_dst_desc = diff_dst.get_desc ().to_format_any ();
259
- auto weights_desc = weights_.get_desc ().to_format_any ();
260
- auto diff_src_desc =
261
- tensor::desc (diff_src_dims, diff_dst.get_data_type (), tag::any);
246
+ auto diff_dst_desc = diff_dst.get_desc ();
247
+ auto weights_desc = weights_.get_desc ();
248
+ auto diff_src_desc = diff_src.get_desc ().to_type (diff_dst.get_data_type ());
262
249
263
250
auto forward_hints =
264
251
inner_product_forward::primitive_desc (
@@ -268,8 +255,6 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
268
255
auto pd = primitive_desc (
269
256
{diff_src_desc, weights_desc, diff_dst_desc}, aengine, forward_hints);
270
257
271
- auto expected_diff_dst = diff_dst.reorder_if_differ_in (pd.diff_dst_desc ());
272
- auto expected_weights = weights_.reorder_if_differ_in (pd.weights_desc ());
273
258
// diff_src's origin content are not used, so it can be re-init directly
274
259
// It's caller's duty to make sure diff_src's buffer size is same with it actually needed
275
260
// Here we dose not support to write to given strided buffer since we know the grad is always contiguous
@@ -280,8 +265,8 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
280
265
}
281
266
282
267
super (pd).execute (stream::default_stream (),
283
- {{DNNL_ARG_DIFF_DST, expected_diff_dst },
284
- {DNNL_ARG_WEIGHTS, expected_weights },
268
+ {{DNNL_ARG_DIFF_DST, diff_dst },
269
+ {DNNL_ARG_WEIGHTS, weights_ },
285
270
{DNNL_ARG_DIFF_SRC, diff_src}});
286
271
}
287
272
};
@@ -319,18 +304,17 @@ struct inner_product_backward_weights
319
304
tensor& diff_bias,
320
305
const data_type diff_weight_type,
321
306
const engine& aengine = engine::cpu_engine()) {
322
- auto src_desc = src.get_desc (). to_format_any () ;
323
- auto diff_dst_desc = diff_dst.get_desc (). to_format_any () ;
307
+ auto src_desc = src.get_desc ();
308
+ auto diff_dst_desc = diff_dst.get_desc ();
324
309
auto diff_weights_dims = src.get_dims ();
325
310
diff_weights_dims[0 ] = diff_dst.get_dim (1 );
326
311
data_type diff_dst_type = diff_dst.get_data_type ();
327
312
data_type diff_weight_type_in = data_type::undef== diff_weight_type ?
328
313
diff_dst_type : diff_weight_type;
329
- auto diff_weights_desc =
330
- tensor::desc (diff_weights_dims, diff_weight_type_in, tag::any);
331
314
332
- auto diff_bias_desc =
333
- tensor::desc ({diff_dst.get_dim (1 )}, diff_weight_type_in, tag::any);
315
+ auto diff_weights_desc =
316
+ diff_weights.get_desc ().to_type (diff_weight_type_in);
317
+ auto diff_bias_desc = diff_bias.get_desc ().to_type (diff_weight_type_in);
334
318
335
319
// for forward hint, weights_desc should have same data_type
336
320
// with other input desc, expect for bias_desc
@@ -349,18 +333,13 @@ struct inner_product_backward_weights
349
333
: primitive_desc ({src_desc, diff_weights_desc, diff_dst_desc},
350
334
aengine, forward_hints);
351
335
352
- auto expected_diff_dst = diff_dst.reorder_if_differ_in (pd.diff_dst_desc ());
353
- auto expected_src = src.reorder_if_differ_in (pd.src_desc ());
354
336
if (diff_weights.is_empty ()){
355
337
diff_weights.init (pd.diff_weights_desc ());
356
338
}
357
- // Here we need to write to given strided buffer, so if given buffer is different with the best format
358
- // We need to firstly init a new buffer to store the output, and copy the output to a given buffer
359
- tensor expected_diff_weights = diff_weights.get_desc () == pd.diff_weights_desc () ? diff_weights : tensor (pd.diff_weights_desc ());
360
339
361
- exec_args args {{DNNL_ARG_DIFF_DST, expected_diff_dst },
362
- {DNNL_ARG_SRC, expected_src },
363
- {DNNL_ARG_DIFF_WEIGHTS ,expected_diff_weights }};
340
+ exec_args args{{DNNL_ARG_DIFF_DST, diff_dst },
341
+ {DNNL_ARG_SRC, src },
342
+ {DNNL_ARG_DIFF_WEIGHTS, diff_weights }};
364
343
365
344
if (with_diff_bias) {
366
345
if (diff_bias.is_empty ()){
@@ -373,7 +352,6 @@ struct inner_product_backward_weights
373
352
}
374
353
375
354
super (pd).execute (stream::default_stream (), args);
376
- expected_diff_weights.reorder_to_if_differ_from (diff_weights);
377
355
}
378
356
};
379
357
0 commit comments