Skip to content

Commit ae3c00a

Browse files
Enable lstm bf16 and fp32 in cpu device (#24)
* enable fp32 lstm in cpu device * lstm enable bf16 * Implement unit test * add gather into black list * Remove unnecessary lines and move test case position * hook at module level * copy _flat_weights into IpexLSTM # model.bias_ih_l0 will be incorrect * add fp32 unit test * refactor LSTM UT * update comments Co-authored-by: chunyuan <chunyuan.wu@intel.com>
1 parent 268a6df commit ae3c00a

File tree

9 files changed

+658
-11
lines changed

9 files changed

+658
-11
lines changed

ideep/ideep/abstract_types.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ using batch_normalization_flag = dnnl::normalization_flags;
3636
using query = dnnl::query;
3737
using scale_t = std::vector<float>;
3838
using exec_args = std::unordered_map<int, memory>;
39+
using rnn_direction = dnnl::rnn_direction;
3940

4041
// for computation cache
4142
using key_t = std::string;

ideep/ideep/operators/lstm.hpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,53 @@
44
namespace ideep {
55

66
struct lstm_forward : public dnnl::lstm_forward {
7-
static void compute() {
8-
}
7+
using super = dnnl::lstm_forward;
8+
9+
static void compute(const tensor& src_layer,
10+
const tensor& src_iter,
11+
const tensor& src_iter_c,
12+
const tensor& weights_layer,
13+
const tensor& weights_iter,
14+
const tensor& bias,
15+
tensor& dst_layer,
16+
tensor& dst_iter,
17+
tensor& dst_iter_c,
18+
const bool reverse = false,
19+
const prop_kind aprop = prop_kind::forward_inference,
20+
const engine& aengine = engine::cpu_engine()) {
21+
auto direction = reverse ? rnn_direction::unidirectional_right2left
22+
: rnn_direction::unidirectional_left2right;
23+
auto src_layer_desc = src_layer.get_desc();
24+
auto src_iter_desc = src_iter.get_desc();
25+
auto src_iter_c_desc = src_iter_c.get_desc();
26+
// use any format for weights
27+
auto weights_layer_desc = weights_layer.get_desc().to_format_any();
28+
auto weights_iter_desc = weights_iter.get_desc().to_format_any();
29+
auto bias_desc = bias.get_desc();
30+
auto dst_layer_desc = dst_layer.get_desc();
31+
auto dst_iter_desc = dst_iter.get_desc();
32+
auto dst_iter_c_desc = dst_iter_c.get_desc();
33+
34+
auto pd = primitive_desc(
35+
{aprop, direction, src_layer_desc, src_iter_desc, src_iter_c_desc,
36+
weights_layer_desc, weights_iter_desc, bias_desc,
37+
dst_layer_desc, dst_iter_desc, dst_iter_c_desc},
38+
aengine);
39+
40+
auto expected_weights_layer = weights_layer.reorder_if_differ_in(pd.weights_desc());
41+
auto expected_weights_iter = weights_iter.reorder_if_differ_in(pd.weights_iter_desc());
42+
43+
super(pd).execute(stream::default_stream(),
44+
{{DNNL_ARG_SRC_LAYER, src_layer},
45+
{DNNL_ARG_SRC_ITER, src_iter},
46+
{DNNL_ARG_SRC_ITER_C, src_iter_c},
47+
{DNNL_ARG_WEIGHTS_LAYER, expected_weights_layer},
48+
{DNNL_ARG_WEIGHTS_ITER, expected_weights_iter},
49+
{DNNL_ARG_BIAS, bias},
50+
{DNNL_ARG_DST_LAYER, dst_layer},
51+
{DNNL_ARG_DST_ITER, dst_iter},
52+
{DNNL_ARG_DST_ITER_C, dst_iter_c}});
53+
}
954
};
1055

1156
struct lstm_backward : public dnnl::lstm_backward {
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .roi_align import *
22
from .nms import *
3+
from .lstm import *
34
from .interaction import *
45
from .embeddingbag import *
56
from .jit import *
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from torch.nn.utils.rnn import PackedSequence
3+
4+
# This is a solution to swap the lstm module with the ipex counterpart
5+
# and will upstream this operator to PyTorch when oneDNN support
6+
# bias and src_iter_c in bf16 in bf16 inference. Will keep this
7+
# for better support of blocked-format weight, e.g. for training.
8+
9+
10+
class IpexLSTM(torch.nn.LSTM):
11+
def __init__(self, *args, **kwargs):
12+
super().__init__(*args, **kwargs)
13+
14+
# port from torch/nn/modules/rnn.py
15+
# replace the _VF.lstm with torch.ops.torch_ipex.lstm when the input is not PackedSequence
16+
def forward(self, input, hx=None): # noqa: F811
17+
orig_input = input
18+
# xxx: isinstance check needs to be in conditional for TorchScript to compile
19+
if isinstance(orig_input, PackedSequence):
20+
# fallback to PyTorch LSTM since PackedSequence unsupported in oneDNN
21+
return super(IpexLSTM, self).forward(input, hx)
22+
else:
23+
batch_sizes = None
24+
max_batch_size = input.size(0) if self.batch_first else input.size(1)
25+
sorted_indices = None
26+
unsorted_indices = None
27+
28+
if hx is None:
29+
num_directions = 2 if self.bidirectional else 1
30+
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
31+
h_zeros = torch.zeros(self.num_layers * num_directions,
32+
max_batch_size, real_hidden_size,
33+
dtype=input.dtype, device=input.device)
34+
c_zeros = torch.zeros(self.num_layers * num_directions,
35+
max_batch_size, self.hidden_size,
36+
dtype=input.dtype, device=input.device)
37+
hx = (h_zeros, c_zeros)
38+
else:
39+
# Each batch of the hidden state should match the input sequence that
40+
# the user believes he/she is passing in.
41+
hx = self.permute_hidden(hx, sorted_indices)
42+
43+
self.check_forward_args(input, hx, batch_sizes)
44+
result = torch.ops.torch_ipex.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
45+
self.dropout, self.training, self.bidirectional, self.batch_first)
46+
output = result[0]
47+
hidden = result[1:]
48+
49+
return output, self.permute_hidden(hidden, unsorted_indices)

intel_pytorch_extension_py/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import copy
2+
13
import torch
4+
5+
from .ops.lstm import IpexLSTM
26
from .fx import *
37

48
def _replace_dropout_with_identity(model):
@@ -11,6 +15,22 @@ def _replace_dropout_with_identity(model):
1115
else:
1216
_replace_dropout_with_identity(child)
1317

18+
def _replace_lstm_with_ipex_lstm(model):
19+
# replace lstm with ipex lstm during inference
20+
# does not support the case where model itself is torch.nn.LSTM
21+
if not model.training:
22+
for child_name, child in model.named_children():
23+
if isinstance(child, torch.nn.LSTM):
24+
assert hasattr(child, "weight_ih_l0"), "torch.nn.LSTM should have weight_ih_l0"
25+
ipex_lstm = IpexLSTM(child.input_size, child.hidden_size,
26+
child.num_layers, child.bias, child.batch_first,
27+
child.dropout, child.bidirectional, child.proj_size,
28+
child.weight_ih_l0.device, child.weight_ih_l0.dtype)
29+
ipex_lstm.__dict__ = copy.deepcopy(child.__dict__)
30+
setattr(model, child_name, ipex_lstm)
31+
else:
32+
_replace_lstm_with_ipex_lstm(child)
33+
1434
def convert_module_data_type(module, dtype):
1535
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
1636
weight_data = module.weight.detach().clone().to(dtype)

tests/cpu/test_autocast.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from common_utils import TestCase
66
import time, sys
77
from torch.testing._core import _get_default_tolerance
8+
import itertools
89

910
def get_rand_seed():
1011
return int(time.time() * 1000000000)
@@ -212,5 +213,143 @@ def test_embeddingbag_op(self):
212213
self.assertEqual(traininig_out.dtype, torch.float)
213214
self.assertEqual(cpu_out, traininig_out)
214215

216+
class M(nn.Module):
217+
def __init__(self, input_size, hidden_size, num_layers, bidirectional, bias, dropout, batch_first):
218+
super(M, self).__init__()
219+
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, bias=bias, dropout=dropout, batch_first=batch_first)
220+
221+
def forward(self, x, h = None):
222+
x, h = self.lstm(x, h)
223+
return x, h
224+
class TestLSTM(TestCase):
225+
def _lstm_params_list(self):
226+
params_dict = {
227+
"input_size": [1, 2],
228+
"hidden_size": [5],
229+
"num_layers": [1, 3],
230+
"bidirectional": [False, True],
231+
"bias": [False, True],
232+
"empty_state": [False, True],
233+
"batch_first": [False, True],
234+
"dropout": [0, 1],
235+
"batch_size": [1, 2],
236+
"seq_len": [1, 3]
237+
}
238+
239+
params_list = []
240+
for key, value in params_dict.items():
241+
params_list.append(value)
242+
return params_list
243+
244+
def _cast_dtype(self, input, bf16):
245+
if bf16:
246+
input = input.to(torch.bfloat16)
247+
return input
248+
249+
def _test_lstm(self, training, bf16, prec = 1e-5):
250+
rand_seed = int(get_rand_seed())
251+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
252+
torch.manual_seed(rand_seed)
253+
with torch.set_grad_enabled(training):
254+
params_list = self._lstm_params_list()
255+
for input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, dropout, batch_size, seq_len in itertools.product(*params_list):
256+
# dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1
257+
if dropout > 0 and num_layers == 1:
258+
continue
259+
260+
num_directions = 2 if bidirectional else 1
261+
262+
if batch_first:
263+
input = torch.randn(batch_size, seq_len, input_size)
264+
else:
265+
input = torch.randn(seq_len, batch_size, input_size)
266+
h = torch.randn(num_layers * num_directions, batch_size, hidden_size)
267+
c = torch.randn(num_layers * num_directions, batch_size, hidden_size)
268+
269+
input_ipex = copy.deepcopy(input)
270+
h_ipex = copy.deepcopy(h)
271+
c_ipex = copy.deepcopy(c)
272+
273+
model = M(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, bias=bias, dropout=dropout, batch_first=batch_first)
274+
model.train() if training else model.eval()
275+
276+
model_ipex = copy.deepcopy(model)
277+
model_ipex.train() if training else model_ipex.eval()
278+
ipex.utils._replace_lstm_with_ipex_lstm(model_ipex)
279+
280+
with ipex.amp.autocast(enabled=bf16, configure=ipex.conf.AmpConf(torch.bfloat16)):
281+
if empty_state:
282+
y, hy = model(self._cast_dtype(input, bf16))
283+
y_ipex, hy_ipex = model_ipex(input)
284+
else:
285+
y, hy = model(input, (self._cast_dtype(h, bf16), self._cast_dtype(c, bf16)))
286+
y_ipex, hy_ipex = model_ipex(input, (h, c))
287+
288+
if not training and bf16:
289+
self.assertEqual(input_ipex.dtype, torch.float)
290+
self.assertEqual(h_ipex.dtype, torch.float)
291+
self.assertEqual(c_ipex.dtype, torch.float)
292+
293+
# with mkldnn LSTM, y, hy[0] is bf16 and hy[1] is fp32
294+
self.assertEqual(y_ipex.dtype, torch.bfloat16)
295+
self.assertEqual(hy_ipex[0].dtype, torch.bfloat16)
296+
self.assertEqual(hy_ipex[1].dtype, torch.float)
297+
self.assertEqual(y, y_ipex, prec=prec)
298+
self.assertEqual(hy[0], hy_ipex[0], prec=prec)
299+
300+
self.assertEqual(hy[1], self._cast_dtype(hy_ipex[1], bf16), prec=prec)
301+
302+
def _test_lstm_pack_padded_sequence(self):
303+
embedding_dim = 1024
304+
hidden_dim = 10
305+
batch_size = 24
306+
num_layers = 1
307+
bidirectional = True
308+
num_direc = 2 if bidirectional else 1
309+
max_lens = 96
310+
311+
sent = torch.randn(batch_size, max_lens, embedding_dim)
312+
hid_0 = torch.rand(num_layers * num_direc, batch_size, hidden_dim)
313+
hid_1 = torch.randn(num_layers * num_direc, batch_size, hidden_dim)
314+
315+
sentences = sent.clone().requires_grad_(False)
316+
sent_lens = torch.Tensor([1, 2, 3, 4, 5, 1, 3, 2, 96, 5, 3, 1, 1, 2, 1, 2, 3, 6, \
317+
1, 2, 4, 6, 2, 1])
318+
319+
assert sent_lens.shape[0] == batch_size
320+
assert sent_lens.max().item() == max_lens
321+
322+
hidden_0 = hid_0.clone().requires_grad_(False)
323+
hidden_1 = hid_1.clone().requires_grad_(False)
324+
embeds = torch.nn.utils.rnn.pack_padded_sequence(sentences, sent_lens, batch_first=True, enforce_sorted=False)
325+
326+
model = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, batch_first=True)
327+
328+
model_ipex = copy.deepcopy(model)
329+
ipex.utils._replace_lstm_with_ipex_lstm(model_ipex)
330+
331+
lstm_out, hidden_out = model(embeds, (hidden_0, hidden_1))
332+
lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
333+
334+
lstm_out_ipex, hidden_out_ipex = model_ipex(embeds, (hidden_0, hidden_1))
335+
lstm_out_ipex, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out_ipex, batch_first=True)
336+
337+
self.assertEqual(lstm_out, lstm_out_ipex)
338+
self.assertEqual(hidden_out[0], hidden_out_ipex[0])
339+
self.assertEqual(hidden_out[1], hidden_out_ipex[1])
340+
341+
def test_lstm_inference(self):
342+
self._test_lstm(training=False, bf16=False)
343+
344+
self._test_lstm(training=False, bf16=True, prec=2e-2)
345+
346+
self._test_lstm(training=True, bf16=False)
347+
348+
# TODO: autocast does not support LSTM bf16 training
349+
# self._test_lstm(training=True, bf16=True)
350+
351+
def test_lstm_pack_padded_sequence(self):
352+
self._test_lstm_pack_padded_sequence()
353+
215354
if __name__ == '__main__':
216355
test = unittest.main()

