Skip to content

Commit 6d0ff45

Browse files
committed
Merge remote-tracking branch 'gitlab/master'
2 parents 593f4b9 + b154e6d commit 6d0ff45

File tree

9 files changed

+206
-86
lines changed

9 files changed

+206
-86
lines changed

intel_pytorch_extension_py/ops/lstm.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
VF_lstm = _VF.lstm
55

6-
def ipex_lstm(input, hx, _flat_weights, bias, num_layers, dropout, training, bidirectional, batch_first):
7-
if input.device.type == 'xpu' and (dropout == 0 or training == False):
8-
return torch.ops.torch_ipex.lstm(input, hx, _flat_weights, bias, num_layers, dropout, training, bidirectional, batch_first)
6+
def ipex_lstm(input, hx, _flat_weights, bias, num_layers, dropout, training, bidirectional, batch_first, device):
7+
# For LSTM training with dropout, fallback to cpu due to performance issue in oneDNN mode
8+
if training and dropout != 0:
9+
return fallback_lstm(input, hx, _flat_weights, bias, num_layers, dropout, training, bidirectional, batch_first, device=device)
910
else:
10-
return VF_lstm(input, hx, _flat_weights, bias, num_layers, dropout, training, bidirectional, batch_first)
11+
return torch.ops.torch_ipex.lstm(input, hx, _flat_weights, bias, num_layers, dropout, training, bidirectional, batch_first)
1112

1213
# users may only transfer the data but not the module to IPEX device, need to check if every item in the args is on "cpu" device
1314
def get_device(*args):
@@ -45,14 +46,14 @@ def fallback_lstm(*args, device):
4546
return tuple(output_device)
4647

4748
def lstm(*args):
49+
device = get_device(*args)
50+
if device == "cpu":
51+
return VF_lstm(*args)
52+
53+
# For LSTM with pack_padded_sequence as input, fallback to cpu due to performance issue in oneDNN mode
4854
if isinstance(args[1], torch.Tensor):
49-
# For LSTM with pack_padded_sequence as input, fallback to cpu due to performance issue in oneDNN mode
50-
device = get_device(*args)
51-
if device == "cpu":
52-
return VF_lstm(*args)
53-
else:
54-
return fallback_lstm(*args, device=device)
55+
return fallback_lstm(*args, device=device)
5556
else:
56-
return ipex_lstm(*args)
57+
return ipex_lstm(*args, device=device)
5758

5859
_VF.lstm = lstm

scripts/cpu/gen-sparse-cpu-ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def gen_code(self):
406406
if param.core_type in ['Tensor', 'Scalar']:
407407
profiler_inputs.append(param.name)
408408
code += '#if defined(IPEX_PROFILE_OP)\n'
409-
code += ' RECORD_FUNCTION("{ns}::{name}", std::vector<c10::IValue>({{{input_names}}}));\n'.format(ns=_IPEX_OP_FUNC_NS, name=cpp_sparse_sig.def_name, input_names=', '.join(profiler_inputs))
409+
code += ' RECORD_FUNCTION("{ns}::{name}", std::vector<c10::IValue>({{{input_names}}}));\n'.format(ns=_IPEX_OP_FUNC_NS, name=cpp_sparse_sig.def_name, input_names='')
410410
code += '#endif\n'
411411

412412
code += self.gen_fallback_prepare_code(cpp_sparse_sig)

tests/cpu/test_bf16_lazy_reorder.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,90 @@ def test_batch_norm3d_backward(self):
486486
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_bf16.grad))
487487
self.assertEqual(x_man_bf16.grad.float(), x_auto_mix_bf16.grad)
488488

489+
class TestLayerNorm(TestCase):
490+
def test_layer_norm(self):
491+
rand_seed = int(get_rand_seed())
492+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
493+
494+
x_cpu, x_auto_mix_inference, x_auto_mix_train, x_man_bf16, x_auto_mix_train_bf16 = _gen_tensor(
495+
rand_seed, (2, 5, 10, 10))
496+
497+
op_cpu, op_auto_mix_inference, op_auto_mix_train, op_man_bf16, op_auto_mix_train_bf16 = _gen_op(
498+
rand_seed, torch.nn.LayerNorm([10, 10]), is_bn=True)
499+
500+
ref_cpu = op_cpu(x_cpu)
501+
with AutoDNNL(True), AutoMixPrecision(False):
502+
res_bf16 = op_man_bf16(x_man_bf16)
503+
self.assertEqual(res_bf16.dtype, torch.bfloat16)
504+
505+
# FW inference
506+
with AutoMixPrecision(True, train=False):
507+
self.assertEqual(x_auto_mix_inference.dtype, torch.float)
508+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix_inference))
509+
res_auto_mix_inference = op_auto_mix_inference(x_auto_mix_inference)
510+
self.assertEqual(res_auto_mix_inference.dtype, torch.float)
511+
self.assertEqual(x_auto_mix_inference.dtype, torch.float)
512+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_mix_inference))
513+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_inference))
514+
self.assertEqual(res_bf16.float(), res_auto_mix_inference)
515+
516+
# FW train (input is not bf16 dil tensor)
517+
with AutoMixPrecision(True, train=True):
518+
self.assertEqual(x_auto_mix_train.dtype, torch.float)
519+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix_train))
520+
res_auto_mix_train = op_auto_mix_train(x_auto_mix_train)
521+
self.assertEqual(res_auto_mix_train.dtype, torch.float)
522+
self.assertEqual(x_auto_mix_train.dtype, torch.float)
523+
self.assertFalse(ipex.core.is_bf16_dil_tensor(res_auto_mix_train))
524+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix_train))
525+
self.assertEqual(ref_cpu, res_auto_mix_train)
526+
527+
# FW train (input is bf16 dil tensor)
528+
with AutoMixPrecision(True, train=True):
529+
self.assertEqual(x_auto_mix_train_bf16.dtype, torch.float)
530+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_train_bf16))
531+
res_auto_mix_train_bf16 = op_auto_mix_train_bf16(x_auto_mix_train_bf16)
532+
self.assertEqual(res_auto_mix_train_bf16.dtype, torch.float)
533+
self.assertEqual(x_auto_mix_train_bf16.dtype, torch.float)
534+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_mix_train_bf16))
535+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_train_bf16))
536+
self.assertEqual(res_bf16.float(), res_auto_mix_train_bf16)
537+
538+
def test_layer_norm_backward(self):
539+
rand_seed = int(get_rand_seed())
540+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
541+
x_cpu, _, x_auto_mix, x_man_bf16, x_auto_mix_bf16 = _gen_tensor(rand_seed, (2, 5, 10, 10), is_forward=False)
542+
543+
op_cpu, _, op_auto_mix, op_man_bf16, op_auto_mix_bf16 = _gen_op(rand_seed, torch.nn.LayerNorm([10, 10]), is_bn=True, is_forward=False)
544+
545+
out_cpu = op_cpu(x_cpu).sum()
546+
out_cpu.backward()
547+
with AutoDNNL(True), AutoMixPrecision(False, train=True):
548+
out_man_bf16 = op_man_bf16(x_man_bf16).sum()
549+
out_man_bf16.backward()
550+
self.assertEqual(x_man_bf16.grad.dtype, torch.bfloat16)
551+
self.assertEqual(x_cpu.grad.bfloat16().float(), x_man_bf16.grad, 1e-2)
552+
553+
# BW train (input is not bf16 dil tensor)
554+
with AutoMixPrecision(True, train=True):
555+
self.assertEqual(x_auto_mix.dtype, torch.float)
556+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix))
557+
out_auto_mix = op_auto_mix(x_auto_mix).sum()
558+
out_auto_mix.backward()
559+
self.assertEqual(x_auto_mix.grad.dtype, torch.float)
560+
self.assertFalse(ipex.core.is_bf16_dil_tensor(x_auto_mix.grad))
561+
self.assertEqual(x_cpu.grad, x_auto_mix.grad)
562+
563+
# BW train (input is bf16 dil tensor)
564+
with AutoMixPrecision(True, train=True):
565+
self.assertEqual(x_auto_mix_bf16.dtype, torch.float)
566+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_bf16))
567+
out_auto_mix_bf16 = op_auto_mix_bf16(x_auto_mix_bf16).sum()
568+
out_auto_mix_bf16.backward()
569+
self.assertEqual(x_auto_mix_bf16.grad.dtype, torch.float)
570+
self.assertTrue(ipex.core.is_bf16_dil_tensor(x_auto_mix_bf16.grad))
571+
self.assertEqual(x_man_bf16.grad.float(), x_auto_mix_bf16.grad)
572+
489573
class TestRelu(TestCase):
490574
def test_relu(self):
491575
rand_seed = int(get_rand_seed())

torch_ipex/csrc/aten_ipex_bridge.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace bridge {
2626
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.unsafeGetTensorImpl()->dtype() == b.unsafeGetTensorImpl()->dtype()); \
2727
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.unsafeGetTensorImpl()->is_contiguous() == b.unsafeGetTensorImpl()->is_contiguous()); \
2828
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.unsafeGetTensorImpl()->is_contiguous(at::MemoryFormat::ChannelsLast) == b.unsafeGetTensorImpl()->is_contiguous(at::MemoryFormat::ChannelsLast)); \
29+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.unsafeGetTensorImpl()->is_contiguous(at::MemoryFormat::ChannelsLast3d) == b.unsafeGetTensorImpl()->is_contiguous(at::MemoryFormat::ChannelsLast3d)); \
2930
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.unsafeGetTensorImpl()->is_strides_like_channels_last() == b.unsafeGetTensorImpl()->is_strides_like_channels_last()); \
3031
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.unsafeGetTensorImpl()->is_non_overlapping_and_dense() == b.unsafeGetTensorImpl()->is_non_overlapping_and_dense()); \
3132
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.unsafeGetTensorImpl()->is_wrapped_number() == b.unsafeGetTensorImpl()->is_wrapped_number()); \

0 commit comments

Comments
 (0)