Skip to content

Commit bbaa063

Browse files
nkaretnikovpytorchmergebot
authored andcommitted
Add error inputs to gaussian_nll_loss OpInfo (#88486)
Pull Request resolved: #88486 Approved by: https://github.com/lezcano
1 parent 404f254 commit bbaa063

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

test/test_nn.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6144,20 +6144,6 @@ def test_poisson_nll_loss_reduction_modes(self):
61446144
with self.assertRaisesRegex(ValueError, 'is not valid'):
61456145
F.poisson_nll_loss(input, target, reduction='total')
61466146

6147-
def test_gaussian_nll_loss_reduction_modes(self):
6148-
input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
6149-
target = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
6150-
var = torch.tensor([[0.5, 1., 1.5], [1., 1.5, 2.]])
6151-
component_wise_loss = 0.5 * (torch.log(var) + (input - target)**2 / var)
6152-
self.assertEqual(component_wise_loss,
6153-
F.gaussian_nll_loss(input, target, var, reduction='none'))
6154-
self.assertEqual(torch.sum(component_wise_loss),
6155-
F.gaussian_nll_loss(input, target, var, reduction='sum'))
6156-
self.assertEqual(torch.mean(component_wise_loss),
6157-
F.gaussian_nll_loss(input, target, var, reduction='mean'))
6158-
with self.assertRaisesRegex(ValueError, 'is not valid'):
6159-
F.gaussian_nll_loss(input, target, var, reduction='total')
6160-
61616147
def test_gaussian_nll_loss_broadcasting(self):
61626148
input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
61636149
target_full = torch.tensor([[1., 2., 3.], [1., 2., 3.]])

torch/testing/_internal/common_methods_invocations.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6890,6 +6890,23 @@ def gen_shape_kwargs():
68906890
for input, target, var, kwargs in gen_shape_kwargs():
68916891
yield SampleInput(input, args=(target, var, ), kwargs=kwargs)
68926892

6893+
def error_inputs_gaussian_nll_loss(op_info, device, **kwargs):
6894+
_make = partial(make_tensor, device=device, dtype=torch.float32)
6895+
6896+
# invalid reduction value
6897+
yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 3), low=0), reduction="abc"),
6898+
error_type=ValueError, error_regex="abc is not valid")
6899+
6900+
# var is of incorrect shape
6901+
yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 2), low=0)),
6902+
error_type=ValueError, error_regex="var is of incorrect size")
6903+
6904+
# target is of incorrect shape
6905+
yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 2), _make((10, 2, 3), low=0)),
6906+
error_type=RuntimeError,
6907+
error_regex=(r"The size of tensor a \(3\) must match the size of tensor b \(2\) "
6908+
r"at non-singleton dimension 2"))
6909+
68936910
def _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
68946911
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
68956912

@@ -16193,6 +16210,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1619316210
supports_forward_ad=True,
1619416211
supports_fwgrad_bwgrad=True,
1619516212
sample_inputs_func=sample_inputs_gaussian_nll_loss,
16213+
error_inputs_func=error_inputs_gaussian_nll_loss,
1619616214
skips=(
1619716215
# Pre-existing condition (calls .item); needs to be fixed
1619816216
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),

0 commit comments

Comments
 (0)