|
| 1 | +# We use __future__ for forward reference type hints - this will work for even py3.8.0 |
| 2 | +# See https://stackoverflow.com/a/33533514/5193926 |
1 | 3 | from __future__ import annotations
|
2 | 4 |
|
3 | 5 | import inspect
|
4 | 6 | import math
|
5 | 7 | import re
|
6 |
| -from dataclasses import dataclass |
| 8 | +from dataclasses import dataclass, field |
7 | 9 | from decimal import ROUND_HALF_EVEN, Decimal
|
8 | 10 | from enum import Enum, auto
|
9 | 11 | from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
@@ -169,7 +171,7 @@ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
|
169 | 171 |
|
170 | 172 | @dataclass
|
171 | 173 | class BoundFromDtype(FromDtypeFunc):
|
172 |
| - kwargs: Dict[str, Any] |
| 174 | + kwargs: Dict[str, Any] = field(default_factory=dict) |
173 | 175 | filter_: Optional[Callable[[Array], bool]] = None
|
174 | 176 | base_func: Optional[FromDtypeFunc] = None
|
175 | 177 |
|
@@ -718,26 +720,38 @@ def partial_cond(i1: float, i2: float) -> bool:
|
718 | 720 | x1_cond_from_dtypes.append(cond_from_dtype)
|
719 | 721 | x2_cond_from_dtypes.append(cond_from_dtype)
|
720 | 722 | else:
|
| 723 | + # For "either x1_i or x2_i is <condition>" cases, we want to |
| 724 | + # test three scenarios: |
| 725 | + # |
| 726 | + # 1. x1_i is <condition> |
| 727 | + # 2. x2_i is <condition> |
| 728 | + # 3. x1_i AND x2_i is <condition> |
| 729 | + # |
| 730 | + # This is achieved by a shared base strategy that picks one |
| 731 | + # of these scenarios to determine whether each array will |
| 732 | + # use either cond_from_dtype() (i.e. meet the condition), or |
| 733 | + # simply xps.from_dtype() (i.e. be any value). |
| 734 | + |
721 | 735 | use_x1_or_x2_strat = st.shared(
|
722 |
| - st.sampled_from([(True, False), (True, False), (True, True)]) |
| 736 | + st.sampled_from([(True, False), (False, True), (True, True)]) |
723 | 737 | )
|
724 | 738 |
|
725 |
| - def x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: |
| 739 | + def _x1_cond_from_dtype(dtype) -> st.SearchStrategy[float]: |
726 | 740 | return use_x1_or_x2_strat.flatmap(
|
727 | 741 | lambda t: cond_from_dtype(dtype)
|
728 | 742 | if t[0]
|
729 | 743 | else xps.from_dtype(dtype)
|
730 | 744 | )
|
731 | 745 |
|
732 |
| - def x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: |
| 746 | + def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]: |
733 | 747 | return use_x1_or_x2_strat.flatmap(
|
734 | 748 | lambda t: cond_from_dtype(dtype)
|
735 | 749 | if t[1]
|
736 | 750 | else xps.from_dtype(dtype)
|
737 | 751 | )
|
738 | 752 |
|
739 |
| - x1_cond_from_dtypes.append(x1_cond_from_dtype) |
740 |
| - x2_cond_from_dtypes.append(x2_cond_from_dtype) |
| 753 | + x1_cond_from_dtypes.append(_x1_cond_from_dtype) |
| 754 | + x2_cond_from_dtypes.append(_x2_cond_from_dtype) |
741 | 755 |
|
742 | 756 | partial_conds.append(partial_cond)
|
743 | 757 | partial_exprs.append(partial_expr)
|
@@ -768,17 +782,17 @@ def cond(i1: float, i2: float) -> bool:
|
768 | 782 | elif len(x1_cond_from_dtypes) == 1:
|
769 | 783 | x1_cond_from_dtype = x1_cond_from_dtypes[0]
|
770 | 784 | else:
|
771 |
| - # sanity check |
772 |
| - assert all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes) |
773 |
| - x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype({})) |
| 785 | + if not all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes): |
| 786 | + raise ValueParseError(case_str) |
| 787 | + x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype()) |
774 | 788 | if len(x2_cond_from_dtypes) == 0:
|
775 | 789 | x2_cond_from_dtype = xps.from_dtype
|
776 | 790 | elif len(x2_cond_from_dtypes) == 1:
|
777 | 791 | x2_cond_from_dtype = x2_cond_from_dtypes[0]
|
778 | 792 | else:
|
779 |
| - # sanity check |
780 |
| - assert all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes) |
781 |
| - x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype({})) |
| 793 | + if not all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes): |
| 794 | + raise ValueParseError(case_str) |
| 795 | + x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype()) |
782 | 796 |
|
783 | 797 | return BinaryCase(
|
784 | 798 | cond_expr=cond_expr,
|
@@ -819,10 +833,6 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
|
819 | 833 | return cases
|
820 | 834 |
|
821 | 835 |
|
822 |
| -# Here be the tests |
823 |
| -# ------------------------------------------------------------------------------ |
824 |
| - |
825 |
| - |
826 | 836 | unary_params = []
|
827 | 837 | binary_params = []
|
828 | 838 | for stub in category_to_funcs["elementwise"]:
|
|
0 commit comments