Skip to content

Commit 5e9c26c

Browse files
george-qipytorchmergebot
authored andcommitted
[maskedtensor] adding reductions (#82839)
Pull Request resolved: #82839 Approved by: https://github.com/bhosmer
1 parent f125bd2 commit 5e9c26c

File tree

11 files changed

+882
-136
lines changed

11 files changed

+882
-136
lines changed

test/test_maskedtensor.py

Lines changed: 325 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
SampleInput,
1313
)
1414

15-
from torch.masked import masked_tensor
15+
from torch.masked import MaskedTensor, masked_bmm
1616
from torch.masked.maskedtensor.core import _masks_match, _tensors_match
1717
from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS
18-
1918
from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
2019

2120

@@ -126,7 +125,7 @@ def _get_sample_kwargs(self, fn_name):
126125

127126
def _get_sample_args(self, fn_name, data, mask):
128127
fn_name = _fix_fn_name(fn_name)
129-
mt = masked_tensor(data, mask)
128+
mt = MaskedTensor(data, mask)
130129
t_args = [data]
131130
mt_args = [mt]
132131
if fn_name in ["pow"]:
@@ -185,8 +184,8 @@ def _yield_sample_args(self, fn_name, data0, data1, mask):
185184
while the MaskedTensor args tests both (MaskedTensor, MaskedTensor) and (MaskedTensor, Tensor)
186185
"""
187186
fn_name = _fix_fn_name(fn_name)
188-
mt0 = masked_tensor(data0, mask)
189-
mt1 = masked_tensor(data1, mask)
187+
mt0 = MaskedTensor(data0, mask)
188+
mt1 = MaskedTensor(data1, mask)
190189

191190
t_args = [data0, data1]
192191
mt_args = [mt0, mt1]
@@ -227,8 +226,8 @@ def test_masks_match(self, fn_name):
227226
data0, data1, mask = self._get_test_data(fn_name)
228227
mask0 = mask
229228
mask1 = torch.rand(mask.size()) > 0.5
230-
mt0 = masked_tensor(data0, mask0)
231-
mt1 = masked_tensor(data1, mask1)
229+
mt0 = MaskedTensor(data0, mask0)
230+
mt1 = MaskedTensor(data1, mask1)
232231
try:
233232
fn(mt0, mt1)
234233
raise AssertionError()
@@ -238,8 +237,327 @@ def test_masks_match(self, fn_name):
238237
== str(e)
239238
)
240239

240+
class TestReductions(TestCase):
241+
def test_max_not_implemented(self):
242+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
243+
m = torch.tensor([[True, False, False], [False, True, False]])
244+
mt = MaskedTensor(d, m)
245+
with self.assertRaisesRegex(TypeError, "no implementation found for 'torch.ops.aten.max'"):
246+
mt.max()
247+
248+
def test_sum(self):
249+
d = torch.tensor([[0, 1, 2, 6], [3, 4, 5.0, 7]])
250+
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
251+
mt = MaskedTensor(d, m)
252+
_compare_mts(MaskedTensor(torch.tensor(17.0), torch.tensor(True)), mt.sum())
253+
_compare_mts(
254+
MaskedTensor(
255+
torch.tensor([0.0, 4.0, 1.0, 13]),
256+
torch.tensor([True, True, False, True]),
257+
),
258+
mt.sum(dim=0),
259+
)
260+
261+
def test_sum_grad(self):
262+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
263+
m = torch.tensor([[True, False, False], [False, True, False]])
264+
mt = MaskedTensor(d, m, requires_grad=True)
265+
mt.sum().backward()
266+
_compare_mts(mt.grad, MaskedTensor(torch.tensor(1.0).expand_as(m), m))
267+
268+
def test_mean(self):
269+
d = torch.tensor([[0, 1, 3, 2], [3, 4, 1.0, 4]])
270+
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
271+
mt = MaskedTensor(d, m)
272+
_compare_mts(MaskedTensor(torch.tensor(2.5), torch.tensor(True)), mt.mean())
273+
_compare_mts(
274+
MaskedTensor(
275+
torch.tensor([0.0, 4.0, 1.0, 3]),
276+
torch.tensor([True, True, False, True]),
277+
),
278+
mt.mean(dim=0),
279+
)
280+
281+
"""
282+
The following block of tests "test_mean_grad_case_1[a through e] are used to test the functionality of
283+
the two different ways of constructing MaskedTensors:
284+
MaskedTensor(data, mask, requires_grad=True/False) -- NO differentiable constructor and always a leaf
285+
MaskedTensor.from_values(data, mask) -- differentiable constructor
286+
287+
Like torch.tensor(data), MaskedTensor(data, mask) will provide a UserWarning if data.requires_grad=True
288+
MaskedTensor.from_values does not take in requires_grad -- it just takes on the requires_grad from data
289+
290+
Therefore, there are 6 cases to test and we use `mean` as a proxy to test the different combinations
291+
292+
Assuming mt.mean().backward() is run after each constructor:
293+
294+
Case 1a:
295+
values.requires_grad = True
296+
mt = MaskedTensor(values, mask, requires_grad=True)
297+
yields
298+
- Provide a UserWarning because values.requires_grad=True
299+
- values.grad = None
300+
- mt.grad is a MaskedTensor with the correct gradient
301+
302+
Case 1b:
303+
values.requires_grad = False
304+
mt = MaskedTensor(values, mask, requires_grad=True)
305+
yields
306+
- values.grad = None
307+
- mt.grad is a MaskedTensor with the correct gradient
308+
309+
Case 2a/2b:
310+
values.requires_grad = True/False
311+
mt = MaskedTensor(values, mask, requires_grad=False)
312+
313+
will both yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
314+
as expected. When values.requires_grad=True, we will also get a UserWarning
315+
316+
Case 3a:
317+
values.requires_grad = True
318+
mt = MaskedTensor.from_values(values, mask)
319+
yields
320+
- values.grad is a MaskedTensor with the correct gradient
321+
- mt.grad is None and gives a UserWarning that
322+
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
323+
324+
Case 3b:
325+
values.requires_grad = False
326+
mt = MaskedTensor.from_values(values, mask)
327+
328+
will yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
329+
as expected.
330+
"""
331+
def test_mean_grad_case_1a(self):
332+
""" values.requires_grad = True
333+
mt = MaskedTensor(values, mask, requires_grad=True)
334+
"""
335+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
336+
m = torch.tensor([[True, False, False], [False, True, False]])
337+
with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
338+
mt = MaskedTensor(d, m, requires_grad=True)
339+
mt.mean().backward()
340+
self.assertIsNone(d.grad)
341+
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
342+
343+
def test_mean_grad_case_1b(self):
344+
""" values.requires_grad = False
345+
mt = MaskedTensor(values, mask, requires_grad=True)
346+
"""
347+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
348+
m = torch.tensor([[True, False, False], [False, True, False]])
349+
mt = MaskedTensor(d, m, requires_grad=True)
350+
mt.mean().backward()
351+
self.assertIsNone(d.grad)
352+
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
353+
354+
def test_mean_grad_case_1c(self):
355+
""" values.requires_grad = True
356+
mt = MaskedTensor(values, mask, requires_grad=False)
357+
"""
358+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
359+
m = torch.tensor([[True, False, False], [False, True, False]])
360+
with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
361+
mt = MaskedTensor(d, m, requires_grad=False)
362+
result = mt.mean()
363+
msg = "element 0 of tensors does not require grad and does not have a grad_fn"
364+
with self.assertRaisesRegex(RuntimeError, msg):
365+
result.backward()
366+
367+
368+
def test_mean_grad_case_1d(self):
369+
""" values.requires_grad = False
370+
mt = MaskedTensor(values, mask, requires_grad=False)
371+
"""
372+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
373+
m = torch.tensor([[True, False, False], [False, True, False]])
374+
mt = MaskedTensor(d, m, requires_grad=False)
375+
result = mt.mean()
376+
msg = "element 0 of tensors does not require grad and does not have a grad_fn"
377+
with self.assertRaisesRegex(RuntimeError, msg):
378+
result.backward()
379+
380+
def test_mean_grad_case_1e(self):
381+
""" values.requires_grad = True
382+
mt = MaskedTensor.from_values(values, mask)
383+
"""
384+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
385+
m = torch.tensor([[True, False, False], [False, True, False]])
386+
mt = MaskedTensor.from_values(d, m)
387+
mt.mean().backward()
388+
_compare_mts(d.grad, MaskedTensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
389+
msg = "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
390+
with self.assertWarnsRegex(UserWarning, msg):
391+
self.assertIsNone(mt.grad)
392+
393+
def test_mean_grad_case_1f(self):
394+
""" values.requires_grad = False
395+
mt = MaskedTensor.from_values(values, mask)
396+
"""
397+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
398+
m = torch.tensor([[True, False, False], [False, True, False]])
399+
mt = MaskedTensor.from_values(d, m)
400+
result = mt.mean()
401+
msg = "element 0 of tensors does not require grad and does not have a grad_fn"
402+
with self.assertRaisesRegex(RuntimeError, msg):
403+
result.backward()
404+
405+
def test_mean_dim_grad(self):
406+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
407+
m = torch.tensor([[True, True, False], [False, True, False]])
408+
mt = MaskedTensor(d, m, requires_grad=True)
409+
mt.mean(1).sum().backward()
410+
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[0.5, 0.5, 0], [0, 1, 0]]), m))
411+
412+
def test_amax(self):
413+
d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
414+
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
415+
mt = MaskedTensor(d, m)
416+
_compare_mts(MaskedTensor(torch.tensor(3.0), torch.tensor(True)), mt.amax())
417+
_compare_mts(
418+
MaskedTensor(
419+
torch.tensor([0.0, -4.0, 1.0, 3]),
420+
torch.tensor([True, True, False, True]),
421+
),
422+
mt.amax(dim=0),
423+
)
424+
425+
def test_amax_grad(self):
426+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
427+
m = torch.tensor([[True, False, False], [False, True, False]])
428+
mt = MaskedTensor(d, m, requires_grad=True)
429+
mt.amax().backward()
430+
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[0.0, 0, 0], [0, 1, 0]]), m))
431+
432+
def test_amin(self):
433+
d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
434+
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
435+
mt = MaskedTensor(d, m)
436+
_compare_mts(MaskedTensor(torch.tensor(-4.0), torch.tensor(True)), mt.amin())
437+
_compare_mts(
438+
MaskedTensor(
439+
torch.tensor([0.0, -4.0, 1.0, -3]),
440+
torch.tensor([True, True, False, True]),
441+
),
442+
mt.amin(dim=0),
443+
)
444+
445+
def test_amin_grad(self):
446+
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
447+
m = torch.tensor([[True, False, False], [False, True, False]])
448+
mt = MaskedTensor(d, m, requires_grad=True)
449+
mt.amin().backward()
450+
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[1.0, 0, 0], [0, 0, 0]]), m))
451+
452+
def test_prod(self):
453+
d = torch.tensor([[0, 1, 3, 0.0], [float("nan"), 4, 1.0, 5.0]])
454+
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
455+
mt = MaskedTensor(d, m)
456+
_compare_mts(MaskedTensor(torch.tensor(0.0), torch.tensor(True)), mt.prod())
457+
_compare_mts(
458+
MaskedTensor(
459+
torch.tensor([0.0, 4.0, 1.0, 0.0]),
460+
torch.tensor([True, True, False, True]),
461+
),
462+
mt.prod(dim=0),
463+
)
464+
465+
def test_prod_grad(self):
466+
d = torch.tensor([[2, float("nan"), 2], [3, 4, 5.0]])
467+
m = torch.tensor([[True, False, False], [False, True, False]])
468+
mt = MaskedTensor(d, m, requires_grad=True)
469+
mt.prod().backward()
470+
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[4.0, 0, 0], [0, 2, 0]]), m))
471+
472+
def test_all(self):
473+
d = torch.tensor([[True, True, False, False], [False, True, True, True]])
474+
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
475+
mt = MaskedTensor(d, m)
476+
_compare_mts(MaskedTensor(torch.tensor(False), torch.tensor(True)), mt.all())
477+
_compare_mts(
478+
MaskedTensor(
479+
torch.tensor([True, True, True, False]),
480+
torch.tensor([True, True, False, True]),
481+
),
482+
mt.all(dim=0),
483+
)
484+
485+
m = torch.tensor([[True, False, True, False], [False, True, False, False]])
486+
mt = MaskedTensor(d, m)
487+
_compare_mts(
488+
MaskedTensor(
489+
torch.tensor([True, True, False, True]),
490+
torch.tensor([True, True, True, False]),
491+
),
492+
mt.all(dim=0),
493+
)
494+
495+
def test_grad_dtype(self):
496+
d = torch.tensor([[True, True, False], [False, True, True]])
497+
m = torch.tensor([[True, False, False], [False, True, False]])
498+
msg = "Only Tensors of floating point and complex dtype can require gradients"
499+
with self.assertRaisesRegex(RuntimeError, msg):
500+
MaskedTensor(d, m, requires_grad=True)
501+
502+
class TestMatMul(TestCase):
503+
def test_bmm(self):
504+
x = torch.rand(3, 2, 1)
505+
key_padding_mask = torch.tensor(
506+
[
507+
[False, False, False],
508+
[False, True, True],
509+
]
510+
)
511+
x_mt = MaskedTensor(x, ~(key_padding_mask.transpose(0, 1).unsqueeze(-1)))
512+
x = x.masked_fill(~x_mt.get_mask(), 0)
513+
attn_2 = torch.bmm(x, x.transpose(-2, -1))
514+
attn_3 = torch.bmm(x_mt, x_mt.transpose(-2, -1))
515+
self.assertEqual(attn_3.get_data().masked_fill(~attn_3.get_mask(), 0), attn_2) # type: ignore[attr-defined]
516+
517+
def test_masked_bmm(self):
518+
key_padding_mask = torch.tensor(
519+
[
520+
[False, False, False, True],
521+
[False, True, True, True],
522+
[False, True, False, True],
523+
]
524+
)
525+
x = torch.arange(4 * 3 * 2).reshape(4, 3, 2).float()
526+
x_mt = MaskedTensor(
527+
x,
528+
~(key_padding_mask.transpose(0, 1).unsqueeze(-1).expand_as(x)),
529+
requires_grad=True,
530+
)
531+
attn_mask_bool = torch.tensor(
532+
[
533+
[False, True, True],
534+
[False, False, True],
535+
[True, False, False],
536+
]
537+
)
538+
attn_mask = attn_mask_bool.float().masked_fill_(attn_mask_bool, float("-inf"))
539+
v = masked_bmm(x, x_mt.transpose(1, 2), attn_mask)
540+
v.sum().backward()
541+
542+
def test_linear(self):
543+
x = torch.arange(4 * 3 * 2).reshape(4, 3, 2)
544+
w_x = torch.arange(10).reshape(5, 2) + x.amax()
545+
linear = torch.nn.functional.linear
546+
key_padding_mask = torch.tensor(
547+
[
548+
[False, False, False, True],
549+
[False, True, True, True],
550+
[False, True, False, True],
551+
]
552+
)
553+
x_mt = MaskedTensor(
554+
x, ~(key_padding_mask.transpose(0, 1).unsqueeze(-1).expand_as(x))
555+
)
556+
241557
instantiate_parametrized_tests(TestUnary)
242558
instantiate_parametrized_tests(TestBinary)
559+
instantiate_parametrized_tests(TestReductions)
560+
instantiate_parametrized_tests(TestMatMul)
243561

244562
if __name__ == '__main__':
245563
run_tests()

0 commit comments

Comments
 (0)