Skip to content

pytorch test_sum fails with out.dtype=int64, but should be uint64 [sum(uint8)] #160

Closed
@asmeurer

Description

@asmeurer

Using the compat library PR at data-apis/array-api-compat#14 (to work around #159):

$PYTHONPATH=~/Documents/array-api-compat ARRAY_API_TESTS_MODULE=array_api_compat.torch pytest --max-examples=1000 -k test_sum
========================================================================== test session starts ==========================================================================
platform darwin -- Python 3.9.15, pytest-7.2.0, pluggy-1.0.0
rootdir: /Users/aaronmeurer/Documents/array-api-tests
plugins: pudb-0.7.0, html-3.2.0, json-report-1.5.0, doctestplus-0.12.1, dependency-0.5.1, cov-4.0.0, metadata-2.0.4, hypothesis-6.61.0
collected 1114 items / 1113 deselected / 1 selected

array_api_tests/test_statistical_functions.py F                                                                                                                   [100%]

=============================================================================== FAILURES ================================================================================
_______________________________________________________________________________ test_sum ________________________________________________________________________________

    @given(
>       x=xps.arrays(
            dtype=xps.numeric_dtypes(),
            shape=hh.shapes(min_side=1),
            elements={"allow_nan": False},
        ),
        data=st.data(),
    )

array_api_tests/test_statistical_functions.py:209:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
array_api_tests/test_statistical_functions.py:251: in test_sum
    ph.assert_dtype("sum", x.dtype, out.dtype, _dtype)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

func_name = 'sum', in_dtype = torch.uint8, out_dtype = torch.int64, expected = <undefined stub for 'uint64'>

    def assert_dtype(
        func_name: str,
        in_dtype: Union[DataType, Sequence[DataType]],
        out_dtype: DataType,
        expected: Optional[DataType] = None,
        *,
        repr_name: str = "out.dtype",
    ):
        """
        Assert the output dtype is as expected.

        If expected=None, we infer the expected dtype as in_dtype, to test
        out_dtype, e.g.

            >>> x = xp.arange(5, dtype=xp.uint8)
            >>> out = xp.abs(x)
            >>> assert_dtype('abs', x.dtype, out.dtype)

            is equivalent to

            >>> assert out.dtype == xp.uint8

        Or for multiple input dtypes, the expected dtype is inferred from their
        resulting type promotion, e.g.

            >>> x1 = xp.arange(5, dtype=xp.uint8)
            >>> x2 = xp.arange(5, dtype=xp.uint16)
            >>> out = xp.add(x1, x2)
            >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)

            is equivalent to

            >>> assert out.dtype == xp.uint16

        We can also specify the expected dtype ourselves, e.g.

            >>> x = xp.arange(5, dtype=xp.int8)
            >>> out = xp.sum(x)
            >>> default_int = xp.asarray(0).dtype
            >>> assert_dtype('sum', x, out.dtype, default_int)

        """
        in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype]
        f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
        f_out_dtype = dh.dtype_to_name[out_dtype]
        if expected is None:
            expected = dh.result_type(*in_dtypes)
        f_expected = dh.dtype_to_name[expected]
        msg = (
            f"{repr_name}={f_out_dtype}, but should be {f_expected} "
            f"[{func_name}({f_in_dtypes})]"
        )
>       assert out_dtype == expected, msg
E       AssertionError: out.dtype=int64, but should be uint64 [sum(uint8)]
E       Falsifying example: test_sum(
E           x=tensor(0, dtype=torch.uint8), data=data(...),
E       )
E       Draw 1 (kw): {}

array_api_tests/pytest_helpers.py:134: AssertionError

Not sure if this is something we should work around or not. Given that pytorch doesn't have uint64, it cannot return that. However, the spec does say that the return type should be uint64 https://data-apis.org/array-api/latest/API_specification/generated/array_api.sum.html#array_api.sum. It's possible also the spec should be updated here.

I will say this error makes it a little harder to verify the other sum() behavior (a general problem with the design of the test suite), although I suppose I can just comment out this line.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions