Skip to content

Commit e4b6fa8

Browse files
committed
Comment on __future__ and either case strategies
1 parent 429a9a5 commit e4b6fa8

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

array_api_tests/test_special_cases.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
# We use __future__ for forward reference type hints - this will work for even py3.8.0
2+
# See https://stackoverflow.com/a/33533514/5193926
13
from __future__ import annotations
24

35
import inspect
46
import math
57
import re
6-
from dataclasses import dataclass
8+
from dataclasses import dataclass, field
79
from decimal import ROUND_HALF_EVEN, Decimal
810
from enum import Enum, auto
911
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
@@ -169,7 +171,7 @@ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
169171

170172
@dataclass
171173
class BoundFromDtype(FromDtypeFunc):
172-
kwargs: Dict[str, Any]
174+
kwargs: Dict[str, Any] = field(default_factory=dict)
173175
filter_: Optional[Callable[[Array], bool]] = None
174176
base_func: Optional[FromDtypeFunc] = None
175177

@@ -718,26 +720,38 @@ def partial_cond(i1: float, i2: float) -> bool:
718720
x1_cond_from_dtypes.append(cond_from_dtype)
719721
x2_cond_from_dtypes.append(cond_from_dtype)
720722
else:
723+
# For "either x1_i or x2_i is <condition>" cases, we want to
724+
# test three scenarios:
725+
#
726+
# 1. x1_i is <condition>
727+
# 2. x2_i is <condition>
728+
# 3. x1_i AND x2_i is <condition>
729+
#
730+
# This is achieved by a shared base strategy that picks one
731+
# of these scenarios to determine whether each array will
732+
# use either cond_from_dtype() (i.e. meet the condition), or
733+
# simply xps.from_dtype() (i.e. be any value).
734+
721735
use_x1_or_x2_strat = st.shared(
722-
st.sampled_from([(True, False), (True, False), (True, True)])
736+
st.sampled_from([(True, False), (False, True), (True, True)])
723737
)
724738

725-
def x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
739+
def _x1_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
726740
return use_x1_or_x2_strat.flatmap(
727741
lambda t: cond_from_dtype(dtype)
728742
if t[0]
729743
else xps.from_dtype(dtype)
730744
)
731745

732-
def x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
746+
def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
733747
return use_x1_or_x2_strat.flatmap(
734748
lambda t: cond_from_dtype(dtype)
735749
if t[1]
736750
else xps.from_dtype(dtype)
737751
)
738752

739-
x1_cond_from_dtypes.append(x1_cond_from_dtype)
740-
x2_cond_from_dtypes.append(x2_cond_from_dtype)
753+
x1_cond_from_dtypes.append(_x1_cond_from_dtype)
754+
x2_cond_from_dtypes.append(_x2_cond_from_dtype)
741755

742756
partial_conds.append(partial_cond)
743757
partial_exprs.append(partial_expr)
@@ -768,17 +782,17 @@ def cond(i1: float, i2: float) -> bool:
768782
elif len(x1_cond_from_dtypes) == 1:
769783
x1_cond_from_dtype = x1_cond_from_dtypes[0]
770784
else:
771-
# sanity check
772-
assert all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes)
773-
x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype({}))
785+
if not all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes):
786+
raise ValueParseError(case_str)
787+
x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype())
774788
if len(x2_cond_from_dtypes) == 0:
775789
x2_cond_from_dtype = xps.from_dtype
776790
elif len(x2_cond_from_dtypes) == 1:
777791
x2_cond_from_dtype = x2_cond_from_dtypes[0]
778792
else:
779-
# sanity check
780-
assert all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes)
781-
x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype({}))
793+
if not all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes):
794+
raise ValueParseError(case_str)
795+
x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype())
782796

783797
return BinaryCase(
784798
cond_expr=cond_expr,
@@ -819,10 +833,6 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
819833
return cases
820834

821835

822-
# Here be the tests
823-
# ------------------------------------------------------------------------------
824-
825-
826836
unary_params = []
827837
binary_params = []
828838
for stub in category_to_funcs["elementwise"]:

0 commit comments

Comments
 (0)