12
12
SampleInput ,
13
13
)
14
14
15
- from torch .masked import masked_tensor
15
+ from torch .masked import MaskedTensor , masked_bmm
16
16
from torch .masked .maskedtensor .core import _masks_match , _tensors_match
17
17
from torch .masked .maskedtensor .unary import NATIVE_INPLACE_UNARY_FNS , NATIVE_UNARY_FNS
18
-
19
18
from torch .masked .maskedtensor .binary import NATIVE_BINARY_FNS , NATIVE_INPLACE_BINARY_FNS
20
19
21
20
@@ -126,7 +125,7 @@ def _get_sample_kwargs(self, fn_name):
126
125
127
126
def _get_sample_args (self , fn_name , data , mask ):
128
127
fn_name = _fix_fn_name (fn_name )
129
- mt = masked_tensor (data , mask )
128
+ mt = MaskedTensor (data , mask )
130
129
t_args = [data ]
131
130
mt_args = [mt ]
132
131
if fn_name in ["pow" ]:
@@ -185,8 +184,8 @@ def _yield_sample_args(self, fn_name, data0, data1, mask):
185
184
while the MaskedTensor args tests both (MaskedTensor, MaskedTensor) and (MaskedTensor, Tensor)
186
185
"""
187
186
fn_name = _fix_fn_name (fn_name )
188
- mt0 = masked_tensor (data0 , mask )
189
- mt1 = masked_tensor (data1 , mask )
187
+ mt0 = MaskedTensor (data0 , mask )
188
+ mt1 = MaskedTensor (data1 , mask )
190
189
191
190
t_args = [data0 , data1 ]
192
191
mt_args = [mt0 , mt1 ]
@@ -227,8 +226,8 @@ def test_masks_match(self, fn_name):
227
226
data0 , data1 , mask = self ._get_test_data (fn_name )
228
227
mask0 = mask
229
228
mask1 = torch .rand (mask .size ()) > 0.5
230
- mt0 = masked_tensor (data0 , mask0 )
231
- mt1 = masked_tensor (data1 , mask1 )
229
+ mt0 = MaskedTensor (data0 , mask0 )
230
+ mt1 = MaskedTensor (data1 , mask1 )
232
231
try :
233
232
fn (mt0 , mt1 )
234
233
raise AssertionError ()
@@ -238,8 +237,327 @@ def test_masks_match(self, fn_name):
238
237
== str (e )
239
238
)
240
239
240
+ class TestReductions (TestCase ):
241
+ def test_max_not_implemented (self ):
242
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]])
243
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
244
+ mt = MaskedTensor (d , m )
245
+ with self .assertRaisesRegex (TypeError , "no implementation found for 'torch.ops.aten.max'" ):
246
+ mt .max ()
247
+
248
+ def test_sum (self ):
249
+ d = torch .tensor ([[0 , 1 , 2 , 6 ], [3 , 4 , 5.0 , 7 ]])
250
+ m = torch .tensor ([[True , False , False , True ], [False , True , False , True ]])
251
+ mt = MaskedTensor (d , m )
252
+ _compare_mts (MaskedTensor (torch .tensor (17.0 ), torch .tensor (True )), mt .sum ())
253
+ _compare_mts (
254
+ MaskedTensor (
255
+ torch .tensor ([0.0 , 4.0 , 1.0 , 13 ]),
256
+ torch .tensor ([True , True , False , True ]),
257
+ ),
258
+ mt .sum (dim = 0 ),
259
+ )
260
+
261
+ def test_sum_grad (self ):
262
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]])
263
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
264
+ mt = MaskedTensor (d , m , requires_grad = True )
265
+ mt .sum ().backward ()
266
+ _compare_mts (mt .grad , MaskedTensor (torch .tensor (1.0 ).expand_as (m ), m ))
267
+
268
+ def test_mean (self ):
269
+ d = torch .tensor ([[0 , 1 , 3 , 2 ], [3 , 4 , 1.0 , 4 ]])
270
+ m = torch .tensor ([[True , False , False , True ], [False , True , False , True ]])
271
+ mt = MaskedTensor (d , m )
272
+ _compare_mts (MaskedTensor (torch .tensor (2.5 ), torch .tensor (True )), mt .mean ())
273
+ _compare_mts (
274
+ MaskedTensor (
275
+ torch .tensor ([0.0 , 4.0 , 1.0 , 3 ]),
276
+ torch .tensor ([True , True , False , True ]),
277
+ ),
278
+ mt .mean (dim = 0 ),
279
+ )
280
+
281
+ """
282
+ The following block of tests "test_mean_grad_case_1[a through e] are used to test the functionality of
283
+ the two different ways of constructing MaskedTensors:
284
+ MaskedTensor(data, mask, requires_grad=True/False) -- NO differentiable constructor and always a leaf
285
+ MaskedTensor.from_values(data, mask) -- differentiable constructor
286
+
287
+ Like torch.tensor(data), MaskedTensor(data, mask) will provide a UserWarning if data.requires_grad=True
288
+ MaskedTensor.from_values does not take in requires_grad -- it just takes on the requires_grad from data
289
+
290
+ Therefore, there are 6 cases to test and we use `mean` as a proxy to test the different combinations
291
+
292
+ Assuming mt.mean().backward() is run after each constructor:
293
+
294
+ Case 1a:
295
+ values.requires_grad = True
296
+ mt = MaskedTensor(values, mask, requires_grad=True)
297
+ yields
298
+ - Provide a UserWarning because values.requires_grad=True
299
+ - values.grad = None
300
+ - mt.grad is a MaskedTensor with the correct gradient
301
+
302
+ Case 1b:
303
+ values.requires_grad = False
304
+ mt = MaskedTensor(values, mask, requires_grad=True)
305
+ yields
306
+ - values.grad = None
307
+ - mt.grad is a MaskedTensor with the correct gradient
308
+
309
+ Case 2a/2b:
310
+ values.requires_grad = True/False
311
+ mt = MaskedTensor(values, mask, requires_grad=False)
312
+
313
+ will both yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
314
+ as expected. When values.requires_grad=True, we will also get a UserWarning
315
+
316
+ Case 3a:
317
+ values.requires_grad = True
318
+ mt = MaskedTensor.from_values(values, mask)
319
+ yields
320
+ - values.grad is a MaskedTensor with the correct gradient
321
+ - mt.grad is None and gives a UserWarning that
322
+ "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
323
+
324
+ Case 3b:
325
+ values.requires_grad = False
326
+ mt = MaskedTensor.from_values(values, mask)
327
+
328
+ will yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
329
+ as expected.
330
+ """
331
+ def test_mean_grad_case_1a (self ):
332
+ """ values.requires_grad = True
333
+ mt = MaskedTensor(values, mask, requires_grad=True)
334
+ """
335
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]], requires_grad = True )
336
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
337
+ with self .assertWarnsRegex (UserWarning , "It is not recommended to create a MaskedTensor" ):
338
+ mt = MaskedTensor (d , m , requires_grad = True )
339
+ mt .mean ().backward ()
340
+ self .assertIsNone (d .grad )
341
+ _compare_mts (mt .grad , MaskedTensor (torch .tensor ([[0.5 , 0 , 0 ], [0 , 0.5 , 0 ]]), m ))
342
+
343
+ def test_mean_grad_case_1b (self ):
344
+ """ values.requires_grad = False
345
+ mt = MaskedTensor(values, mask, requires_grad=True)
346
+ """
347
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]])
348
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
349
+ mt = MaskedTensor (d , m , requires_grad = True )
350
+ mt .mean ().backward ()
351
+ self .assertIsNone (d .grad )
352
+ _compare_mts (mt .grad , MaskedTensor (torch .tensor ([[0.5 , 0 , 0 ], [0 , 0.5 , 0 ]]), m ))
353
+
354
+ def test_mean_grad_case_1c (self ):
355
+ """ values.requires_grad = True
356
+ mt = MaskedTensor(values, mask, requires_grad=False)
357
+ """
358
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]], requires_grad = True )
359
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
360
+ with self .assertWarnsRegex (UserWarning , "It is not recommended to create a MaskedTensor" ):
361
+ mt = MaskedTensor (d , m , requires_grad = False )
362
+ result = mt .mean ()
363
+ msg = "element 0 of tensors does not require grad and does not have a grad_fn"
364
+ with self .assertRaisesRegex (RuntimeError , msg ):
365
+ result .backward ()
366
+
367
+
368
+ def test_mean_grad_case_1d (self ):
369
+ """ values.requires_grad = False
370
+ mt = MaskedTensor(values, mask, requires_grad=False)
371
+ """
372
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]])
373
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
374
+ mt = MaskedTensor (d , m , requires_grad = False )
375
+ result = mt .mean ()
376
+ msg = "element 0 of tensors does not require grad and does not have a grad_fn"
377
+ with self .assertRaisesRegex (RuntimeError , msg ):
378
+ result .backward ()
379
+
380
+ def test_mean_grad_case_1e (self ):
381
+ """ values.requires_grad = True
382
+ mt = MaskedTensor.from_values(values, mask)
383
+ """
384
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]], requires_grad = True )
385
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
386
+ mt = MaskedTensor .from_values (d , m )
387
+ mt .mean ().backward ()
388
+ _compare_mts (d .grad , MaskedTensor (torch .tensor ([[0.5 , 0 , 0 ], [0 , 0.5 , 0 ]]), m ))
389
+ msg = "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
390
+ with self .assertWarnsRegex (UserWarning , msg ):
391
+ self .assertIsNone (mt .grad )
392
+
393
+ def test_mean_grad_case_1f (self ):
394
+ """ values.requires_grad = False
395
+ mt = MaskedTensor.from_values(values, mask)
396
+ """
397
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]])
398
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
399
+ mt = MaskedTensor .from_values (d , m )
400
+ result = mt .mean ()
401
+ msg = "element 0 of tensors does not require grad and does not have a grad_fn"
402
+ with self .assertRaisesRegex (RuntimeError , msg ):
403
+ result .backward ()
404
+
405
+ def test_mean_dim_grad (self ):
406
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]])
407
+ m = torch .tensor ([[True , True , False ], [False , True , False ]])
408
+ mt = MaskedTensor (d , m , requires_grad = True )
409
+ mt .mean (1 ).sum ().backward ()
410
+ _compare_mts (mt .grad , MaskedTensor (torch .tensor ([[0.5 , 0.5 , 0 ], [0 , 1 , 0 ]]), m ))
411
+
412
+ def test_amax (self ):
413
+ d = torch .tensor ([[0 , 1 , 3 , - 3 ], [3 , - 4 , 1.0 , 3 ]])
414
+ m = torch .tensor ([[True , False , False , True ], [False , True , False , True ]])
415
+ mt = MaskedTensor (d , m )
416
+ _compare_mts (MaskedTensor (torch .tensor (3.0 ), torch .tensor (True )), mt .amax ())
417
+ _compare_mts (
418
+ MaskedTensor (
419
+ torch .tensor ([0.0 , - 4.0 , 1.0 , 3 ]),
420
+ torch .tensor ([True , True , False , True ]),
421
+ ),
422
+ mt .amax (dim = 0 ),
423
+ )
424
+
425
+ def test_amax_grad (self ):
426
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]])
427
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
428
+ mt = MaskedTensor (d , m , requires_grad = True )
429
+ mt .amax ().backward ()
430
+ _compare_mts (mt .grad , MaskedTensor (torch .tensor ([[0.0 , 0 , 0 ], [0 , 1 , 0 ]]), m ))
431
+
432
+ def test_amin (self ):
433
+ d = torch .tensor ([[0 , 1 , 3 , - 3 ], [3 , - 4 , 1.0 , 3 ]])
434
+ m = torch .tensor ([[True , False , False , True ], [False , True , False , True ]])
435
+ mt = MaskedTensor (d , m )
436
+ _compare_mts (MaskedTensor (torch .tensor (- 4.0 ), torch .tensor (True )), mt .amin ())
437
+ _compare_mts (
438
+ MaskedTensor (
439
+ torch .tensor ([0.0 , - 4.0 , 1.0 , - 3 ]),
440
+ torch .tensor ([True , True , False , True ]),
441
+ ),
442
+ mt .amin (dim = 0 ),
443
+ )
444
+
445
+ def test_amin_grad (self ):
446
+ d = torch .tensor ([[0 , 1 , 2 ], [3 , 4 , 5.0 ]])
447
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
448
+ mt = MaskedTensor (d , m , requires_grad = True )
449
+ mt .amin ().backward ()
450
+ _compare_mts (mt .grad , MaskedTensor (torch .tensor ([[1.0 , 0 , 0 ], [0 , 0 , 0 ]]), m ))
451
+
452
+ def test_prod (self ):
453
+ d = torch .tensor ([[0 , 1 , 3 , 0.0 ], [float ("nan" ), 4 , 1.0 , 5.0 ]])
454
+ m = torch .tensor ([[True , False , False , True ], [False , True , False , True ]])
455
+ mt = MaskedTensor (d , m )
456
+ _compare_mts (MaskedTensor (torch .tensor (0.0 ), torch .tensor (True )), mt .prod ())
457
+ _compare_mts (
458
+ MaskedTensor (
459
+ torch .tensor ([0.0 , 4.0 , 1.0 , 0.0 ]),
460
+ torch .tensor ([True , True , False , True ]),
461
+ ),
462
+ mt .prod (dim = 0 ),
463
+ )
464
+
465
+ def test_prod_grad (self ):
466
+ d = torch .tensor ([[2 , float ("nan" ), 2 ], [3 , 4 , 5.0 ]])
467
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
468
+ mt = MaskedTensor (d , m , requires_grad = True )
469
+ mt .prod ().backward ()
470
+ _compare_mts (mt .grad , MaskedTensor (torch .tensor ([[4.0 , 0 , 0 ], [0 , 2 , 0 ]]), m ))
471
+
472
+ def test_all (self ):
473
+ d = torch .tensor ([[True , True , False , False ], [False , True , True , True ]])
474
+ m = torch .tensor ([[True , False , False , True ], [False , True , False , True ]])
475
+ mt = MaskedTensor (d , m )
476
+ _compare_mts (MaskedTensor (torch .tensor (False ), torch .tensor (True )), mt .all ())
477
+ _compare_mts (
478
+ MaskedTensor (
479
+ torch .tensor ([True , True , True , False ]),
480
+ torch .tensor ([True , True , False , True ]),
481
+ ),
482
+ mt .all (dim = 0 ),
483
+ )
484
+
485
+ m = torch .tensor ([[True , False , True , False ], [False , True , False , False ]])
486
+ mt = MaskedTensor (d , m )
487
+ _compare_mts (
488
+ MaskedTensor (
489
+ torch .tensor ([True , True , False , True ]),
490
+ torch .tensor ([True , True , True , False ]),
491
+ ),
492
+ mt .all (dim = 0 ),
493
+ )
494
+
495
+ def test_grad_dtype (self ):
496
+ d = torch .tensor ([[True , True , False ], [False , True , True ]])
497
+ m = torch .tensor ([[True , False , False ], [False , True , False ]])
498
+ msg = "Only Tensors of floating point and complex dtype can require gradients"
499
+ with self .assertRaisesRegex (RuntimeError , msg ):
500
+ MaskedTensor (d , m , requires_grad = True )
501
+
502
+ class TestMatMul (TestCase ):
503
+ def test_bmm (self ):
504
+ x = torch .rand (3 , 2 , 1 )
505
+ key_padding_mask = torch .tensor (
506
+ [
507
+ [False , False , False ],
508
+ [False , True , True ],
509
+ ]
510
+ )
511
+ x_mt = MaskedTensor (x , ~ (key_padding_mask .transpose (0 , 1 ).unsqueeze (- 1 )))
512
+ x = x .masked_fill (~ x_mt .get_mask (), 0 )
513
+ attn_2 = torch .bmm (x , x .transpose (- 2 , - 1 ))
514
+ attn_3 = torch .bmm (x_mt , x_mt .transpose (- 2 , - 1 ))
515
+ self .assertEqual (attn_3 .get_data ().masked_fill (~ attn_3 .get_mask (), 0 ), attn_2 ) # type: ignore[attr-defined]
516
+
517
+ def test_masked_bmm (self ):
518
+ key_padding_mask = torch .tensor (
519
+ [
520
+ [False , False , False , True ],
521
+ [False , True , True , True ],
522
+ [False , True , False , True ],
523
+ ]
524
+ )
525
+ x = torch .arange (4 * 3 * 2 ).reshape (4 , 3 , 2 ).float ()
526
+ x_mt = MaskedTensor (
527
+ x ,
528
+ ~ (key_padding_mask .transpose (0 , 1 ).unsqueeze (- 1 ).expand_as (x )),
529
+ requires_grad = True ,
530
+ )
531
+ attn_mask_bool = torch .tensor (
532
+ [
533
+ [False , True , True ],
534
+ [False , False , True ],
535
+ [True , False , False ],
536
+ ]
537
+ )
538
+ attn_mask = attn_mask_bool .float ().masked_fill_ (attn_mask_bool , float ("-inf" ))
539
+ v = masked_bmm (x , x_mt .transpose (1 , 2 ), attn_mask )
540
+ v .sum ().backward ()
541
+
542
+ def test_linear (self ):
543
+ x = torch .arange (4 * 3 * 2 ).reshape (4 , 3 , 2 )
544
+ w_x = torch .arange (10 ).reshape (5 , 2 ) + x .amax ()
545
+ linear = torch .nn .functional .linear
546
+ key_padding_mask = torch .tensor (
547
+ [
548
+ [False , False , False , True ],
549
+ [False , True , True , True ],
550
+ [False , True , False , True ],
551
+ ]
552
+ )
553
+ x_mt = MaskedTensor (
554
+ x , ~ (key_padding_mask .transpose (0 , 1 ).unsqueeze (- 1 ).expand_as (x ))
555
+ )
556
+
241
557
instantiate_parametrized_tests (TestUnary )
242
558
instantiate_parametrized_tests (TestBinary )
559
+ instantiate_parametrized_tests (TestReductions )
560
+ instantiate_parametrized_tests (TestMatMul )
243
561
244
562
if __name__ == '__main__' :
245
563
run_tests ()
0 commit comments