@@ -6890,6 +6890,23 @@ def gen_shape_kwargs():
6890
6890
for input, target, var, kwargs in gen_shape_kwargs():
6891
6891
yield SampleInput(input, args=(target, var, ), kwargs=kwargs)
6892
6892
6893
+ def error_inputs_gaussian_nll_loss(op_info, device, **kwargs):
6894
+ _make = partial(make_tensor, device=device, dtype=torch.float32)
6895
+
6896
+ # invalid reduction value
6897
+ yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 3), low=0), reduction="abc"),
6898
+ error_type=ValueError, error_regex="abc is not valid")
6899
+
6900
+ # var is of incorrect shape
6901
+ yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 2), low=0)),
6902
+ error_type=ValueError, error_regex="var is of incorrect size")
6903
+
6904
+ # target is of incorrect shape
6905
+ yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 2), _make((10, 2, 3), low=0)),
6906
+ error_type=RuntimeError,
6907
+ error_regex=(r"The size of tensor a \(3\) must match the size of tensor b \(2\) "
6908
+ r"at non-singleton dimension 2"))
6909
+
6893
6910
def _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
6894
6911
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6895
6912
@@ -16193,6 +16210,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
16193
16210
supports_forward_ad=True,
16194
16211
supports_fwgrad_bwgrad=True,
16195
16212
sample_inputs_func=sample_inputs_gaussian_nll_loss,
16213
+ error_inputs_func=error_inputs_gaussian_nll_loss,
16196
16214
skips=(
16197
16215
# Pre-existing condition (calls .item); needs to be fixed
16198
16216
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
0 commit comments