From c92dd0965ded9f6a76745295f37f3ad0f6fd9de6 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 28 Oct 2021 10:41:40 +0100 Subject: [PATCH 1/4] Filter undefined dtypes in `mutually_promotable_dtypes` --- array_api_tests/hypothesis_helpers.py | 17 ++++++----------- .../meta/test_hypothesis_helpers.py | 18 ++++++++++++++++-- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index aed4f1a2..9740c3df 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -2,7 +2,7 @@ from functools import reduce from math import sqrt from operator import mul -from typing import Any, List, NamedTuple, Optional, Tuple +from typing import Any, List, NamedTuple, Optional, Tuple, Sequence from hypothesis import assume from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, @@ -68,19 +68,14 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]): promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter) -if FILTER_UNDEFINED_DTYPES: - promotable_dtypes = [ - (i, j) for i, j in promotable_dtypes - if not isinstance(i, _UndefinedStub) - and not isinstance(j, _UndefinedStub) - ] - - def mutually_promotable_dtypes( max_size: Optional[int] = 2, *, - dtypes: Tuple[DataType, ...] = dh.all_dtypes, + dtypes: Sequence[DataType] = dh.all_dtypes, ) -> SearchStrategy[Tuple[DataType, ...]]: + if FILTER_UNDEFINED_DTYPES: + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] + assert len(dtypes) > 0, "all dtypes undefined" # sanity check if max_size == 2: return sampled_from( [(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes] @@ -347,7 +342,7 @@ def multiaxis_indices(draw, shapes): def two_mutual_arrays( - dtypes: Tuple[DataType, ...] = dh.all_dtypes, + dtypes: Sequence[DataType] = dh.all_dtypes, two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes, ) -> SearchStrategy: mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes)) diff --git a/array_api_tests/meta/test_hypothesis_helpers.py b/array_api_tests/meta/test_hypothesis_helpers.py index b57db4ed..c5ff94d6 100644 --- a/array_api_tests/meta/test_hypothesis_helpers.py +++ b/array_api_tests/meta/test_hypothesis_helpers.py @@ -15,8 +15,8 @@ pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")] @given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes)) -def test_mutually_promotable_dtypes(pairs): - assert pairs in ( +def test_mutually_promotable_dtypes(pair): + assert pair in ( (xp.float32, xp.float32), (xp.float32, xp.float64), (xp.float64, xp.float32), @@ -24,6 +24,20 @@ def test_mutually_promotable_dtypes(pairs): ) +@given( + hh.mutually_promotable_dtypes( + dtypes=[xp.uint8, _UndefinedStub("uint16"), xp.uint32] + ) +) +def test_partial_mutually_promotable_dtypes(pair): + assert pair in ( + (xp.uint8, xp.uint8), + (xp.uint8, xp.uint32), + (xp.uint32, xp.uint8), + (xp.uint32, xp.uint32), + ) + + def valid_shape(shape) -> bool: return ( all(isinstance(side, int) for side in shape) From b436fd889966758a7731f7437751a02bd1857611 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 28 Oct 2021 11:01:46 +0100 Subject: [PATCH 2/4] Use direct array module when making `xps` namespace --- array_api_tests/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index 785aa43b..d01335c9 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -1,6 +1,6 @@ from hypothesis.extra.array_api import make_strategies_namespace -from . import _array_module as xp +from ._array_module import mod as xp xps = make_strategies_namespace(xp) From eb706ca06b2c4d7c9fc980bb87859203674c22a0 Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 28 Oct 2021 11:11:22 +0100 Subject: [PATCH 3/4] Test case for PyTorch using `mutually_promotable_dtypes` --- array_api_tests/meta/test_partial_adopters.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 array_api_tests/meta/test_partial_adopters.py diff --git a/array_api_tests/meta/test_partial_adopters.py b/array_api_tests/meta/test_partial_adopters.py new file mode 100644 index 00000000..6eda5c89 --- /dev/null +++ b/array_api_tests/meta/test_partial_adopters.py @@ -0,0 +1,18 @@ +import pytest +from hypothesis import given + +from .. import dtype_helpers as dh +from .. import hypothesis_helpers as hh +from .. import _array_module as xp +from .._array_module import _UndefinedStub + + +# e.g. PyTorch only supports uint8 currently +@pytest.mark.skipif(isinstance(xp.uint8, _UndefinedStub), reason="uint8 not defined") +@pytest.mark.skipif( + not all(isinstance(d, _UndefinedStub) for d in dh.uint_dtypes[1:]), + reason="uints defined", +) +@given(hh.mutually_promotable_dtypes(dtypes=dh.uint_dtypes)) +def test_mutually_promotable_dtypes(pair): + assert pair == (xp.uint8, xp.uint8) From 0938424895e0edd98386a3265cfc1b3bac57c106 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 29 Oct 2021 10:57:34 +0100 Subject: [PATCH 4/4] Alias `mod` as `_xp`, clean `__init__` namespace --- array_api_tests/__init__.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index d01335c9..763cef14 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -1,6 +1,10 @@ from hypothesis.extra.array_api import make_strategies_namespace -from ._array_module import mod as xp +from ._array_module import mod as _xp -xps = make_strategies_namespace(xp) +xps = make_strategies_namespace(_xp) + + +del _xp +del make_strategies_namespace