torch_ipex/csrc/autocast_mode.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ MAKE_REGISTER_FUNC(ADD_NS(std), "std", Tensor (const Tensor &, bool), fp32)
195195
MAKE_REGISTER_FUNC(ADD_NS(std), "std.dim", Tensor (const Tensor &, IntArrayRef, bool, bool), fp32)
196196
MAKE_REGISTER_FUNC(ADD_NS(instance_norm), "instance_norm", Tensor (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, const c10::optional<Tensor>&, const c10::optional<Tensor>&, bool, double, double, bool), fp32)
197197
MAKE_REGISTER_FUNC(ADD_NS(grid_sampler), "grid_sampler", Tensor (const Tensor &, const Tensor &, int64_t, int64_t, bool), fp32)
198+
MAKE_REGISTER_FUNC(ADD_NS(gather), "gather", Tensor (const Tensor &, int64_t, const Tensor &, bool), fp32)
198199

199200
// promote
200201
MAKE_REGISTER_FUNC(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote)

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <ATen/Tensor.h>
55
#include <torch/extension.h>
6-
6+
#include "ideep/ideep.hpp"
77

88
namespace torch_ipex {
99

@@ -86,14 +86,9 @@ class AtenIpexTypeExt {
8686
const at::Tensor& dboxes_xywh,
8787
const double scale_xy,
8888
const double scale_wh);
89+
static std::tuple<at::Tensor, at::Tensor, at::Tensor> lstm(
90+
const at::Tensor& input, std::vector<at::Tensor> hx, std::vector<at::Tensor> params, bool has_biases,
91+
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first);
8992
};
9093

9194
} // namespace torch_ipex
92-
93-
// namespace {
94-
// static auto dispatch =
95-
// torch::RegisterOperators()
96-
// // .op("torch_ipex::embedding_bag", &torch_ipex::AtenIpexTypeExt::embedding_bag)
97-
// .op("torch_ipex::interaction_forward", &torch_ipex::AtenIpexTypeExt::interaction_forward)
98-
// .op("torch_ipex::interaction_backward", &torch_ipex::AtenIpexTypeExt::interaction_backward);
99-
// }

0 commit comments

Comments
 (0)