5
5
from common_utils import TestCase
6
6
import time , sys
7
7
from torch .testing ._core import _get_default_tolerance
8
+ import itertools
8
9
9
10
def get_rand_seed ():
10
11
return int (time .time () * 1000000000 )
@@ -212,5 +213,143 @@ def test_embeddingbag_op(self):
212
213
self .assertEqual (traininig_out .dtype , torch .float )
213
214
self .assertEqual (cpu_out , traininig_out )
214
215
216
+ class M (nn .Module ):
217
+ def __init__ (self , input_size , hidden_size , num_layers , bidirectional , bias , dropout , batch_first ):
218
+ super (M , self ).__init__ ()
219
+ self .lstm = nn .LSTM (input_size = input_size , hidden_size = hidden_size , num_layers = num_layers , bidirectional = bidirectional , bias = bias , dropout = dropout , batch_first = batch_first )
220
+
221
+ def forward (self , x , h = None ):
222
+ x , h = self .lstm (x , h )
223
+ return x , h
224
+ class TestLSTM (TestCase ):
225
+ def _lstm_params_list (self ):
226
+ params_dict = {
227
+ "input_size" : [1 , 2 ],
228
+ "hidden_size" : [5 ],
229
+ "num_layers" : [1 , 3 ],
230
+ "bidirectional" : [False , True ],
231
+ "bias" : [False , True ],
232
+ "empty_state" : [False , True ],
233
+ "batch_first" : [False , True ],
234
+ "dropout" : [0 , 1 ],
235
+ "batch_size" : [1 , 2 ],
236
+ "seq_len" : [1 , 3 ]
237
+ }
238
+
239
+ params_list = []
240
+ for key , value in params_dict .items ():
241
+ params_list .append (value )
242
+ return params_list
243
+
244
+ def _cast_dtype (self , input , bf16 ):
245
+ if bf16 :
246
+ input = input .to (torch .bfloat16 )
247
+ return input
248
+
249
+ def _test_lstm (self , training , bf16 , prec = 1e-5 ):
250
+ rand_seed = int (get_rand_seed ())
251
+ print ("{} rand sed: {}" .format (sys ._getframe ().f_code .co_name , rand_seed ))
252
+ torch .manual_seed (rand_seed )
253
+ with torch .set_grad_enabled (training ):
254
+ params_list = self ._lstm_params_list ()
255
+ for input_size , hidden_size , num_layers , bidirectional , bias , empty_state , batch_first , dropout , batch_size , seq_len in itertools .product (* params_list ):
256
+ # dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1
257
+ if dropout > 0 and num_layers == 1 :
258
+ continue
259
+
260
+ num_directions = 2 if bidirectional else 1
261
+
262
+ if batch_first :
263
+ input = torch .randn (batch_size , seq_len , input_size )
264
+ else :
265
+ input = torch .randn (seq_len , batch_size , input_size )
266
+ h = torch .randn (num_layers * num_directions , batch_size , hidden_size )
267
+ c = torch .randn (num_layers * num_directions , batch_size , hidden_size )
268
+
269
+ input_ipex = copy .deepcopy (input )
270
+ h_ipex = copy .deepcopy (h )
271
+ c_ipex = copy .deepcopy (c )
272
+
273
+ model = M (input_size = input_size , hidden_size = hidden_size , num_layers = num_layers , bidirectional = bidirectional , bias = bias , dropout = dropout , batch_first = batch_first )
274
+ model .train () if training else model .eval ()
275
+
276
+ model_ipex = copy .deepcopy (model )
277
+ model_ipex .train () if training else model_ipex .eval ()
278
+ ipex .utils ._replace_lstm_with_ipex_lstm (model_ipex )
279
+
280
+ with ipex .amp .autocast (enabled = bf16 , configure = ipex .conf .AmpConf (torch .bfloat16 )):
281
+ if empty_state :
282
+ y , hy = model (self ._cast_dtype (input , bf16 ))
283
+ y_ipex , hy_ipex = model_ipex (input )
284
+ else :
285
+ y , hy = model (input , (self ._cast_dtype (h , bf16 ), self ._cast_dtype (c , bf16 )))
286
+ y_ipex , hy_ipex = model_ipex (input , (h , c ))
287
+
288
+ if not training and bf16 :
289
+ self .assertEqual (input_ipex .dtype , torch .float )
290
+ self .assertEqual (h_ipex .dtype , torch .float )
291
+ self .assertEqual (c_ipex .dtype , torch .float )
292
+
293
+ # with mkldnn LSTM, y, hy[0] is bf16 and hy[1] is fp32
294
+ self .assertEqual (y_ipex .dtype , torch .bfloat16 )
295
+ self .assertEqual (hy_ipex [0 ].dtype , torch .bfloat16 )
296
+ self .assertEqual (hy_ipex [1 ].dtype , torch .float )
297
+ self .assertEqual (y , y_ipex , prec = prec )
298
+ self .assertEqual (hy [0 ], hy_ipex [0 ], prec = prec )
299
+
300
+ self .assertEqual (hy [1 ], self ._cast_dtype (hy_ipex [1 ], bf16 ), prec = prec )
301
+
302
+ def _test_lstm_pack_padded_sequence (self ):
303
+ embedding_dim = 1024
304
+ hidden_dim = 10
305
+ batch_size = 24
306
+ num_layers = 1
307
+ bidirectional = True
308
+ num_direc = 2 if bidirectional else 1
309
+ max_lens = 96
310
+
311
+ sent = torch .randn (batch_size , max_lens , embedding_dim )
312
+ hid_0 = torch .rand (num_layers * num_direc , batch_size , hidden_dim )
313
+ hid_1 = torch .randn (num_layers * num_direc , batch_size , hidden_dim )
314
+
315
+ sentences = sent .clone ().requires_grad_ (False )
316
+ sent_lens = torch .Tensor ([1 , 2 , 3 , 4 , 5 , 1 , 3 , 2 , 96 , 5 , 3 , 1 , 1 , 2 , 1 , 2 , 3 , 6 , \
317
+ 1 , 2 , 4 , 6 , 2 , 1 ])
318
+
319
+ assert sent_lens .shape [0 ] == batch_size
320
+ assert sent_lens .max ().item () == max_lens
321
+
322
+ hidden_0 = hid_0 .clone ().requires_grad_ (False )
323
+ hidden_1 = hid_1 .clone ().requires_grad_ (False )
324
+ embeds = torch .nn .utils .rnn .pack_padded_sequence (sentences , sent_lens , batch_first = True , enforce_sorted = False )
325
+
326
+ model = nn .LSTM (embedding_dim , hidden_dim , num_layers = num_layers , bidirectional = bidirectional , batch_first = True )
327
+
328
+ model_ipex = copy .deepcopy (model )
329
+ ipex .utils ._replace_lstm_with_ipex_lstm (model_ipex )
330
+
331
+ lstm_out , hidden_out = model (embeds , (hidden_0 , hidden_1 ))
332
+ lstm_out , _ = torch .nn .utils .rnn .pad_packed_sequence (lstm_out , batch_first = True )
333
+
334
+ lstm_out_ipex , hidden_out_ipex = model_ipex (embeds , (hidden_0 , hidden_1 ))
335
+ lstm_out_ipex , _ = torch .nn .utils .rnn .pad_packed_sequence (lstm_out_ipex , batch_first = True )
336
+
337
+ self .assertEqual (lstm_out , lstm_out_ipex )
338
+ self .assertEqual (hidden_out [0 ], hidden_out_ipex [0 ])
339
+ self .assertEqual (hidden_out [1 ], hidden_out_ipex [1 ])
340
+
341
+ def test_lstm_inference (self ):
342
+ self ._test_lstm (training = False , bf16 = False )
343
+
344
+ self ._test_lstm (training = False , bf16 = True , prec = 2e-2 )
345
+
346
+ self ._test_lstm (training = True , bf16 = False )
347
+
348
+ # TODO: autocast does not support LSTM bf16 training
349
+ # self._test_lstm(training=True, bf16=True)
350
+
351
+ def test_lstm_pack_padded_sequence (self ):
352
+ self ._test_lstm_pack_padded_sequence ()
353
+
215
354
if __name__ == '__main__' :
216
355
test = unittest .main ()
0 commit comments