From 61eec43636db0104bbf15bb41c76d42fab46a656 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 8 Feb 2022 13:07:16 +0000 Subject: [PATCH 01/63] Add `array-api` spec as submodule, load its signatures --- .gitmodules | 3 +++ array_api_tests/array-api | 1 + array_api_tests/stubs.py | 47 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+) create mode 100644 .gitmodules create mode 160000 array_api_tests/array-api create mode 100644 array_api_tests/stubs.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..4128e9f2 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "array_api_tests/array-api"] + path = array_api_tests/array-api + url = https://github.com/data-apis/array-api/ diff --git a/array_api_tests/array-api b/array_api_tests/array-api new file mode 160000 index 00000000..2b9c402e --- /dev/null +++ b/array_api_tests/array-api @@ -0,0 +1 @@ +Subproject commit 2b9c402ebdb9825c2e8787caaabb5c5e3d9cf394 diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py new file mode 100644 index 00000000..ab1851f3 --- /dev/null +++ b/array_api_tests/stubs.py @@ -0,0 +1,47 @@ +import sys +from importlib import import_module +from importlib.util import find_spec +from pathlib import Path +from types import FunctionType, ModuleType +from typing import Dict, List + +__all__ = ["category_to_funcs", "array", "extension_to_funcs"] + + +spec_dir = Path(__file__).parent / "array-api" / "spec" / "API_specification" +assert spec_dir.exists(), f"{spec_dir} not found - try `git pull --recurse-submodules`" +sigs_dir = spec_dir / "signatures" +assert sigs_dir.exists() + +spec_abs_path: str = str(spec_dir.resolve()) +sys.path.append(spec_abs_path) +assert find_spec("signatures") is not None + +name_to_mod: Dict[str, ModuleType] = {} +for path in sigs_dir.glob("*.py"): + name = path.name.replace(".py", "") + name_to_mod[name] = import_module(f"signatures.{name}") + + +category_to_funcs: Dict[str, List[FunctionType]] = {} +for name, mod in name_to_mod.items(): + if name.endswith("_functions"): + category = name.replace("_functions", "") + objects = [getattr(mod, name) for name in mod.__all__] + assert all(isinstance(o, FunctionType) for o in objects) + category_to_funcs[category] = objects + + +array = name_to_mod["array_object"].array + + +EXTENSIONS = ["linalg"] +extension_to_funcs: Dict[str, List[FunctionType]] = {} +for ext in EXTENSIONS: + mod = name_to_mod[ext] + objects = [getattr(mod, name) for name in mod.__all__] + assert all(isinstance(o, FunctionType) for o in objects) + extension_to_funcs[ext] = objects + + +sys.path.remove(spec_abs_path) From 6a4fffc2354ef8677f32ce2343788f178a404cce Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 9 Feb 2022 13:02:25 +0000 Subject: [PATCH 02/63] Rudimentary special case tests on runtime --- array_api_tests/special_cases.py | 240 ++++++++++++ array_api_tests/special_cases/__init__.py | 0 array_api_tests/special_cases/test_abs.py | 53 --- array_api_tests/special_cases/test_acos.py | 66 ---- array_api_tests/special_cases/test_acosh.py | 66 ---- array_api_tests/special_cases/test_add.py | 226 ----------- array_api_tests/special_cases/test_asin.py | 79 ---- array_api_tests/special_cases/test_asinh.py | 79 ---- array_api_tests/special_cases/test_atan.py | 79 ---- array_api_tests/special_cases/test_atan2.py | 314 ---------------- array_api_tests/special_cases/test_atanh.py | 106 ------ array_api_tests/special_cases/test_ceil.py | 92 ----- array_api_tests/special_cases/test_cos.py | 79 ---- array_api_tests/special_cases/test_cosh.py | 79 ---- array_api_tests/special_cases/test_divide.py | 293 --------------- .../special_cases/test_dunder_abs.py | 52 --- .../special_cases/test_dunder_add.py | 225 ----------- .../special_cases/test_dunder_iadd.py | 243 ------------ .../special_cases/test_dunder_imul.py | 133 ------- .../special_cases/test_dunder_ipow.py | 353 ------------------ .../special_cases/test_dunder_itruediv.py | 315 ---------------- .../special_cases/test_dunder_mul.py | 123 ------ .../special_cases/test_dunder_pow.py | 326 ---------------- .../special_cases/test_dunder_truediv.py | 292 --------------- array_api_tests/special_cases/test_exp.py | 79 ---- array_api_tests/special_cases/test_expm1.py | 79 ---- array_api_tests/special_cases/test_floor.py | 92 ----- array_api_tests/special_cases/test_log.py | 80 ---- array_api_tests/special_cases/test_log10.py | 80 ---- array_api_tests/special_cases/test_log1p.py | 92 ----- array_api_tests/special_cases/test_log2.py | 80 ---- .../special_cases/test_logaddexp.py | 54 --- .../special_cases/test_multiply.py | 124 ------ array_api_tests/special_cases/test_pow.py | 327 ---------------- array_api_tests/special_cases/test_round.py | 108 ------ array_api_tests/special_cases/test_sign.py | 54 --- array_api_tests/special_cases/test_sin.py | 66 ---- array_api_tests/special_cases/test_sinh.py | 79 ---- array_api_tests/special_cases/test_sqrt.py | 79 ---- array_api_tests/special_cases/test_tan.py | 66 ---- array_api_tests/special_cases/test_tanh.py | 79 ---- array_api_tests/special_cases/test_trunc.py | 92 ----- 42 files changed, 240 insertions(+), 5383 deletions(-) create mode 100644 array_api_tests/special_cases.py delete mode 100644 array_api_tests/special_cases/__init__.py delete mode 100644 array_api_tests/special_cases/test_abs.py delete mode 100644 array_api_tests/special_cases/test_acos.py delete mode 100644 array_api_tests/special_cases/test_acosh.py delete mode 100644 array_api_tests/special_cases/test_add.py delete mode 100644 array_api_tests/special_cases/test_asin.py delete mode 100644 array_api_tests/special_cases/test_asinh.py delete mode 100644 array_api_tests/special_cases/test_atan.py delete mode 100644 array_api_tests/special_cases/test_atan2.py delete mode 100644 array_api_tests/special_cases/test_atanh.py delete mode 100644 array_api_tests/special_cases/test_ceil.py delete mode 100644 array_api_tests/special_cases/test_cos.py delete mode 100644 array_api_tests/special_cases/test_cosh.py delete mode 100644 array_api_tests/special_cases/test_divide.py delete mode 100644 array_api_tests/special_cases/test_dunder_abs.py delete mode 100644 array_api_tests/special_cases/test_dunder_add.py delete mode 100644 array_api_tests/special_cases/test_dunder_iadd.py delete mode 100644 array_api_tests/special_cases/test_dunder_imul.py delete mode 100644 array_api_tests/special_cases/test_dunder_ipow.py delete mode 100644 array_api_tests/special_cases/test_dunder_itruediv.py delete mode 100644 array_api_tests/special_cases/test_dunder_mul.py delete mode 100644 array_api_tests/special_cases/test_dunder_pow.py delete mode 100644 array_api_tests/special_cases/test_dunder_truediv.py delete mode 100644 array_api_tests/special_cases/test_exp.py delete mode 100644 array_api_tests/special_cases/test_expm1.py delete mode 100644 array_api_tests/special_cases/test_floor.py delete mode 100644 array_api_tests/special_cases/test_log.py delete mode 100644 array_api_tests/special_cases/test_log10.py delete mode 100644 array_api_tests/special_cases/test_log1p.py delete mode 100644 array_api_tests/special_cases/test_log2.py delete mode 100644 array_api_tests/special_cases/test_logaddexp.py delete mode 100644 array_api_tests/special_cases/test_multiply.py delete mode 100644 array_api_tests/special_cases/test_pow.py delete mode 100644 array_api_tests/special_cases/test_round.py delete mode 100644 array_api_tests/special_cases/test_sign.py delete mode 100644 array_api_tests/special_cases/test_sin.py delete mode 100644 array_api_tests/special_cases/test_sinh.py delete mode 100644 array_api_tests/special_cases/test_sqrt.py delete mode 100644 array_api_tests/special_cases/test_tan.py delete mode 100644 array_api_tests/special_cases/test_tanh.py delete mode 100644 array_api_tests/special_cases/test_trunc.py diff --git a/array_api_tests/special_cases.py b/array_api_tests/special_cases.py new file mode 100644 index 00000000..0b5c31fd --- /dev/null +++ b/array_api_tests/special_cases.py @@ -0,0 +1,240 @@ +import inspect +import math +import re +from typing import Callable, Dict, NamedTuple, Pattern +from warnings import warn + +import pytest +from attr import dataclass +from hypothesis import assume, given + +from . import hypothesis_helpers as hh +from . import shape_helpers as sh +from . import xps +from ._array_module import mod as xp +from .stubs import category_to_funcs + +repr_to_value = { + "NaN": float("nan"), + "+infinity": float("infinity"), + "infinity": float("infinity"), + "-infinity": float("-infinity"), + "+0": 0.0, + "0": 0.0, + "-0": -0.0, + "+1": 1.0, + "1": 1.0, + "-1": -1.0, + "+π/2": math.pi / 2, + "π/2": math.pi / 2, + "-π/2": -math.pi / 2, +} + + +def make_eq(v: float) -> Callable[[float], bool]: + if math.isnan(v): + return math.isnan + + def eq(i: float) -> bool: + return i == v + + return eq + + +def make_rough_eq(v: float) -> Callable[[float], bool]: + def rough_eq(i: float) -> bool: + return math.isclose(i, v, abs_tol=0.01) + + return rough_eq + + +def make_gt(v: float): + assert not math.isnan(v) # sanity check + + def gt(i: float): + return i > v + + return gt + + +def make_lt(v: float): + assert not math.isnan(v) # sanity check + + def lt(i: float): + return i < v + + return lt + + +def make_or(cond1: Callable, cond2: Callable): + def or_(i: float): + return cond1(i) or cond2(i) + + return or_ + + +r_value = re.compile(r"``([^\s]+)``") +r_approx_value = re.compile( + rf"an implementation-dependent approximation to {r_value.pattern}" +) + + +@dataclass +class ValueParseError(ValueError): + value: str + + +def parse_value(value: str) -> float: + if m := r_value.match(value): + return repr_to_value[m.group(1)] + raise ValueParseError(value) + + +class Result(NamedTuple): + value: float + repr_: str + strict_check: bool + + +def parse_result(result: str) -> Result: + if m := r_value.match(result): + repr_ = m.group(1) + strict_check = True + elif m := r_approx_value.match(result): + repr_ = m.group(1) + strict_check = False + else: + raise ValueParseError(result) + value = repr_to_value[repr_] + return Result(value, repr_, strict_check) + + +r_special_cases = re.compile( + r"\*\*Special [Cc]ases\*\*\n\n\s*" + r"For floating-point operands,\n\n" + r"((?:\s*-\s*.*\n)+)" +) +r_case = re.compile(r"\s+-\s*(.*)\.\n?") +r_remaining_case = re.compile("In the remaining cases.+") + + +unary_pattern_to_condition_factory: Dict[Pattern, Callable] = { + re.compile("If ``x_i`` is greater than (.+), the result is (.+)"): make_gt, + re.compile("If ``x_i`` is less than (.+), the result is (.+)"): make_lt, + re.compile("If ``x_i`` is either (.+) or (.+), the result is (.+)"): ( + lambda v1, v2: make_or(make_eq(v1), make_eq(v2)) + ), + # This pattern must come after the previous patterns to avoid unwanted matches + re.compile("If ``x_i`` is (.+), the result is (.+)"): make_eq, + re.compile( + "If two integers are equally close to ``x_i``, the result is (.+)" + ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5), +} + + +def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: + match = r_special_cases.search(docstring) + if match is None: + return {} + cases = match.group(1).split("\n")[:-1] + condition_to_result = {} + for line in cases: + if m := r_case.match(line): + case = m.group(1) + else: + warn(f"line not machine-readable: '{line}'") + continue + for pattern, make_cond in unary_pattern_to_condition_factory.items(): + if m := pattern.search(case): + *s_values, s_result = m.groups() + try: + values = [parse_value(v) for v in s_values] + except ValueParseError as e: + warn(f"value not machine-readable: '{e.value}'") + break + cond = make_cond(*values) + try: + result = parse_result(s_result) + except ValueParseError as e: + warn(f"result not machine-readable: '{e.value}'") + break + condition_to_result[cond] = result + break + else: + if not r_remaining_case.search(case): + warn(f"case not machine-readable: '{case}'") + return condition_to_result + + +unary_params = [] +for stub in category_to_funcs["elementwise"]: + if stub.__doc__ is None: + warn(f"{stub.__name__}() stub has no docstring") + continue + marks = [] + try: + func = getattr(xp, stub.__name__) + except AttributeError: + marks.append( + pytest.mark.skip(reason=f"{stub.__name__} not found in array module") + ) + func = None + sig = inspect.signature(stub) + param_names = list(sig.parameters.keys()) + if len(sig.parameters) == 0: + warn(f"{func=} has no parameters") + continue + if param_names[0] == "x": + if condition_to_result := parse_unary_docstring(stub.__doc__): + p = pytest.param(stub.__name__, func, condition_to_result, id=stub.__name__) + unary_params.append(p) + continue + if len(sig.parameters) == 1: + warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") + continue + if param_names[0] == "x1" and param_names[1] == "x2": + pass # TODO + else: + warn( + f"{func=} starts with two parameters '{param_names[0]}' and " + f"'{param_names[1]}', which are not named 'x1' and 'x2'" + ) + + +# good_example is a flag that tells us whether Hypothesis generated an array +# with at least on element that is special-cased. We reject the example when +# its False - Hypothesis will complain if we reject too many examples, thus +# indicating we should modify the array strategy being used. + + +@pytest.mark.parametrize("func_name, func, condition_to_result", unary_params) +@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) +def test_unary_special_cases(func_name, func, condition_to_result, x): + res = func(x) + good_example = False + for idx in sh.ndindex(res.shape): + in_ = float(x[idx]) + for cond, result in condition_to_result.items(): + if cond(in_): + good_example = True + out = float(res[idx]) + f_in = f"{sh.fmt_idx('x', idx)}={in_}" + f_out = f"{sh.fmt_idx('out', idx)}={out}" + if result.strict_check: + msg = ( + f"{f_out}, but should be {result.repr_} [{func_name}()]\n" + f"{f_in}" + ) + if math.isnan(result.value): + assert math.isnan(out), msg + else: + assert out == result.value, msg + else: + assert math.isfinite(result.value) # sanity check + assert math.isclose(out, result.value, abs_tol=0.1), ( + f"{f_out}, but should be roughly {result.repr_}={result.value} " + f"[{func_name}()]\n" + f"{f_in}" + ) + break + assume(good_example) diff --git a/array_api_tests/special_cases/__init__.py b/array_api_tests/special_cases/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/array_api_tests/special_cases/test_abs.py b/array_api_tests/special_cases/test_abs.py deleted file mode 100644 index 4ed04d02..00000000 --- a/array_api_tests/special_cases/test_abs.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Special cases tests for abs. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import abs - -from hypothesis import given - - -@given(numeric_arrays) -def test_abs_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `abs(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = abs(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_abs_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `abs(x, /)`: - - - If `x_i` is `-0`, the result is `+0`. - - """ - res = abs(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_abs_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `abs(x, /)`: - - - If `x_i` is `-infinity`, the result is `+infinity`. - - """ - res = abs(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_acos.py b/array_api_tests/special_cases/test_acos.py deleted file mode 100644 index b1c3cb56..00000000 --- a/array_api_tests/special_cases/test_acos.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Special cases tests for acos. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, greater, less, one, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import acos - -from hypothesis import given - - -@given(numeric_arrays) -def test_acos_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `acos(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = acos(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_acos_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `acos(x, /)`: - - - If `x_i` is `1`, the result is `+0`. - - """ - res = acos(arg1) - mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_acos_special_cases_one_arg_greater(arg1): - """ - Special case test for `acos(x, /)`: - - - If `x_i` is greater than `1`, the result is `NaN`. - - """ - res = acos(arg1) - mask = greater(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_acos_special_cases_one_arg_less(arg1): - """ - Special case test for `acos(x, /)`: - - - If `x_i` is less than `-1`, the result is `NaN`. - - """ - res = acos(arg1) - mask = less(arg1, -one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_acosh.py b/array_api_tests/special_cases/test_acosh.py deleted file mode 100644 index 8749eaf2..00000000 --- a/array_api_tests/special_cases/test_acosh.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Special cases tests for acosh. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, less, one, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import acosh - -from hypothesis import given - - -@given(numeric_arrays) -def test_acosh_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `acosh(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = acosh(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_acosh_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `acosh(x, /)`: - - - If `x_i` is `1`, the result is `+0`. - - """ - res = acosh(arg1) - mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_acosh_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `acosh(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = acosh(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_acosh_special_cases_one_arg_less(arg1): - """ - Special case test for `acosh(x, /)`: - - - If `x_i` is less than `1`, the result is `NaN`. - - """ - res = acosh(arg1) - mask = less(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_add.py b/array_api_tests/special_cases/test_add.py deleted file mode 100644 index eaccb803..00000000 --- a/array_api_tests/special_cases/test_add.py +++ /dev/null @@ -1,226 +0,0 @@ -""" -Special cases tests for add. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, isfinite, - logical_and, logical_or, non_zero, zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import add - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = add(arg1, arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_1(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is `-infinity`, the result is `NaN`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_2(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is `+infinity`, the result is `NaN`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_3(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_4(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is `-infinity`, the result is `-infinity`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_5(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is a finite number, the result is `+infinity`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), isfinite(arg2)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_6(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is a finite number, the result is `-infinity`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), isfinite(arg2)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_7(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is a finite number and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = add(arg1, arg2) - mask = logical_and(isfinite(arg1), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_8(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is a finite number and `x2_i` is `-infinity`, the result is `-infinity`. - - """ - res = add(arg1, arg2) - mask = logical_and(isfinite(arg1), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_9(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `-0` and `x2_i` is `-0`, the result is `-0`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_10(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `-0` and `x2_i` is `+0`, the result is `+0`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_11(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is `-0`, the result is `+0`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_12(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is `+0`, the result is `+0`. - - """ - res = add(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_13(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is a nonzero finite number and `x2_i` is `-x1_i`, the result is `+0`. - - """ - res = add(arg1, arg2) - mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), exactly_equal(arg2, -arg1)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_either__equal(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is either `+0` or `-0` and `x2_i` is a nonzero finite number, the result is `x2_i`. - - """ - res = add(arg1, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_and(isfinite(arg2), non_zero(arg2))) - assert_exactly_equal(res[mask], (arg2)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__either(arg1, arg2): - """ - Special case test for `add(x1, x2, /)`: - - - If `x1_i` is a nonzero finite number and `x2_i` is either `+0` or `-0`, the result is `x1_i`. - - """ - res = add(arg1, arg2) - mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (arg1)[mask]) - -# TODO: Implement REMAINING test for: -# - In the remaining cases, when neither `infinity`, `+0`, `-0`, nor a `NaN` is involved, and the operands have the same mathematical sign or have different magnitudes, the sum must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported round mode. If the magnitude is too large to represent, the operation overflows and the result is an `infinity` of appropriate mathematical sign. diff --git a/array_api_tests/special_cases/test_asin.py b/array_api_tests/special_cases/test_asin.py deleted file mode 100644 index 0a41b716..00000000 --- a/array_api_tests/special_cases/test_asin.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for asin. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, greater, less, one, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import asin - -from hypothesis import given - - -@given(numeric_arrays) -def test_asin_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `asin(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = asin(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_asin_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `asin(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = asin(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_asin_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `asin(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = asin(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_asin_special_cases_one_arg_greater(arg1): - """ - Special case test for `asin(x, /)`: - - - If `x_i` is greater than `1`, the result is `NaN`. - - """ - res = asin(arg1) - mask = greater(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_asin_special_cases_one_arg_less(arg1): - """ - Special case test for `asin(x, /)`: - - - If `x_i` is less than `-1`, the result is `NaN`. - - """ - res = asin(arg1) - mask = less(arg1, -one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_asinh.py b/array_api_tests/special_cases/test_asinh.py deleted file mode 100644 index a54d3346..00000000 --- a/array_api_tests/special_cases/test_asinh.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for asinh. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import asinh - -from hypothesis import given - - -@given(numeric_arrays) -def test_asinh_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `asinh(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = asinh(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_asinh_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `asinh(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = asinh(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_asinh_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `asinh(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = asinh(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_asinh_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `asinh(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = asinh(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_asinh_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `asinh(x, /)`: - - - If `x_i` is `-infinity`, the result is `-infinity`. - - """ - res = asinh(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_atan.py b/array_api_tests/special_cases/test_atan.py deleted file mode 100644 index 4b6936ed..00000000 --- a/array_api_tests/special_cases/test_atan.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for atan. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, zero, π -from ..hypothesis_helpers import numeric_arrays -from .._array_module import atan - -from hypothesis import given - - -@given(numeric_arrays) -def test_atan_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `atan(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = atan(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_atan_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `atan(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = atan(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_atan_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `atan(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = atan(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_atan_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `atan(x, /)`: - - - If `x_i` is `+infinity`, the result is an implementation-dependent approximation to `+π/2`. - - """ - res = atan(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/2)[mask]) - - -@given(numeric_arrays) -def test_atan_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `atan(x, /)`: - - - If `x_i` is `-infinity`, the result is an implementation-dependent approximation to `-π/2`. - - """ - res = atan(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/2)[mask]) diff --git a/array_api_tests/special_cases/test_atan2.py b/array_api_tests/special_cases/test_atan2.py deleted file mode 100644 index 9d7452e7..00000000 --- a/array_api_tests/special_cases/test_atan2.py +++ /dev/null @@ -1,314 +0,0 @@ -""" -Special cases tests for atan2. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, greater, infinity, isfinite, - less, logical_and, logical_or, zero, π) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import atan2 - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = atan2(arg1, arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_greater__equal_1(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is greater than `0` and `x2_i` is `+0`, the result is an implementation-dependent approximation to `+π/2`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/2)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_greater__equal_2(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is greater than `0` and `x2_i` is `-0`, the result is an implementation-dependent approximation to `+π/2`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/2)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__greater_1(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__greater_2(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `-0` and `x2_i` is greater than `0`, the result is `-0`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_1(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is `+0`, the result is `+0`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_2(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is `-0`, the result is an implementation-dependent approximation to `+π`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_3(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `-0` and `x2_i` is `+0`, the result is `-0`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_4(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `-0` and `x2_i` is `-0`, the result is an implementation-dependent approximation to `-π`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_5(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is finite, the result is an implementation-dependent approximation to `+π/2`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), isfinite(arg2)) - assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/2)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_6(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is finite, the result is an implementation-dependent approximation to `-π/2`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), isfinite(arg2)) - assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/2)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_7(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is `+infinity`, the result is an implementation-dependent approximation to `+π/4`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/4)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_8(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is `-infinity`, the result is an implementation-dependent approximation to `+3π/4`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (+3*π(arg1.shape, arg1.dtype)/4)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_9(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is `+infinity`, the result is an implementation-dependent approximation to `-π/4`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/4)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__equal_10(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is `-infinity`, the result is an implementation-dependent approximation to `-3π/4`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-3*π(arg1.shape, arg1.dtype)/4)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__less_1(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is an implementation-dependent approximation to `+π`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_equal__less_2(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is `-0` and `x2_i` is less than `0`, the result is an implementation-dependent approximation to `-π`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_less__equal_1(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is less than `0` and `x2_i` is `+0`, the result is an implementation-dependent approximation to `-π/2`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/2)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_less__equal_2(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is less than `0` and `x2_i` is `-0`, the result is an implementation-dependent approximation to `-π/2`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/2)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_greater_equal__equal_1(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is greater than `0`, `x1_i` is a finite number, and `x2_i` is `+infinity`, the result is `+0`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_greater_equal__equal_2(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is greater than `0`, `x1_i` is a finite number, and `x2_i` is `-infinity`, the result is an implementation-dependent approximation to `+π`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_less_equal__equal_1(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is less than `0`, `x1_i` is a finite number, and `x2_i` is `+infinity`, the result is `-0`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_atan2_special_cases_two_args_less_equal__equal_2(arg1, arg2): - """ - Special case test for `atan2(x1, x2, /)`: - - - If `x1_i` is less than `0`, `x1_i` is a finite number, and `x2_i` is `-infinity`, the result is an implementation-dependent approximation to `-π`. - - """ - res = atan2(arg1, arg2) - mask = logical_and(logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_atanh.py b/array_api_tests/special_cases/test_atanh.py deleted file mode 100644 index 6e26cc99..00000000 --- a/array_api_tests/special_cases/test_atanh.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -Special cases tests for atanh. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, greater, infinity, less, one, - zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import atanh - -from hypothesis import given - - -@given(numeric_arrays) -def test_atanh_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `atanh(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = atanh(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_atanh_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `atanh(x, /)`: - - - If `x_i` is `-1`, the result is `-infinity`. - - """ - res = atanh(arg1) - mask = exactly_equal(arg1, -one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_atanh_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `atanh(x, /)`: - - - If `x_i` is `+1`, the result is `+infinity`. - - """ - res = atanh(arg1) - mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_atanh_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `atanh(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = atanh(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_atanh_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `atanh(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = atanh(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_atanh_special_cases_one_arg_less(arg1): - """ - Special case test for `atanh(x, /)`: - - - If `x_i` is less than `-1`, the result is `NaN`. - - """ - res = atanh(arg1) - mask = less(arg1, -one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_atanh_special_cases_one_arg_greater(arg1): - """ - Special case test for `atanh(x, /)`: - - - If `x_i` is greater than `1`, the result is `NaN`. - - """ - res = atanh(arg1) - mask = greater(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_ceil.py b/array_api_tests/special_cases/test_ceil.py deleted file mode 100644 index 5c9eee86..00000000 --- a/array_api_tests/special_cases/test_ceil.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Special cases tests for ceil. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, isintegral, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import ceil - -from hypothesis import given - - -@given(numeric_arrays) -def test_ceil_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `ceil(x, /)`: - - - If `x_i` is already integer-valued, the result is `x_i`. - - """ - res = ceil(arg1) - mask = isintegral(arg1) - assert_exactly_equal(res[mask], (arg1)[mask]) - - -@given(numeric_arrays) -def test_ceil_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `ceil(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = ceil(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_ceil_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `ceil(x, /)`: - - - If `x_i` is `-infinity`, the result is `-infinity`. - - """ - res = ceil(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_ceil_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `ceil(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = ceil(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_ceil_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `ceil(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = ceil(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_ceil_special_cases_one_arg_equal_6(arg1): - """ - Special case test for `ceil(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = ceil(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_cos.py b/array_api_tests/special_cases/test_cos.py deleted file mode 100644 index e80a7130..00000000 --- a/array_api_tests/special_cases/test_cos.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for cos. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import cos - -from hypothesis import given - - -@given(numeric_arrays) -def test_cos_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `cos(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = cos(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_cos_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `cos(x, /)`: - - - If `x_i` is `+0`, the result is `1`. - - """ - res = cos(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_cos_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `cos(x, /)`: - - - If `x_i` is `-0`, the result is `1`. - - """ - res = cos(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_cos_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `cos(x, /)`: - - - If `x_i` is `+infinity`, the result is `NaN`. - - """ - res = cos(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_cos_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `cos(x, /)`: - - - If `x_i` is `-infinity`, the result is `NaN`. - - """ - res = cos(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_cosh.py b/array_api_tests/special_cases/test_cosh.py deleted file mode 100644 index bdca4a82..00000000 --- a/array_api_tests/special_cases/test_cosh.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for cosh. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import cosh - -from hypothesis import given - - -@given(numeric_arrays) -def test_cosh_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `cosh(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = cosh(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_cosh_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `cosh(x, /)`: - - - If `x_i` is `+0`, the result is `1`. - - """ - res = cosh(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_cosh_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `cosh(x, /)`: - - - If `x_i` is `-0`, the result is `1`. - - """ - res = cosh(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_cosh_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `cosh(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = cosh(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_cosh_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `cosh(x, /)`: - - - If `x_i` is `-infinity`, the result is `+infinity`. - - """ - res = cosh(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_divide.py b/array_api_tests/special_cases/test_divide.py deleted file mode 100644 index fe1596c9..00000000 --- a/array_api_tests/special_cases/test_divide.py +++ /dev/null @@ -1,293 +0,0 @@ -""" -Special cases tests for divide. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, assert_negative_mathematical_sign, - assert_positive_mathematical_sign, exactly_equal, greater, infinity, - isfinite, isnegative, ispositive, less, logical_and, logical_not, - logical_or, non_zero, same_sign, zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import divide - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = divide(arg1, arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_either__either_1(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+infinity` or `-infinity`, the result is `NaN`. - - """ - res = divide(arg1, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_either__either_2(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is either `+0` or `-0` and `x2_i` is either `+0` or `-0`, the result is `NaN`. - - """ - res = divide(arg1, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__greater_1(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`. - - """ - res = divide(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__greater_2(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is `-0` and `x2_i` is greater than `0`, the result is `-0`. - - """ - res = divide(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__less_1(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is `-0`. - - """ - res = divide(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__less_2(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is `-0` and `x2_i` is less than `0`, the result is `+0`. - - """ - res = divide(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_greater__equal_1(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is greater than `0` and `x2_i` is `+0`, the result is `+infinity`. - - """ - res = divide(arg1, arg2) - mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_greater__equal_2(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is greater than `0` and `x2_i` is `-0`, the result is `-infinity`. - - """ - res = divide(arg1, arg2) - mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_less__equal_1(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is less than `0` and `x2_i` is `+0`, the result is `-infinity`. - - """ - res = divide(arg1, arg2) - mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_less__equal_2(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is less than `0` and `x2_i` is `-0`, the result is `+infinity`. - - """ - res = divide(arg1, arg2) - mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__equal_1(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is a positive (i.e., greater than `0`) finite number, the result is `+infinity`. - - """ - res = divide(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), ispositive(arg2))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__equal_2(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is a negative (i.e., less than `0`) finite number, the result is `-infinity`. - - """ - res = divide(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), isnegative(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__equal_3(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is a positive (i.e., greater than `0`) finite number, the result is `-infinity`. - - """ - res = divide(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), ispositive(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__equal_4(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is a negative (i.e., less than `0`) finite number, the result is `+infinity`. - - """ - res = divide(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), isnegative(arg2))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__equal_5(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is a positive (i.e., greater than `0`) finite number and `x2_i` is `+infinity`, the result is `+0`. - - """ - res = divide(arg1, arg2) - mask = logical_and(logical_and(isfinite(arg1), ispositive(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__equal_6(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is a positive (i.e., greater than `0`) finite number and `x2_i` is `-infinity`, the result is `-0`. - - """ - res = divide(arg1, arg2) - mask = logical_and(logical_and(isfinite(arg1), ispositive(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__equal_7(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is a negative (i.e., less than `0`) finite number and `x2_i` is `+infinity`, the result is `-0`. - - """ - res = divide(arg1, arg2) - mask = logical_and(logical_and(isfinite(arg1), isnegative(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_equal__equal_8(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` is a negative (i.e., less than `0`) finite number and `x2_i` is `-infinity`, the result is `+0`. - - """ - res = divide(arg1, arg2) - mask = logical_and(logical_and(isfinite(arg1), isnegative(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_same_sign_both(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` and `x2_i` have the same mathematical sign and are both nonzero finite numbers, the result has a positive mathematical sign. - - """ - res = divide(arg1, arg2) - mask = logical_and(same_sign(arg1, arg2), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2)))) - assert_positive_mathematical_sign(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_divide_special_cases_two_args_different_signs_both(arg1, arg2): - """ - Special case test for `divide(x1, x2, /)`: - - - If `x1_i` and `x2_i` have different mathematical signs and are both nonzero finite numbers, the result has a negative mathematical sign. - - """ - res = divide(arg1, arg2) - mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2)))) - assert_negative_mathematical_sign(res[mask]) - -# TODO: Implement REMAINING test for: -# - In the remaining cases, where neither `-infinity`, `+0`, `-0`, nor `NaN` is involved, the quotient must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode. If the magnitude is too larger to represent, the operation overflows and the result is an `infinity` of appropriate mathematical sign. If the magnitude is too small to represent, the operation underflows and the result is a zero of appropriate mathematical sign. diff --git a/array_api_tests/special_cases/test_dunder_abs.py b/array_api_tests/special_cases/test_dunder_abs.py deleted file mode 100644 index 5028f1b7..00000000 --- a/array_api_tests/special_cases/test_dunder_abs.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Special cases tests for __abs__. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, zero -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - - -@given(numeric_arrays) -def test_abs_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `__abs__(self, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = (arg1).__abs__() - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_abs_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `__abs__(self, /)`: - - - If `x_i` is `-0`, the result is `+0`. - - """ - res = (arg1).__abs__() - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_abs_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `__abs__(self, /)`: - - - If `x_i` is `-infinity`, the result is `+infinity`. - - """ - res = (arg1).__abs__() - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_dunder_add.py b/array_api_tests/special_cases/test_dunder_add.py deleted file mode 100644 index d3b5e169..00000000 --- a/array_api_tests/special_cases/test_dunder_add.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -Special cases tests for __add__. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, isfinite, - logical_and, logical_or, non_zero, zero) -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = arg1.__add__(arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_1(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is `-infinity`, the result is `NaN`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_2(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is `+infinity`, the result is `NaN`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_3(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_4(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is `-infinity`, the result is `-infinity`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_5(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is a finite number, the result is `+infinity`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), isfinite(arg2)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_6(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is a finite number, the result is `-infinity`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), isfinite(arg2)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_7(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is a finite number and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(isfinite(arg1), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_8(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is a finite number and `x2_i` is `-infinity`, the result is `-infinity`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(isfinite(arg1), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_9(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `-0` and `x2_i` is `-0`, the result is `-0`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_10(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `-0` and `x2_i` is `+0`, the result is `+0`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_11(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is `-0`, the result is `+0`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_12(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is `+0`, the result is `+0`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__equal_13(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is a nonzero finite number and `x2_i` is `-x1_i`, the result is `+0`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), exactly_equal(arg2, -arg1)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_either__equal(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is either `+0` or `-0` and `x2_i` is a nonzero finite number, the result is `x2_i`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_and(isfinite(arg2), non_zero(arg2))) - assert_exactly_equal(res[mask], (arg2)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_add_special_cases_two_args_equal__either(arg1, arg2): - """ - Special case test for `__add__(self, other, /)`: - - - If `x1_i` is a nonzero finite number and `x2_i` is either `+0` or `-0`, the result is `x1_i`. - - """ - res = arg1.__add__(arg2) - mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (arg1)[mask]) - -# TODO: Implement REMAINING test for: -# - In the remaining cases, when neither `infinity`, `+0`, `-0`, nor a `NaN` is involved, and the operands have the same mathematical sign or have different magnitudes, the sum must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported round mode. If the magnitude is too large to represent, the operation overflows and the result is an `infinity` of appropriate mathematical sign. diff --git a/array_api_tests/special_cases/test_dunder_iadd.py b/array_api_tests/special_cases/test_dunder_iadd.py deleted file mode 100644 index 692877bf..00000000 --- a/array_api_tests/special_cases/test_dunder_iadd.py +++ /dev/null @@ -1,243 +0,0 @@ -""" -Special cases tests for __iadd__. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from operator import iadd - -from ..array_helpers import (NaN, asarray, assert_exactly_equal, exactly_equal, infinity, isfinite, - logical_and, logical_or, non_zero, zero) -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_1(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is `-infinity`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_2(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is `+infinity`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_3(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_4(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is `-infinity`, the result is `-infinity`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_5(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is a finite number, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), isfinite(arg2)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_6(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is a finite number, the result is `-infinity`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), isfinite(arg2)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_7(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is a finite number and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(isfinite(arg1), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_8(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is a finite number and `x2_i` is `-infinity`, the result is `-infinity`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(isfinite(arg1), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_9(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `-0` and `x2_i` is `-0`, the result is `-0`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_10(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `-0` and `x2_i` is `+0`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_11(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is `-0`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_12(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is `+0`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__equal_13(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is a nonzero finite number and `x2_i` is `-x1_i`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), exactly_equal(arg2, -arg1)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_either__equal(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is either `+0` or `-0` and `x2_i` is a nonzero finite number, the result is `x2_i`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_and(isfinite(arg2), non_zero(arg2))) - assert_exactly_equal(res[mask], (arg2)[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_iadd_special_cases_two_args_equal__either(arg1, arg2): - """ - Special case test for `__iadd__(self, other, /)`: - - - If `x1_i` is a nonzero finite number and `x2_i` is either `+0` or `-0`, the result is `x1_i`. - - """ - res = asarray(arg1, copy=True) - iadd(res, arg2) - mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (arg1)[mask]) - -# TODO: Implement REMAINING test for: -# - In the remaining cases, when neither `infinity`, `+0`, `-0`, nor a `NaN` is involved, and the operands have the same mathematical sign or have different magnitudes, the sum must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported round mode. If the magnitude is too large to represent, the operation overflows and the result is an `infinity` of appropriate mathematical sign. diff --git a/array_api_tests/special_cases/test_dunder_imul.py b/array_api_tests/special_cases/test_dunder_imul.py deleted file mode 100644 index a077fb9a..00000000 --- a/array_api_tests/special_cases/test_dunder_imul.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Special cases tests for __imul__. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from operator import imul - -from ..array_helpers import (NaN, asarray, assert_exactly_equal, assert_isinf, - assert_negative_mathematical_sign, assert_positive_mathematical_sign, - exactly_equal, infinity, isfinite, logical_and, logical_not, - logical_or, non_zero, same_sign, zero) -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_imul_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `__imul__(self, other, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - imul(res, arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_imul_special_cases_two_args_either__either_1(arg1, arg2): - """ - Special case test for `__imul__(self, other, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+0` or `-0`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - imul(res, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_imul_special_cases_two_args_either__either_2(arg1, arg2): - """ - Special case test for `__imul__(self, other, /)`: - - - If `x1_i` is either `+0` or `-0` and `x2_i` is either `+infinity` or `-infinity`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - imul(res, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_imul_special_cases_two_args_either__either_3(arg1, arg2): - """ - Special case test for `__imul__(self, other, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+infinity` or `-infinity`, the result is a signed infinity with the mathematical sign determined by the rule already stated above. - - """ - res = asarray(arg1, copy=True) - imul(res, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_isinf(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_imul_special_cases_two_args_same_sign_except(arg1, arg2): - """ - Special case test for `__imul__(self, other, /)`: - - - If `x1_i` and `x2_i` have the same mathematical sign, the result has a positive mathematical sign, unless the result is `NaN`. If the result is `NaN`, the "sign" of `NaN` is implementation-defined. - - """ - res = asarray(arg1, copy=True) - imul(res, arg2) - mask = logical_and(same_sign(arg1, arg2), logical_not(exactly_equal(res, NaN(res.shape, res.dtype)))) - assert_positive_mathematical_sign(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_imul_special_cases_two_args_different_signs_except(arg1, arg2): - """ - Special case test for `__imul__(self, other, /)`: - - - If `x1_i` and `x2_i` have different mathematical signs, the result has a negative mathematical sign, unless the result is `NaN`. If the result is `NaN`, the "sign" of `NaN` is implementation-defined. - - """ - res = asarray(arg1, copy=True) - imul(res, arg2) - mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_not(exactly_equal(res, NaN(res.shape, res.dtype)))) - assert_negative_mathematical_sign(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_imul_special_cases_two_args_either__equal(arg1, arg2): - """ - Special case test for `__imul__(self, other, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is a nonzero finite number, the result is a signed infinity with the mathematical sign determined by the rule already stated above. - - """ - res = asarray(arg1, copy=True) - imul(res, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_and(isfinite(arg2), non_zero(arg2))) - assert_isinf(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_imul_special_cases_two_args_equal__either(arg1, arg2): - """ - Special case test for `__imul__(self, other, /)`: - - - If `x1_i` is a nonzero finite number and `x2_i` is either `+infinity` or `-infinity`, the result is a signed infinity with the mathematical sign determined by the rule already stated above. - - """ - res = asarray(arg1, copy=True) - imul(res, arg2) - mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_isinf(res[mask]) - -# TODO: Implement REMAINING test for: -# - In the remaining cases, where neither `infinity` nor `NaN` is involved, the product must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode. If the magnitude is too large to represent, the result is an `infinity` of appropriate mathematical sign. If the magnitude is too small to represent, the result is a zero of appropriate mathematical sign. diff --git a/array_api_tests/special_cases/test_dunder_ipow.py b/array_api_tests/special_cases/test_dunder_ipow.py deleted file mode 100644 index 0ca5e705..00000000 --- a/array_api_tests/special_cases/test_dunder_ipow.py +++ /dev/null @@ -1,353 +0,0 @@ -""" -Special cases tests for __ipow__. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from operator import ipow - -from ..array_helpers import (NaN, asarray, assert_exactly_equal, exactly_equal, greater, infinity, - isfinite, isintegral, isodd, less, logical_and, logical_not, notequal, - one, zero) -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_notequal__equal(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is not equal to `1` and `x2_i` is `NaN`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(logical_not(exactly_equal(arg1, one(arg1.shape, arg1.dtype))), exactly_equal(arg2, NaN(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_even_if_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x2_i` is `+0`, the result is `1`, even if `x1_i` is `NaN`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = exactly_equal(arg2, zero(arg2.shape, arg2.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_even_if_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x2_i` is `-0`, the result is `1`, even if `x1_i` is `NaN`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__notequal_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `NaN` and `x2_i` is not equal to `0`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), notequal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__notequal_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `1` and `x2_i` is not `NaN`, the result is `1`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, one(arg1.shape, arg1.dtype)), logical_not(exactly_equal(arg2, NaN(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_absgreater__equal_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `abs(x1_i)` is greater than `1` and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(greater(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_absgreater__equal_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `abs(x1_i)` is greater than `1` and `x2_i` is `-infinity`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(greater(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_absequal__equal_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `abs(x1_i)` is `1` and `x2_i` is `+infinity`, the result is `1`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_absequal__equal_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `abs(x1_i)` is `1` and `x2_i` is `-infinity`, the result is `1`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_absless__equal_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `abs(x1_i)` is less than `1` and `x2_i` is `+infinity`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(less(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_absless__equal_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `abs(x1_i)` is less than `1` and `x2_i` is `-infinity`, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(less(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__greater_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is greater than `0`, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__greater_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__less_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is less than `0`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__less_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__greater_equal_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-infinity`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__greater_equal_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-0`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__greater_notequal_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is greater than `0`, and `x2_i` is not an odd integer value, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__greater_notequal_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is not an odd integer value, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__less_equal_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is less than `0`, and `x2_i` is an odd integer value, the result is `-0`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__less_equal_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `-0`, `x2_i` is less than `0`, and `x2_i` is an odd integer value, the result is `-infinity`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__less_notequal_1(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is less than `0`, and `x2_i` is not an odd integer value, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_equal__less_notequal_2(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is `-0`, `x2_i` is less than `0`, and `x2_i` is not an odd integer value, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_ipow_special_cases_two_args_less_equal__equal_notequal(arg1, arg2): - """ - Special case test for `__ipow__(self, other, /)`: - - - If `x1_i` is less than `0`, `x1_i` is a finite number, `x2_i` is a finite number, and `x2_i` is not an integer value, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - ipow(res, arg2) - mask = logical_and(logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), logical_and(isfinite(arg2), logical_not(isintegral(arg2)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_dunder_itruediv.py b/array_api_tests/special_cases/test_dunder_itruediv.py deleted file mode 100644 index e6747d40..00000000 --- a/array_api_tests/special_cases/test_dunder_itruediv.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -Special cases tests for __itruediv__. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from operator import itruediv - -from ..array_helpers import (NaN, asarray, assert_exactly_equal, assert_negative_mathematical_sign, - assert_positive_mathematical_sign, exactly_equal, greater, infinity, - isfinite, isnegative, ispositive, less, logical_and, logical_not, - logical_or, non_zero, same_sign, zero) -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_either__either_1(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+infinity` or `-infinity`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_either__either_2(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is either `+0` or `-0` and `x2_i` is either `+0` or `-0`, the result is `NaN`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__greater_1(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__greater_2(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is `-0` and `x2_i` is greater than `0`, the result is `-0`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__less_1(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is `-0`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__less_2(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is `-0` and `x2_i` is less than `0`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_greater__equal_1(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is greater than `0` and `x2_i` is `+0`, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_greater__equal_2(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is greater than `0` and `x2_i` is `-0`, the result is `-infinity`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_less__equal_1(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is less than `0` and `x2_i` is `+0`, the result is `-infinity`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_less__equal_2(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is less than `0` and `x2_i` is `-0`, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__equal_1(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is a positive (i.e., greater than `0`) finite number, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), ispositive(arg2))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__equal_2(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is a negative (i.e., less than `0`) finite number, the result is `-infinity`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), isnegative(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__equal_3(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is a positive (i.e., greater than `0`) finite number, the result is `-infinity`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), ispositive(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__equal_4(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is a negative (i.e., less than `0`) finite number, the result is `+infinity`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), isnegative(arg2))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__equal_5(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is a positive (i.e., greater than `0`) finite number and `x2_i` is `+infinity`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(logical_and(isfinite(arg1), ispositive(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__equal_6(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is a positive (i.e., greater than `0`) finite number and `x2_i` is `-infinity`, the result is `-0`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(logical_and(isfinite(arg1), ispositive(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__equal_7(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is a negative (i.e., less than `0`) finite number and `x2_i` is `+infinity`, the result is `-0`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(logical_and(isfinite(arg1), isnegative(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_equal__equal_8(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` is a negative (i.e., less than `0`) finite number and `x2_i` is `-infinity`, the result is `+0`. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(logical_and(isfinite(arg1), isnegative(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_same_sign_both(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` and `x2_i` have the same mathematical sign and are both nonzero finite numbers, the result has a positive mathematical sign. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(same_sign(arg1, arg2), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2)))) - assert_positive_mathematical_sign(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_itruediv_special_cases_two_args_different_signs_both(arg1, arg2): - """ - Special case test for `__itruediv__(self, other, /)`: - - - If `x1_i` and `x2_i` have different mathematical signs and are both nonzero finite numbers, the result has a negative mathematical sign. - - """ - res = asarray(arg1, copy=True) - itruediv(res, arg2) - mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2)))) - assert_negative_mathematical_sign(res[mask]) - -# TODO: Implement REMAINING test for: -# - In the remaining cases, where neither `-infinity`, `+0`, `-0`, nor `NaN` is involved, the quotient must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode. If the magnitude is too larger to represent, the operation overflows and the result is an `infinity` of appropriate mathematical sign. If the magnitude is too small to represent, the operation underflows and the result is a zero of appropriate mathematical sign. diff --git a/array_api_tests/special_cases/test_dunder_mul.py b/array_api_tests/special_cases/test_dunder_mul.py deleted file mode 100644 index 983f3654..00000000 --- a/array_api_tests/special_cases/test_dunder_mul.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -Special cases tests for __mul__. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, assert_isinf, - assert_negative_mathematical_sign, assert_positive_mathematical_sign, - exactly_equal, infinity, isfinite, logical_and, logical_not, - logical_or, non_zero, same_sign, zero) -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_mul_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `__mul__(self, other, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = arg1.__mul__(arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_mul_special_cases_two_args_either__either_1(arg1, arg2): - """ - Special case test for `__mul__(self, other, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+0` or `-0`, the result is `NaN`. - - """ - res = arg1.__mul__(arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_mul_special_cases_two_args_either__either_2(arg1, arg2): - """ - Special case test for `__mul__(self, other, /)`: - - - If `x1_i` is either `+0` or `-0` and `x2_i` is either `+infinity` or `-infinity`, the result is `NaN`. - - """ - res = arg1.__mul__(arg2) - mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_mul_special_cases_two_args_either__either_3(arg1, arg2): - """ - Special case test for `__mul__(self, other, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+infinity` or `-infinity`, the result is a signed infinity with the mathematical sign determined by the rule already stated above. - - """ - res = arg1.__mul__(arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_isinf(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_mul_special_cases_two_args_same_sign_except(arg1, arg2): - """ - Special case test for `__mul__(self, other, /)`: - - - If `x1_i` and `x2_i` have the same mathematical sign, the result has a positive mathematical sign, unless the result is `NaN`. If the result is `NaN`, the "sign" of `NaN` is implementation-defined. - - """ - res = arg1.__mul__(arg2) - mask = logical_and(same_sign(arg1, arg2), logical_not(exactly_equal(res, NaN(res.shape, res.dtype)))) - assert_positive_mathematical_sign(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_mul_special_cases_two_args_different_signs_except(arg1, arg2): - """ - Special case test for `__mul__(self, other, /)`: - - - If `x1_i` and `x2_i` have different mathematical signs, the result has a negative mathematical sign, unless the result is `NaN`. If the result is `NaN`, the "sign" of `NaN` is implementation-defined. - - """ - res = arg1.__mul__(arg2) - mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_not(exactly_equal(res, NaN(res.shape, res.dtype)))) - assert_negative_mathematical_sign(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_mul_special_cases_two_args_either__equal(arg1, arg2): - """ - Special case test for `__mul__(self, other, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is a nonzero finite number, the result is a signed infinity with the mathematical sign determined by the rule already stated above. - - """ - res = arg1.__mul__(arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_and(isfinite(arg2), non_zero(arg2))) - assert_isinf(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_mul_special_cases_two_args_equal__either(arg1, arg2): - """ - Special case test for `__mul__(self, other, /)`: - - - If `x1_i` is a nonzero finite number and `x2_i` is either `+infinity` or `-infinity`, the result is a signed infinity with the mathematical sign determined by the rule already stated above. - - """ - res = arg1.__mul__(arg2) - mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_isinf(res[mask]) - -# TODO: Implement REMAINING test for: -# - In the remaining cases, where neither `infinity` nor `NaN` is involved, the product must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode. If the magnitude is too large to represent, the result is an `infinity` of appropriate mathematical sign. If the magnitude is too small to represent, the result is a zero of appropriate mathematical sign. diff --git a/array_api_tests/special_cases/test_dunder_pow.py b/array_api_tests/special_cases/test_dunder_pow.py deleted file mode 100644 index b90abd4c..00000000 --- a/array_api_tests/special_cases/test_dunder_pow.py +++ /dev/null @@ -1,326 +0,0 @@ -""" -Special cases tests for __pow__. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, greater, infinity, isfinite, - isintegral, isodd, less, logical_and, logical_not, notequal, one, zero) -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_notequal__equal(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is not equal to `1` and `x2_i` is `NaN`, the result is `NaN`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(logical_not(exactly_equal(arg1, one(arg1.shape, arg1.dtype))), exactly_equal(arg2, NaN(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_even_if_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x2_i` is `+0`, the result is `1`, even if `x1_i` is `NaN`. - - """ - res = arg1.__pow__(arg2) - mask = exactly_equal(arg2, zero(arg2.shape, arg2.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_even_if_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x2_i` is `-0`, the result is `1`, even if `x1_i` is `NaN`. - - """ - res = arg1.__pow__(arg2) - mask = exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__notequal_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `NaN` and `x2_i` is not equal to `0`, the result is `NaN`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), notequal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__notequal_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `1` and `x2_i` is not `NaN`, the result is `1`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, one(arg1.shape, arg1.dtype)), logical_not(exactly_equal(arg2, NaN(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absgreater__equal_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `abs(x1_i)` is greater than `1` and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(greater(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absgreater__equal_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `abs(x1_i)` is greater than `1` and `x2_i` is `-infinity`, the result is `+0`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(greater(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absequal__equal_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `abs(x1_i)` is `1` and `x2_i` is `+infinity`, the result is `1`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absequal__equal_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `abs(x1_i)` is `1` and `x2_i` is `-infinity`, the result is `1`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absless__equal_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `abs(x1_i)` is less than `1` and `x2_i` is `+infinity`, the result is `+0`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(less(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absless__equal_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `abs(x1_i)` is less than `1` and `x2_i` is `-infinity`, the result is `+infinity`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(less(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is greater than `0`, the result is `+infinity`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is less than `0`, the result is `+0`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is `+infinity`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_equal_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-infinity`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_equal_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-0`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_notequal_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is greater than `0`, and `x2_i` is not an odd integer value, the result is `+infinity`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_notequal_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is not an odd integer value, the result is `+0`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_equal_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is less than `0`, and `x2_i` is an odd integer value, the result is `-0`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_equal_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `-0`, `x2_i` is less than `0`, and `x2_i` is an odd integer value, the result is `-infinity`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_notequal_1(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is less than `0`, and `x2_i` is not an odd integer value, the result is `+0`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_notequal_2(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is `-0`, `x2_i` is less than `0`, and `x2_i` is not an odd integer value, the result is `+infinity`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_less_equal__equal_notequal(arg1, arg2): - """ - Special case test for `__pow__(self, other, /)`: - - - If `x1_i` is less than `0`, `x1_i` is a finite number, `x2_i` is a finite number, and `x2_i` is not an integer value, the result is `NaN`. - - """ - res = arg1.__pow__(arg2) - mask = logical_and(logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), logical_and(isfinite(arg2), logical_not(isintegral(arg2)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_dunder_truediv.py b/array_api_tests/special_cases/test_dunder_truediv.py deleted file mode 100644 index d08302f2..00000000 --- a/array_api_tests/special_cases/test_dunder_truediv.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -Special cases tests for __truediv__. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, assert_negative_mathematical_sign, - assert_positive_mathematical_sign, exactly_equal, greater, infinity, - isfinite, isnegative, ispositive, less, logical_and, logical_not, - logical_or, non_zero, same_sign, zero) -from ..hypothesis_helpers import numeric_arrays - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_either__either_1(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+infinity` or `-infinity`, the result is `NaN`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_either__either_2(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is either `+0` or `-0` and `x2_i` is either `+0` or `-0`, the result is `NaN`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__greater_1(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__greater_2(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is `-0` and `x2_i` is greater than `0`, the result is `-0`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__less_1(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is `-0`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__less_2(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is `-0` and `x2_i` is less than `0`, the result is `+0`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_greater__equal_1(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is greater than `0` and `x2_i` is `+0`, the result is `+infinity`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_greater__equal_2(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is greater than `0` and `x2_i` is `-0`, the result is `-infinity`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_less__equal_1(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is less than `0` and `x2_i` is `+0`, the result is `-infinity`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_less__equal_2(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is less than `0` and `x2_i` is `-0`, the result is `+infinity`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__equal_1(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is a positive (i.e., greater than `0`) finite number, the result is `+infinity`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), ispositive(arg2))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__equal_2(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is a negative (i.e., less than `0`) finite number, the result is `-infinity`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), isnegative(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__equal_3(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is a positive (i.e., greater than `0`) finite number, the result is `-infinity`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), ispositive(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__equal_4(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is `-infinity` and `x2_i` is a negative (i.e., less than `0`) finite number, the result is `+infinity`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), isnegative(arg2))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__equal_5(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is a positive (i.e., greater than `0`) finite number and `x2_i` is `+infinity`, the result is `+0`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(logical_and(isfinite(arg1), ispositive(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__equal_6(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is a positive (i.e., greater than `0`) finite number and `x2_i` is `-infinity`, the result is `-0`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(logical_and(isfinite(arg1), ispositive(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__equal_7(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is a negative (i.e., less than `0`) finite number and `x2_i` is `+infinity`, the result is `-0`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(logical_and(isfinite(arg1), isnegative(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_equal__equal_8(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` is a negative (i.e., less than `0`) finite number and `x2_i` is `-infinity`, the result is `+0`. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(logical_and(isfinite(arg1), isnegative(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_same_sign_both(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` and `x2_i` have the same mathematical sign and are both nonzero finite numbers, the result has a positive mathematical sign. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(same_sign(arg1, arg2), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2)))) - assert_positive_mathematical_sign(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_truediv_special_cases_two_args_different_signs_both(arg1, arg2): - """ - Special case test for `__truediv__(self, other, /)`: - - - If `x1_i` and `x2_i` have different mathematical signs and are both nonzero finite numbers, the result has a negative mathematical sign. - - """ - res = arg1.__truediv__(arg2) - mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2)))) - assert_negative_mathematical_sign(res[mask]) - -# TODO: Implement REMAINING test for: -# - In the remaining cases, where neither `-infinity`, `+0`, `-0`, nor `NaN` is involved, the quotient must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode. If the magnitude is too larger to represent, the operation overflows and the result is an `infinity` of appropriate mathematical sign. If the magnitude is too small to represent, the operation underflows and the result is a zero of appropriate mathematical sign. diff --git a/array_api_tests/special_cases/test_exp.py b/array_api_tests/special_cases/test_exp.py deleted file mode 100644 index 47399648..00000000 --- a/array_api_tests/special_cases/test_exp.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for exp. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import exp - -from hypothesis import given - - -@given(numeric_arrays) -def test_exp_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `exp(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = exp(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_exp_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `exp(x, /)`: - - - If `x_i` is `+0`, the result is `1`. - - """ - res = exp(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_exp_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `exp(x, /)`: - - - If `x_i` is `-0`, the result is `1`. - - """ - res = exp(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_exp_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `exp(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = exp(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_exp_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `exp(x, /)`: - - - If `x_i` is `-infinity`, the result is `+0`. - - """ - res = exp(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_expm1.py b/array_api_tests/special_cases/test_expm1.py deleted file mode 100644 index d96b742e..00000000 --- a/array_api_tests/special_cases/test_expm1.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for expm1. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import expm1 - -from hypothesis import given - - -@given(numeric_arrays) -def test_expm1_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `expm1(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = expm1(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_expm1_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `expm1(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = expm1(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_expm1_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `expm1(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = expm1(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_expm1_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `expm1(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = expm1(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_expm1_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `expm1(x, /)`: - - - If `x_i` is `-infinity`, the result is `-1`. - - """ - res = expm1(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-one(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_floor.py b/array_api_tests/special_cases/test_floor.py deleted file mode 100644 index bf7a1572..00000000 --- a/array_api_tests/special_cases/test_floor.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Special cases tests for floor. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, isintegral, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import floor - -from hypothesis import given - - -@given(numeric_arrays) -def test_floor_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `floor(x, /)`: - - - If `x_i` is already integer-valued, the result is `x_i`. - - """ - res = floor(arg1) - mask = isintegral(arg1) - assert_exactly_equal(res[mask], (arg1)[mask]) - - -@given(numeric_arrays) -def test_floor_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `floor(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = floor(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_floor_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `floor(x, /)`: - - - If `x_i` is `-infinity`, the result is `-infinity`. - - """ - res = floor(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_floor_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `floor(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = floor(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_floor_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `floor(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = floor(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_floor_special_cases_one_arg_equal_6(arg1): - """ - Special case test for `floor(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = floor(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_log.py b/array_api_tests/special_cases/test_log.py deleted file mode 100644 index 0ea6cd25..00000000 --- a/array_api_tests/special_cases/test_log.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Special cases tests for log. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, less, logical_or, - one, zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import log - -from hypothesis import given - - -@given(numeric_arrays) -def test_log_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `log(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = log(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `log(x, /)`: - - - If `x_i` is `1`, the result is `+0`. - - """ - res = log(arg1) - mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `log(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = log(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log_special_cases_one_arg_less(arg1): - """ - Special case test for `log(x, /)`: - - - If `x_i` is less than `0`, the result is `NaN`. - - """ - res = log(arg1) - mask = less(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log_special_cases_one_arg_either(arg1): - """ - Special case test for `log(x, /)`: - - - If `x_i` is either `+0` or `-0`, the result is `-infinity`. - - """ - res = log(arg1) - mask = logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_log10.py b/array_api_tests/special_cases/test_log10.py deleted file mode 100644 index 8dc5a5de..00000000 --- a/array_api_tests/special_cases/test_log10.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Special cases tests for log10. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, less, logical_or, - one, zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import log10 - -from hypothesis import given - - -@given(numeric_arrays) -def test_log10_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `log10(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = log10(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log10_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `log10(x, /)`: - - - If `x_i` is `1`, the result is `+0`. - - """ - res = log10(arg1) - mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log10_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `log10(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = log10(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log10_special_cases_one_arg_less(arg1): - """ - Special case test for `log10(x, /)`: - - - If `x_i` is less than `0`, the result is `NaN`. - - """ - res = log10(arg1) - mask = less(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log10_special_cases_one_arg_either(arg1): - """ - Special case test for `log10(x, /)`: - - - If `x_i` is either `+0` or `-0`, the result is `-infinity`. - - """ - res = log10(arg1) - mask = logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_log1p.py b/array_api_tests/special_cases/test_log1p.py deleted file mode 100644 index 432a761b..00000000 --- a/array_api_tests/special_cases/test_log1p.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Special cases tests for log1p. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, less, one, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import log1p - -from hypothesis import given - - -@given(numeric_arrays) -def test_log1p_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `log1p(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = log1p(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log1p_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `log1p(x, /)`: - - - If `x_i` is `-1`, the result is `-infinity`. - - """ - res = log1p(arg1) - mask = exactly_equal(arg1, -one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log1p_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `log1p(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = log1p(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log1p_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `log1p(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = log1p(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log1p_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `log1p(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = log1p(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log1p_special_cases_one_arg_less(arg1): - """ - Special case test for `log1p(x, /)`: - - - If `x_i` is less than `-1`, the result is `NaN`. - - """ - res = log1p(arg1) - mask = less(arg1, -one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_log2.py b/array_api_tests/special_cases/test_log2.py deleted file mode 100644 index 41797dd7..00000000 --- a/array_api_tests/special_cases/test_log2.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Special cases tests for log2. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, less, logical_or, - one, zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import log2 - -from hypothesis import given - - -@given(numeric_arrays) -def test_log2_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `log2(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = log2(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log2_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `log2(x, /)`: - - - If `x_i` is `1`, the result is `+0`. - - """ - res = log2(arg1) - mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log2_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `log2(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = log2(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log2_special_cases_one_arg_less(arg1): - """ - Special case test for `log2(x, /)`: - - - If `x_i` is less than `0`, the result is `NaN`. - - """ - res = log2(arg1) - mask = less(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_log2_special_cases_one_arg_either(arg1): - """ - Special case test for `log2(x, /)`: - - - If `x_i` is either `+0` or `-0`, the result is `-infinity`. - - """ - res = log2(arg1) - mask = logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_logaddexp.py b/array_api_tests/special_cases/test_logaddexp.py deleted file mode 100644 index de8081e7..00000000 --- a/array_api_tests/special_cases/test_logaddexp.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Special cases tests for logaddexp. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, logical_and, - logical_not, logical_or) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import logaddexp - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_logaddexp_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `logaddexp(x1, x2)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = logaddexp(arg1, arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_logaddexp_special_cases_two_args_equal__notequal(arg1, arg2): - """ - Special case test for `logaddexp(x1, x2)`: - - - If `x1_i` is `+infinity` and `x2_i` is not `NaN`, the result is `+infinity`. - - """ - res = logaddexp(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_not(exactly_equal(arg2, NaN(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_logaddexp_special_cases_two_args_notequal__equal(arg1, arg2): - """ - Special case test for `logaddexp(x1, x2)`: - - - If `x1_i` is not `NaN` and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = logaddexp(arg1, arg2) - mask = logical_and(logical_not(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_multiply.py b/array_api_tests/special_cases/test_multiply.py deleted file mode 100644 index 0ab2eec0..00000000 --- a/array_api_tests/special_cases/test_multiply.py +++ /dev/null @@ -1,124 +0,0 @@ -""" -Special cases tests for multiply. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, assert_isinf, - assert_negative_mathematical_sign, assert_positive_mathematical_sign, - exactly_equal, infinity, isfinite, logical_and, logical_not, - logical_or, non_zero, same_sign, zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import multiply - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_multiply_special_cases_two_args_either(arg1, arg2): - """ - Special case test for `multiply(x1, x2, /)`: - - - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`. - - """ - res = multiply(arg1, arg2) - mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_multiply_special_cases_two_args_either__either_1(arg1, arg2): - """ - Special case test for `multiply(x1, x2, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+0` or `-0`, the result is `NaN`. - - """ - res = multiply(arg1, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_multiply_special_cases_two_args_either__either_2(arg1, arg2): - """ - Special case test for `multiply(x1, x2, /)`: - - - If `x1_i` is either `+0` or `-0` and `x2_i` is either `+infinity` or `-infinity`, the result is `NaN`. - - """ - res = multiply(arg1, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_multiply_special_cases_two_args_either__either_3(arg1, arg2): - """ - Special case test for `multiply(x1, x2, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+infinity` or `-infinity`, the result is a signed infinity with the mathematical sign determined by the rule already stated above. - - """ - res = multiply(arg1, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_isinf(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_multiply_special_cases_two_args_same_sign_except(arg1, arg2): - """ - Special case test for `multiply(x1, x2, /)`: - - - If `x1_i` and `x2_i` have the same mathematical sign, the result has a positive mathematical sign, unless the result is `NaN`. If the result is `NaN`, the "sign" of `NaN` is implementation-defined. - - """ - res = multiply(arg1, arg2) - mask = logical_and(same_sign(arg1, arg2), logical_not(exactly_equal(res, NaN(res.shape, res.dtype)))) - assert_positive_mathematical_sign(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_multiply_special_cases_two_args_different_signs_except(arg1, arg2): - """ - Special case test for `multiply(x1, x2, /)`: - - - If `x1_i` and `x2_i` have different mathematical signs, the result has a negative mathematical sign, unless the result is `NaN`. If the result is `NaN`, the "sign" of `NaN` is implementation-defined. - - """ - res = multiply(arg1, arg2) - mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_not(exactly_equal(res, NaN(res.shape, res.dtype)))) - assert_negative_mathematical_sign(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_multiply_special_cases_two_args_either__equal(arg1, arg2): - """ - Special case test for `multiply(x1, x2, /)`: - - - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is a nonzero finite number, the result is a signed infinity with the mathematical sign determined by the rule already stated above. - - """ - res = multiply(arg1, arg2) - mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_and(isfinite(arg2), non_zero(arg2))) - assert_isinf(res[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_multiply_special_cases_two_args_equal__either(arg1, arg2): - """ - Special case test for `multiply(x1, x2, /)`: - - - If `x1_i` is a nonzero finite number and `x2_i` is either `+infinity` or `-infinity`, the result is a signed infinity with the mathematical sign determined by the rule already stated above. - - """ - res = multiply(arg1, arg2) - mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))) - assert_isinf(res[mask]) - -# TODO: Implement REMAINING test for: -# - In the remaining cases, where neither `infinity` nor `NaN` is involved, the product must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode. If the magnitude is too large to represent, the result is an `infinity` of appropriate mathematical sign. If the magnitude is too small to represent, the result is a zero of appropriate mathematical sign. diff --git a/array_api_tests/special_cases/test_pow.py b/array_api_tests/special_cases/test_pow.py deleted file mode 100644 index a422ffd3..00000000 --- a/array_api_tests/special_cases/test_pow.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -Special cases tests for pow. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, greater, infinity, isfinite, - isintegral, isodd, less, logical_and, logical_not, notequal, one, zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import pow - -from hypothesis import given - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_notequal__equal(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is not equal to `1` and `x2_i` is `NaN`, the result is `NaN`. - - """ - res = pow(arg1, arg2) - mask = logical_and(logical_not(exactly_equal(arg1, one(arg1.shape, arg1.dtype))), exactly_equal(arg2, NaN(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_even_if_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x2_i` is `+0`, the result is `1`, even if `x1_i` is `NaN`. - - """ - res = pow(arg1, arg2) - mask = exactly_equal(arg2, zero(arg2.shape, arg2.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_even_if_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x2_i` is `-0`, the result is `1`, even if `x1_i` is `NaN`. - - """ - res = pow(arg1, arg2) - mask = exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__notequal_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `NaN` and `x2_i` is not equal to `0`, the result is `NaN`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), notequal(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__notequal_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `1` and `x2_i` is not `NaN`, the result is `1`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, one(arg1.shape, arg1.dtype)), logical_not(exactly_equal(arg2, NaN(arg2.shape, arg2.dtype)))) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absgreater__equal_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `abs(x1_i)` is greater than `1` and `x2_i` is `+infinity`, the result is `+infinity`. - - """ - res = pow(arg1, arg2) - mask = logical_and(greater(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absgreater__equal_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `abs(x1_i)` is greater than `1` and `x2_i` is `-infinity`, the result is `+0`. - - """ - res = pow(arg1, arg2) - mask = logical_and(greater(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absequal__equal_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `abs(x1_i)` is `1` and `x2_i` is `+infinity`, the result is `1`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absequal__equal_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `abs(x1_i)` is `1` and `x2_i` is `-infinity`, the result is `1`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absless__equal_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `abs(x1_i)` is less than `1` and `x2_i` is `+infinity`, the result is `+0`. - - """ - res = pow(arg1, arg2) - mask = logical_and(less(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_absless__equal_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `abs(x1_i)` is less than `1` and `x2_i` is `-infinity`, the result is `+infinity`. - - """ - res = pow(arg1, arg2) - mask = logical_and(less(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is greater than `0`, the result is `+infinity`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `+infinity` and `x2_i` is less than `0`, the result is `+0`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is `+infinity`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_equal_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-infinity`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_equal_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-0`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_notequal_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is greater than `0`, and `x2_i` is not an odd integer value, the result is `+infinity`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__greater_notequal_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is not an odd integer value, the result is `+0`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_equal_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is less than `0`, and `x2_i` is an odd integer value, the result is `-0`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_equal_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `-0`, `x2_i` is less than `0`, and `x2_i` is an odd integer value, the result is `-infinity`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2))) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_notequal_1(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `-infinity`, `x2_i` is less than `0`, and `x2_i` is not an odd integer value, the result is `+0`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_equal__less_notequal_2(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is `-0`, `x2_i` is less than `0`, and `x2_i` is not an odd integer value, the result is `+infinity`. - - """ - res = pow(arg1, arg2) - mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2)))) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays, numeric_arrays) -def test_pow_special_cases_two_args_less_equal__equal_notequal(arg1, arg2): - """ - Special case test for `pow(x1, x2, /)`: - - - If `x1_i` is less than `0`, `x1_i` is a finite number, `x2_i` is a finite number, and `x2_i` is not an integer value, the result is `NaN`. - - """ - res = pow(arg1, arg2) - mask = logical_and(logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), logical_and(isfinite(arg2), logical_not(isintegral(arg2)))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_round.py b/array_api_tests/special_cases/test_round.py deleted file mode 100644 index 13dc2d99..00000000 --- a/array_api_tests/special_cases/test_round.py +++ /dev/null @@ -1,108 +0,0 @@ -""" -Special cases tests for round. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (NaN, assert_exactly_equal, assert_iseven, assert_positive, ceil, equal, - exactly_equal, floor, infinity, isintegral, logical_and, not_equal, - one, subtract, zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import round - -from hypothesis import given - - -@given(numeric_arrays) -def test_round_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `round(x, /)`: - - - If `x_i` is already integer-valued, the result is `x_i`. - - """ - res = round(arg1) - mask = isintegral(arg1) - assert_exactly_equal(res[mask], (arg1)[mask]) - - -@given(numeric_arrays) -def test_round_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `round(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = round(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_round_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `round(x, /)`: - - - If `x_i` is `-infinity`, the result is `-infinity`. - - """ - res = round(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_round_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `round(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = round(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_round_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `round(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = round(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_round_special_cases_one_arg_equal_6(arg1): - """ - Special case test for `round(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = round(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_round_special_cases_one_arg_two_integers_equally_close(arg1): - """ - Special case test for `round(x, /)`: - - - If two integers are equally close to `x_i`, the result is the even integer closest to `x_i`. - - """ - res = round(arg1) - mask = logical_and(not_equal(floor(arg1), ceil(arg1)), equal(subtract(arg1, floor(arg1)), subtract(ceil(arg1), arg1))) - assert_iseven(res[mask]) - assert_positive(subtract(one(arg1[mask].shape, arg1[mask].dtype), abs(subtract(arg1[mask], res[mask])))) diff --git a/array_api_tests/special_cases/test_sign.py b/array_api_tests/special_cases/test_sign.py deleted file mode 100644 index e2ac6630..00000000 --- a/array_api_tests/special_cases/test_sign.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Special cases tests for sign. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import (assert_exactly_equal, exactly_equal, greater, less, logical_or, one, - zero) -from ..hypothesis_helpers import numeric_arrays -from .._array_module import sign - -from hypothesis import given - - -@given(numeric_arrays) -def test_sign_special_cases_one_arg_less(arg1): - """ - Special case test for `sign(x, /)`: - - - If `x_i` is less than `0`, the result is `-1`. - - """ - res = sign(arg1) - mask = less(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sign_special_cases_one_arg_either(arg1): - """ - Special case test for `sign(x, /)`: - - - If `x_i` is either `-0` or `+0`, the result is `0`. - - """ - res = sign(arg1) - mask = logical_or(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, zero(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sign_special_cases_one_arg_greater(arg1): - """ - Special case test for `sign(x, /)`: - - - If `x_i` is greater than `0`, the result is `+1`. - - """ - res = sign(arg1) - mask = greater(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_sin.py b/array_api_tests/special_cases/test_sin.py deleted file mode 100644 index 4af01736..00000000 --- a/array_api_tests/special_cases/test_sin.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Special cases tests for sin. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, logical_or, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import sin - -from hypothesis import given - - -@given(numeric_arrays) -def test_sin_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `sin(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = sin(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sin_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `sin(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = sin(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sin_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `sin(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = sin(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sin_special_cases_one_arg_either(arg1): - """ - Special case test for `sin(x, /)`: - - - If `x_i` is either `+infinity` or `-infinity`, the result is `NaN`. - - """ - res = sin(arg1) - mask = logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_sinh.py b/array_api_tests/special_cases/test_sinh.py deleted file mode 100644 index 4d2ff217..00000000 --- a/array_api_tests/special_cases/test_sinh.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for sinh. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import sinh - -from hypothesis import given - - -@given(numeric_arrays) -def test_sinh_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `sinh(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = sinh(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sinh_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `sinh(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = sinh(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sinh_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `sinh(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = sinh(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sinh_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `sinh(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = sinh(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sinh_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `sinh(x, /)`: - - - If `x_i` is `-infinity`, the result is `-infinity`. - - """ - res = sinh(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_sqrt.py b/array_api_tests/special_cases/test_sqrt.py deleted file mode 100644 index 18244755..00000000 --- a/array_api_tests/special_cases/test_sqrt.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for sqrt. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, less, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import sqrt - -from hypothesis import given - - -@given(numeric_arrays) -def test_sqrt_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `sqrt(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = sqrt(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sqrt_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `sqrt(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = sqrt(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sqrt_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `sqrt(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = sqrt(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sqrt_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `sqrt(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = sqrt(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_sqrt_special_cases_one_arg_less(arg1): - """ - Special case test for `sqrt(x, /)`: - - - If `x_i` is less than `0`, the result is `NaN`. - - """ - res = sqrt(arg1) - mask = less(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_tan.py b/array_api_tests/special_cases/test_tan.py deleted file mode 100644 index ec09878d..00000000 --- a/array_api_tests/special_cases/test_tan.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Special cases tests for tan. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, logical_or, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import tan - -from hypothesis import given - - -@given(numeric_arrays) -def test_tan_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `tan(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = tan(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_tan_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `tan(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = tan(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_tan_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `tan(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = tan(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_tan_special_cases_one_arg_either(arg1): - """ - Special case test for `tan(x, /)`: - - - If `x_i` is either `+infinity` or `-infinity`, the result is `NaN`. - - """ - res = tan(arg1) - mask = logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_tanh.py b/array_api_tests/special_cases/test_tanh.py deleted file mode 100644 index 91304c2f..00000000 --- a/array_api_tests/special_cases/test_tanh.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Special cases tests for tanh. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import tanh - -from hypothesis import given - - -@given(numeric_arrays) -def test_tanh_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `tanh(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = tanh(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_tanh_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `tanh(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = tanh(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_tanh_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `tanh(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = tanh(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_tanh_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `tanh(x, /)`: - - - If `x_i` is `+infinity`, the result is `+1`. - - """ - res = tanh(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_tanh_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `tanh(x, /)`: - - - If `x_i` is `-infinity`, the result is `-1`. - - """ - res = tanh(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-one(arg1.shape, arg1.dtype))[mask]) diff --git a/array_api_tests/special_cases/test_trunc.py b/array_api_tests/special_cases/test_trunc.py deleted file mode 100644 index 6ee7d402..00000000 --- a/array_api_tests/special_cases/test_trunc.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Special cases tests for trunc. - -These tests are generated from the special cases listed in the spec. - -NOTE: This file is generated automatically by the generate_stubs.py script. Do -not modify it directly. -""" - -from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, isintegral, zero -from ..hypothesis_helpers import numeric_arrays -from .._array_module import trunc - -from hypothesis import given - - -@given(numeric_arrays) -def test_trunc_special_cases_one_arg_equal_1(arg1): - """ - Special case test for `trunc(x, /)`: - - - If `x_i` is already integer-valued, the result is `x_i`. - - """ - res = trunc(arg1) - mask = isintegral(arg1) - assert_exactly_equal(res[mask], (arg1)[mask]) - - -@given(numeric_arrays) -def test_trunc_special_cases_one_arg_equal_2(arg1): - """ - Special case test for `trunc(x, /)`: - - - If `x_i` is `+infinity`, the result is `+infinity`. - - """ - res = trunc(arg1) - mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_trunc_special_cases_one_arg_equal_3(arg1): - """ - Special case test for `trunc(x, /)`: - - - If `x_i` is `-infinity`, the result is `-infinity`. - - """ - res = trunc(arg1) - mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_trunc_special_cases_one_arg_equal_4(arg1): - """ - Special case test for `trunc(x, /)`: - - - If `x_i` is `+0`, the result is `+0`. - - """ - res = trunc(arg1) - mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_trunc_special_cases_one_arg_equal_5(arg1): - """ - Special case test for `trunc(x, /)`: - - - If `x_i` is `-0`, the result is `-0`. - - """ - res = trunc(arg1) - mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask]) - - -@given(numeric_arrays) -def test_trunc_special_cases_one_arg_equal_6(arg1): - """ - Special case test for `trunc(x, /)`: - - - If `x_i` is `NaN`, the result is `NaN`. - - """ - res = trunc(arg1) - mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)) - assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask]) From ee68b89d38b416872feacb510df40ecd71f7034e Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 10 Feb 2022 10:22:00 +0000 Subject: [PATCH 03/63] Rudimentary testing for binary elwise special cases --- ...special_cases.py => test_special_cases.py} | 198 ++++++++++++++---- 1 file changed, 162 insertions(+), 36 deletions(-) rename array_api_tests/{special_cases.py => test_special_cases.py} (53%) diff --git a/array_api_tests/special_cases.py b/array_api_tests/test_special_cases.py similarity index 53% rename from array_api_tests/special_cases.py rename to array_api_tests/test_special_cases.py index 0b5c31fd..d7472f72 100644 --- a/array_api_tests/special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -6,34 +6,32 @@ import pytest from attr import dataclass -from hypothesis import assume, given +from hypothesis import HealthCheck, assume, given, settings +from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import shape_helpers as sh from . import xps from ._array_module import mod as xp from .stubs import category_to_funcs -repr_to_value = { - "NaN": float("nan"), - "+infinity": float("infinity"), - "infinity": float("infinity"), - "-infinity": float("-infinity"), - "+0": 0.0, - "0": 0.0, - "-0": -0.0, - "+1": 1.0, - "1": 1.0, - "-1": -1.0, - "+π/2": math.pi / 2, - "π/2": math.pi / 2, - "-π/2": -math.pi / 2, -} + +def is_pos_zero(n: float) -> bool: + return n == 0 and math.copysign(1, n) == 1 + + +def is_neg_zero(n: float) -> bool: + return n == 0 and math.copysign(1, n) == -1 def make_eq(v: float) -> Callable[[float], bool]: if math.isnan(v): return math.isnan + if v == 0: + if is_pos_zero(v): + return is_pos_zero + else: + return is_neg_zero def eq(i: float) -> bool: return i == v @@ -42,6 +40,8 @@ def eq(i: float) -> bool: def make_rough_eq(v: float) -> Callable[[float], bool]: + assert math.isfinite(v) # sanity check + def rough_eq(i: float) -> bool: return math.isclose(i, v, abs_tol=0.01) @@ -73,10 +73,15 @@ def or_(i: float): return or_ -r_value = re.compile(r"``([^\s]+)``") -r_approx_value = re.compile( - rf"an implementation-dependent approximation to {r_value.pattern}" -) +repr_to_value = { + "NaN": float("nan"), + "infinity": float("infinity"), + "0": 0.0, + "1": 1.0, +} + +r_value = re.compile(r"([+-]?)(.+)") +r_pi = re.compile(r"(\d?)π(?:/(\d))?") @dataclass @@ -84,10 +89,36 @@ class ValueParseError(ValueError): value: str -def parse_value(value: str) -> float: - if m := r_value.match(value): - return repr_to_value[m.group(1)] - raise ValueParseError(value) +def parse_value(s_value: str) -> float: + assert not s_value.startswith("``") and not s_value.endswith("``") # sanity check + m = r_value.match(s_value) + if m is None: + raise ValueParseError(s_value) + if pi_m := r_pi.match(m.group(2)): + value = math.pi + if numerator := pi_m.group(1): + value *= int(numerator) + if denominator := pi_m.group(2): + value /= int(denominator) + else: + value = repr_to_value[m.group(2)] + if sign := m.group(1): + if sign == "-": + value *= -1 + return value + + +r_inline_code = re.compile(r"``([^\s]+)``") +r_approx_value = re.compile( + rf"an implementation-dependent approximation to {r_inline_code.pattern}" +) + + +def parse_inline_code(inline_code: str) -> float: + if m := r_inline_code.match(inline_code): + return parse_value(m.group(1)) + else: + raise ValueParseError(inline_code) class Result(NamedTuple): @@ -96,22 +127,24 @@ class Result(NamedTuple): strict_check: bool -def parse_result(result: str) -> Result: - if m := r_value.match(result): - repr_ = m.group(1) +def parse_result(s_result: str) -> Result: + match = None + if m := r_inline_code.match(s_result): + match = m strict_check = True - elif m := r_approx_value.match(result): - repr_ = m.group(1) + elif m := r_approx_value.match(s_result): + match = m strict_check = False else: - raise ValueParseError(result) - value = repr_to_value[repr_] + raise ValueParseError(s_result) + value = parse_value(match.group(1)) + repr_ = match.group(1) return Result(value, repr_, strict_check) r_special_cases = re.compile( - r"\*\*Special [Cc]ases\*\*\n\n\s*" - r"For floating-point operands,\n\n" + r"\*\*Special [Cc]ases\*\*\n+\s*" + r"For floating-point operands,\n+" r"((?:\s*-\s*.*\n)+)" ) r_case = re.compile(r"\s+-\s*(.*)\.\n?") @@ -148,7 +181,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: if m := pattern.search(case): *s_values, s_result = m.groups() try: - values = [parse_value(v) for v in s_values] + values = [parse_inline_code(v) for v in s_values] except ValueParseError as e: warn(f"value not machine-readable: '{e.value}'") break @@ -166,7 +199,56 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: return condition_to_result +binary_pattern_to_condition_factory: Dict[Pattern, Callable] = { + re.compile( + "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" + ): lambda v1, v2: lambda i1, i2: make_eq(v1)(i1) + and make_eq(v2)(i2), +} + + +def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]: + match = r_special_cases.search(docstring) + if match is None: + return {} + cases = match.group(1).split("\n")[:-1] + condition_to_result = {} + for line in cases: + if m := r_case.match(line): + case = m.group(1) + else: + warn(f"line not machine-readable: '{line}'") + continue + for pattern, make_cond in binary_pattern_to_condition_factory.items(): + if m := pattern.search(case): + *s_values, s_result = m.groups() + try: + values = [parse_inline_code(v) for v in s_values] + except ValueParseError as e: + warn(f"value not machine-readable: '{e.value}'") + break + cond = make_cond(*values) + if ( + "atan2" in docstring + and is_pos_zero(values[0]) + and is_neg_zero(values[1]) + ): + breakpoint() + try: + result = parse_result(s_result) + except ValueParseError as e: + warn(f"result not machine-readable: '{e.value}'") + break + condition_to_result[cond] = result + break + else: + if not r_remaining_case.search(case): + warn(f"case not machine-readable: '{case}'") + return condition_to_result + + unary_params = [] +binary_params = [] for stub in category_to_funcs["elementwise"]: if stub.__doc__ is None: warn(f"{stub.__name__}() stub has no docstring") @@ -193,7 +275,10 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": - pass # TODO + if condition_to_result := parse_binary_docstring(stub.__doc__): + p = pytest.param(stub.__name__, func, condition_to_result, id=stub.__name__) + binary_params.append(p) + continue else: warn( f"{func=} starts with two parameters '{param_names[0]}' and " @@ -209,7 +294,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: @pytest.mark.parametrize("func_name, func, condition_to_result", unary_params) @given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) -def test_unary_special_cases(func_name, func, condition_to_result, x): +def test_unary(func_name, func, condition_to_result, x): res = func(x) good_example = False for idx in sh.ndindex(res.shape): @@ -238,3 +323,44 @@ def test_unary_special_cases(func_name, func, condition_to_result, x): ) break assume(good_example) + + +@pytest.mark.parametrize("func_name, func, condition_to_result", binary_params) +@given( + *hh.two_mutual_arrays( + dtypes=dh.float_dtypes, + two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1), + ) +) +@settings(suppress_health_check=[HealthCheck.filter_too_much]) # TODO: remove +def test_binary(func_name, func, condition_to_result, x1, x2): + res = func(x1, x2) + good_example = False + for l_idx, r_idx, o_idx in sh.iter_indices(x1.shape, x2.shape, res.shape): + l = float(x1[l_idx]) + r = float(x2[r_idx]) + for cond, result in condition_to_result.items(): + if cond(l, r): + good_example = True + out = float(res[o_idx]) + f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" + f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" + f_out = f"{sh.fmt_idx('out', o_idx)}={out}" + if result.strict_check: + msg = ( + f"{f_out}, but should be {result.repr_} [{func_name}()]\n" + f"{f_left}, {f_right}" + ) + if math.isnan(result.value): + assert math.isnan(out), msg + else: + assert out == result.value, msg + else: + assert math.isfinite(result.value) # sanity check + assert math.isclose(out, result.value, abs_tol=0.1), ( + f"{f_out}, but should be roughly {result.repr_}={result.value} " + f"[{func_name}()]\n" + f"{f_left}, {f_right}" + ) + break + assume(good_example) From 1ef5b3e4a4da5c1cb82d7d6bbe82c38753199e16 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 11 Feb 2022 14:06:13 +0000 Subject: [PATCH 04/63] Use `ph` for pos/neg zero utils --- array_api_tests/test_special_cases.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index d7472f72..8d9aa21d 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -10,28 +10,21 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh +from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps from ._array_module import mod as xp from .stubs import category_to_funcs -def is_pos_zero(n: float) -> bool: - return n == 0 and math.copysign(1, n) == 1 - - -def is_neg_zero(n: float) -> bool: - return n == 0 and math.copysign(1, n) == -1 - - def make_eq(v: float) -> Callable[[float], bool]: if math.isnan(v): return math.isnan if v == 0: - if is_pos_zero(v): - return is_pos_zero + if ph.is_pos_zero(v): + return ph.is_pos_zero else: - return is_neg_zero + return ph.is_neg_zero def eq(i: float) -> bool: return i == v @@ -230,8 +223,8 @@ def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]: cond = make_cond(*values) if ( "atan2" in docstring - and is_pos_zero(values[0]) - and is_neg_zero(values[1]) + and ph.is_pos_zero(values[0]) + and ph.is_neg_zero(values[1]) ): breakpoint() try: From 14c4793e654d3dacbd78df19d29b3ac4dc0b0c45 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 11 Feb 2022 19:22:12 +0000 Subject: [PATCH 05/63] More binary cases coverage --- array_api_tests/meta/test_special_cases.py | 8 + array_api_tests/test_special_cases.py | 181 ++++++++++++++++++++- 2 files changed, 180 insertions(+), 9 deletions(-) create mode 100644 array_api_tests/meta/test_special_cases.py diff --git a/array_api_tests/meta/test_special_cases.py b/array_api_tests/meta/test_special_cases.py new file mode 100644 index 00000000..1b8c8358 --- /dev/null +++ b/array_api_tests/meta/test_special_cases.py @@ -0,0 +1,8 @@ +import math + +from ..test_special_cases import parse_result + + +def test_parse_result(): + s_result = "an implementation-dependent approximation to ``+3π/4``" + assert parse_result(s_result).value == 3 * math.pi / 4 diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 8d9aa21d..bb9d3fce 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1,7 +1,7 @@ import inspect import math import re -from typing import Callable, Dict, NamedTuple, Pattern +from typing import Callable, Dict, List, NamedTuple, Pattern from warnings import warn import pytest @@ -16,6 +16,9 @@ from ._array_module import mod as xp from .stubs import category_to_funcs +# Condition factories +# ------------------------------------------------------------------------------ + def make_eq(v: float) -> Callable[[float], bool]: if math.isnan(v): @@ -32,6 +35,15 @@ def eq(i: float) -> bool: return eq +def make_neq(v: float) -> Callable[[float], bool]: + eq = make_eq(v) + + def neq(i: float) -> bool: + return not eq(i) + + return neq + + def make_rough_eq(v: float) -> Callable[[float], bool]: assert math.isfinite(v) # sanity check @@ -66,6 +78,71 @@ def or_(i: float): return or_ +def make_and(cond1: Callable, cond2: Callable) -> Callable: + def and_(i: float) -> bool: + return cond1(i) or cond2(i) + + return and_ + + +def make_bin_and_factory(make_cond1: Callable, make_cond2: Callable) -> Callable: + def make_bin_and(v1: float, v2: float) -> Callable: + cond1 = make_cond1(v1) + cond2 = make_cond2(v2) + + def bin_and(i1: float, i2: float) -> bool: + return cond1(i1) and cond2(i2) + + return bin_and + + return make_bin_and + + +def make_bin_or_factory(make_cond: Callable) -> Callable: + def make_bin_or(v: float) -> Callable: + cond = make_cond(v) + + def bin_or(i1: float, i2: float) -> bool: + return cond(i1) or cond(i2) + + return bin_or + + return make_bin_or + + +def absify_cond_factory(make_cond): + def make_abs_cond(v: float): + cond = make_cond(v) + + def abs_cond(i: float) -> bool: + i = abs(i) + return cond(i) + + return abs_cond + + return make_abs_cond + + +def make_bin_multi_and_factory( + make_conds1: List[Callable], make_conds2: List[Callable] +) -> Callable: + def make_bin_multi_and(*values: float) -> Callable: + assert len(values) == len(make_conds1) + len(make_conds2) + conds1 = [make_cond(v) for make_cond, v in zip(make_conds1, values)] + conds2 = [make_cond(v) for make_cond, v in zip(make_conds2, values[::-1])] + + def bin_multi_and(i1: float, i2: float) -> bool: + return all(cond(i1) for cond in conds1) and all(cond(i2) for cond in conds2) + + return bin_multi_and + + return make_bin_multi_and + + +# Parse utils +# ------------------------------------------------------------------------------ + + repr_to_value = { "NaN": float("nan"), "infinity": float("infinity"), @@ -183,6 +260,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: result = parse_result(s_result) except ValueParseError as e: warn(f"result not machine-readable: '{e.value}'") + break condition_to_result[cond] = result break @@ -193,10 +271,97 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: binary_pattern_to_condition_factory: Dict[Pattern, Callable] = { + re.compile( + "If ``x1_i`` is (.+) and ``x2_i`` is not equal to (.+), the result is (.+)" + ): make_bin_and_factory(make_eq, lambda v: lambda i: i != v), + re.compile( + "If ``x1_i`` is greater than (.+), ``x1_i`` is (.+), " + "and ``x2_i`` is (.+), the result is (.+)" + ): make_bin_multi_and_factory([make_gt, make_eq], [make_eq]), + re.compile( + "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), " + "and ``x2_i`` is (.+), the result is (.+)" + ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq]), + re.compile( + "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), " + "and ``x2_i`` is not (.+), the result is (.+)" + ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq, make_neq]), + re.compile( + "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), " + "and ``x2_i`` is (.+), the result is (.+)" + ): make_bin_multi_and_factory([make_eq], [make_lt, make_eq]), + re.compile( + "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), " + "and ``x2_i`` is not (.+), the result is (.+)" + ): make_bin_multi_and_factory([make_eq], [make_lt, make_neq]), + re.compile( + "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), " + "and ``x2_i`` is (.+), the result is (.+)" + ): make_bin_multi_and_factory([make_eq], [make_gt, make_eq]), + re.compile( + "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), " + "and ``x2_i`` is not (.+), the result is (.+)" + ): make_bin_multi_and_factory([make_eq], [make_gt, make_neq]), + re.compile( + "If ``x1_i`` is greater than (.+) and ``x2_i`` is (.+), the result is (.+)" + ): make_bin_and_factory(make_gt, make_eq), + re.compile( + "If ``x1_i`` is (.+) and ``x2_i`` is greater than (.+), the result is (.+)" + ): make_bin_and_factory(make_eq, make_gt), + re.compile( + "If ``x1_i`` is less than (.+) and ``x2_i`` is (.+), the result is (.+)" + ): make_bin_and_factory(make_lt, make_eq), + re.compile( + "If ``x1_i`` is (.+) and ``x2_i`` is less than (.+), the result is (.+)" + ): make_bin_and_factory(make_eq, make_lt), + re.compile( + "If ``x1_i`` is not (?:equal to )?(.+) and ``x2_i`` is (.+), the result is (.+)" + ): make_bin_and_factory(make_neq, make_eq), + re.compile( + "If ``x1_i`` is (.+) and ``x2_i`` is not (?:equal to )?(.+), the result is (.+)" + ): make_bin_and_factory(make_eq, make_neq), + re.compile( + r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), " + "the result is (.+)" + ): make_bin_and_factory(absify_cond_factory(make_gt), make_eq), + re.compile( + r"If `abs\(x1_i\)` is less than (.+) and ``x2_i`` is (.+), the result is (.+)" + ): make_bin_and_factory(absify_cond_factory(make_lt), make_eq), + re.compile( + r"If `abs\(x1_i\)` is (.+) and ``x2_i`` is (.+), the result is (.+)" + ): make_bin_and_factory(absify_cond_factory(make_eq), make_eq), re.compile( "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" - ): lambda v1, v2: lambda i1, i2: make_eq(v1)(i1) - and make_eq(v2)(i2), + ): make_bin_and_factory(make_eq, make_eq), + re.compile( + "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)" + ): make_bin_or_factory(make_eq), + re.compile( + "If ``x1_i`` is either (.+) or (.+) and ``x2_i`` is (.+), the result is (.+)" + ): lambda v1, v2, v3: ( + lambda i1, i2: make_or(make_eq(v1), make_eq(v2))(i1) and make_eq(v3)(i2) + ), + re.compile( + "If ``x1_i`` is (.+) and ``x2_i`` is either (.+) or (.+), the result is (.+)" + ): lambda v1, v2, v3: ( + lambda i1, i2: make_eq(v1)(i1) and make_or(make_eq(v2), make_eq(v3))(i2) + ), + re.compile( + "If ``x1_i`` is either (.+) or (.+) and " + "``x2_i`` is either (.+) or (.+), the result is (.+)" + ): lambda v1, v2, v3, v4: ( + lambda i1, i2: ( + make_or(make_eq(v1), make_eq(v2))(i1) + and make_or(make_eq(v3), make_eq(v4))(i2) + ) + ), + # re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+)") + # re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined") + # re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign and are both (.+), the result has a (.+)") + # re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+)") + # re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined") + # re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs and are both (.+), the result has a (.+)") + # re.compile("If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+") } @@ -221,12 +386,6 @@ def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]: warn(f"value not machine-readable: '{e.value}'") break cond = make_cond(*values) - if ( - "atan2" in docstring - and ph.is_pos_zero(values[0]) - and ph.is_neg_zero(values[1]) - ): - breakpoint() try: result = parse_result(s_result) except ValueParseError as e: @@ -240,6 +399,10 @@ def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]: return condition_to_result +# Here be the tests +# ------------------------------------------------------------------------------ + + unary_params = [] binary_params = [] for stub in category_to_funcs["elementwise"]: From a937d26a7a0aca8145558ecab25b849f2191b192 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Feb 2022 09:59:41 +0000 Subject: [PATCH 06/63] Cover most sign special cases --- array_api_tests/test_special_cases.py | 88 +++++++++++++++++++-------- 1 file changed, 63 insertions(+), 25 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index bb9d3fce..83aa6b18 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -20,7 +20,11 @@ # ------------------------------------------------------------------------------ -def make_eq(v: float) -> Callable[[float], bool]: +UnaryCheck = Callable[[float], bool] +BinaryCheck = Callable[[float, float], bool] + + +def make_eq(v: float) -> UnaryCheck: if math.isnan(v): return math.isnan if v == 0: @@ -35,7 +39,7 @@ def eq(i: float) -> bool: return eq -def make_neq(v: float) -> Callable[[float], bool]: +def make_neq(v: float) -> UnaryCheck: eq = make_eq(v) def neq(i: float) -> bool: @@ -44,7 +48,7 @@ def neq(i: float) -> bool: return neq -def make_rough_eq(v: float) -> Callable[[float], bool]: +def make_rough_eq(v: float) -> UnaryCheck: assert math.isfinite(v) # sanity check def rough_eq(i: float) -> bool: @@ -53,40 +57,42 @@ def rough_eq(i: float) -> bool: return rough_eq -def make_gt(v: float): +def make_gt(v: float) -> UnaryCheck: assert not math.isnan(v) # sanity check - def gt(i: float): + def gt(i: float) -> bool: return i > v return gt -def make_lt(v: float): +def make_lt(v: float) -> UnaryCheck: assert not math.isnan(v) # sanity check - def lt(i: float): + def lt(i: float) -> bool: return i < v return lt -def make_or(cond1: Callable, cond2: Callable): - def or_(i: float): +def make_or(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck: + def or_(i: float) -> bool: return cond1(i) or cond2(i) return or_ -def make_and(cond1: Callable, cond2: Callable) -> Callable: +def make_and(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck: def and_(i: float) -> bool: return cond1(i) or cond2(i) return and_ -def make_bin_and_factory(make_cond1: Callable, make_cond2: Callable) -> Callable: - def make_bin_and(v1: float, v2: float) -> Callable: +def make_bin_and_factory( + make_cond1: Callable[[float], UnaryCheck], make_cond2: Callable[[float], UnaryCheck] +) -> Callable[[float, float], BinaryCheck]: + def make_bin_and(v1: float, v2: float) -> BinaryCheck: cond1 = make_cond1(v1) cond2 = make_cond2(v2) @@ -98,8 +104,10 @@ def bin_and(i1: float, i2: float) -> bool: return make_bin_and -def make_bin_or_factory(make_cond: Callable) -> Callable: - def make_bin_or(v: float) -> Callable: +def make_bin_or_factory( + make_cond: Callable[[float], UnaryCheck] +) -> Callable[[float], BinaryCheck]: + def make_bin_or(v: float) -> BinaryCheck: cond = make_cond(v) def bin_or(i1: float, i2: float) -> bool: @@ -110,8 +118,10 @@ def bin_or(i1: float, i2: float) -> bool: return make_bin_or -def absify_cond_factory(make_cond): - def make_abs_cond(v: float): +def absify_cond_factory( + make_cond: Callable[[float], UnaryCheck] +) -> Callable[[float], UnaryCheck]: + def make_abs_cond(v: float) -> UnaryCheck: cond = make_cond(v) def abs_cond(i: float) -> bool: @@ -124,9 +134,10 @@ def abs_cond(i: float) -> bool: def make_bin_multi_and_factory( - make_conds1: List[Callable], make_conds2: List[Callable] + make_conds1: List[Callable[[float], UnaryCheck]], + make_conds2: List[Callable[[float], UnaryCheck]], ) -> Callable: - def make_bin_multi_and(*values: float) -> Callable: + def make_bin_multi_and(*values: float) -> BinaryCheck: assert len(values) == len(make_conds1) + len(make_conds2) conds1 = [make_cond(v) for make_cond, v in zip(make_conds1, values)] conds2 = [make_cond(v) for make_cond, v in zip(make_conds2, values[::-1])] @@ -139,6 +150,14 @@ def bin_multi_and(i1: float, i2: float) -> bool: return make_bin_multi_and +def same_sign(i1: float, i2: float) -> bool: + return math.copysign(1, i1) == math.copysign(1, i2) + + +def diff_sign(i1: float, i2: float) -> bool: + return not same_sign(i1, i2) + + # Parse utils # ------------------------------------------------------------------------------ @@ -271,6 +290,9 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: binary_pattern_to_condition_factory: Dict[Pattern, Callable] = { + re.compile( + "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+" + ): lambda v: lambda _, i2: make_eq(v)(i2), re.compile( "If ``x1_i`` is (.+) and ``x2_i`` is not equal to (.+), the result is (.+)" ): make_bin_and_factory(make_eq, lambda v: lambda i: i != v), @@ -355,13 +377,29 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: and make_or(make_eq(v3), make_eq(v4))(i2) ) ), - # re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+)") - # re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined") - # re.compile("If ``x1_i`` and ``x2_i`` have the same mathematical sign and are both (.+), the result has a (.+)") - # re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+)") - # re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the "sign" of (.+) is implementation-defined") - # re.compile("If ``x1_i`` and ``x2_i`` have different mathematical signs and are both (.+), the result has a (.+)") - # re.compile("If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+") + re.compile( + "If ``x1_i`` and ``x2_i`` have the same mathematical sign, " + "the result has a (.+)" + ): lambda: same_sign, + re.compile( + "If ``x1_i`` and ``x2_i`` have different mathematical signs, " + "the result has a (.+)" + ): lambda: diff_sign, + re.compile( + "If ``x1_i`` and ``x2_i`` have the same mathematical sign and " + "are both (.+), the result has a (.+)" + ): lambda v: lambda i1, i2: same_sign(i1, i2) + and make_eq(v)(i1) + and make_eq(v)(i2), + re.compile( + "If ``x1_i`` and ``x2_i`` have different mathematical signs and " + "are both (.+), the result has a (.+)" + ): lambda v: lambda i1, i2: diff_sign(i1, i2) + and make_eq(v)(i1) + and make_eq(v)(i2), + # TODO: support capturing values that come after the result + # re.compile(r"If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is .+, the \"sign\" of .+ is implementation-defined") + # re.compile(r"If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the \"sign\" of (.+) is implementation-defined") } From 08aebdec365aafd1c8308b944fa734cc50e3563a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Feb 2022 11:38:35 +0000 Subject: [PATCH 07/63] Rudimentary rework of value condition factories --- array_api_tests/test_special_cases.py | 458 +++++++++++++++++--------- 1 file changed, 297 insertions(+), 161 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 83aa6b18..cc76bb71 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1,7 +1,17 @@ import inspect import math import re -from typing import Callable, Dict, List, NamedTuple, Pattern +from typing import ( + Callable, + Dict, + List, + Literal, + NamedTuple, + Pattern, + Protocol, + Tuple, + Union, +) from warnings import warn import pytest @@ -197,14 +207,14 @@ def parse_value(s_value: str) -> float: return value -r_inline_code = re.compile(r"``([^\s]+)``") +r_code = re.compile(r"``([^\s]+)``") r_approx_value = re.compile( - rf"an implementation-dependent approximation to {r_inline_code.pattern}" + rf"an implementation-dependent approximation to {r_code.pattern}" ) def parse_inline_code(inline_code: str) -> float: - if m := r_inline_code.match(inline_code): + if m := r_code.match(inline_code): return parse_value(m.group(1)) else: raise ValueParseError(inline_code) @@ -218,7 +228,7 @@ class Result(NamedTuple): def parse_result(s_result: str) -> Result: match = None - if m := r_inline_code.match(s_result): + if m := r_code.match(s_result): match = m strict_check = True elif m := r_approx_value.match(s_result): @@ -259,7 +269,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: if match is None: return {} cases = match.group(1).split("\n")[:-1] - condition_to_result = {} + cases = {} for line in cases: if m := r_case.match(line): case = m.group(1) @@ -281,160 +291,285 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: warn(f"result not machine-readable: '{e.value}'") break - condition_to_result[cond] = result + cases[cond] = result break else: if not r_remaining_case.search(case): warn(f"case not machine-readable: '{case}'") - return condition_to_result + return cases + + +class CondFactory(Protocol): + def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: + ... + + +r_not_code = re.compile(f"not (?:equal to )?{r_code.pattern}") +r_array_element = re.compile(r"``([+-]?)x[12]_i``") +r_gt = re.compile(f"greater than {r_code.pattern}") +r_lt = re.compile(f"less than {r_code.pattern}") +r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}") + + +class ValueCondFactory(NamedTuple): # TODO: inherit from CondFactory as well + input_: Union[Literal["i1"], Literal["i2"]] + re_group: int + + def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: + group = groups[self.re_group] + + if m := r_array_element.match(group): + cond_factory = make_eq if m.group(1) != "-" else make_neq + if self.input_ == "i1": + + def cond(i1: float, i2: float) -> bool: + _cond = cond_factory(i2) + return _cond(i1) + + else: + + def cond(i1: float, i2: float) -> bool: + _cond = cond_factory(i1) + return _cond(i2) + + return cond + # this branch must come after checking for array elements + elif m := r_code.match(group): + value = parse_value(m.group(1)) + _cond = make_eq(value) + elif m := r_not_code.match(group): + value = parse_value(m.group(1)) + _cond = make_neq(value) + elif m := r_gt.match(group): + value = parse_value(m.group(1)) + _cond = make_gt(value) + elif m := r_lt.match(group): + value = parse_value(m.group(1)) + _cond = make_lt(value) + elif m := r_either_code.match(group): + v1 = parse_value(m.group(1)) + v2 = parse_value(m.group(2)) + _cond = make_or(make_eq(v1), make_eq(v2)) + elif group in ["finite", "a finite number"]: + _cond = math.isfinite + elif group in "a positive (i.e., greater than ``0``) finite number": + _cond = lambda i: math.isfinite(i) and i > 0 + elif group == "a negative (i.e., less than ``0``) finite number": + _cond = lambda i: math.isfinite(i) and i < 0 + elif group == "positive": + _cond = lambda i: math.copysign(1, i) == 1 + elif group == "negative": + _cond = lambda i: math.copysign(1, i) == -1 + elif "nonzero finite" in group: + _cond = lambda i: math.isfinite(i) and i != 0 + elif group == "an integer value": + _cond = lambda i: i.is_integer() + elif group == "not an integer value": + _cond = lambda i: not i.is_integer() + elif group == "an odd integer value": + _cond = lambda i: i.is_integer() and i % 2 == 1 + elif group == "not an odd integer value": + _cond = lambda i: not (i.is_integer() and i % 2 == 1) + else: + raise ValueParseError(group) + if self.input_ == "i1": -binary_pattern_to_condition_factory: Dict[Pattern, Callable] = { - re.compile( - "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+" - ): lambda v: lambda _, i2: make_eq(v)(i2), - re.compile( - "If ``x1_i`` is (.+) and ``x2_i`` is not equal to (.+), the result is (.+)" - ): make_bin_and_factory(make_eq, lambda v: lambda i: i != v), - re.compile( - "If ``x1_i`` is greater than (.+), ``x1_i`` is (.+), " - "and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_multi_and_factory([make_gt, make_eq], [make_eq]), - re.compile( - "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), " - "and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq]), - re.compile( - "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), " - "and ``x2_i`` is not (.+), the result is (.+)" - ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq, make_neq]), - re.compile( - "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), " - "and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_multi_and_factory([make_eq], [make_lt, make_eq]), - re.compile( - "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), " - "and ``x2_i`` is not (.+), the result is (.+)" - ): make_bin_multi_and_factory([make_eq], [make_lt, make_neq]), - re.compile( - "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), " - "and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_multi_and_factory([make_eq], [make_gt, make_eq]), - re.compile( - "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), " - "and ``x2_i`` is not (.+), the result is (.+)" - ): make_bin_multi_and_factory([make_eq], [make_gt, make_neq]), - re.compile( - "If ``x1_i`` is greater than (.+) and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_and_factory(make_gt, make_eq), - re.compile( - "If ``x1_i`` is (.+) and ``x2_i`` is greater than (.+), the result is (.+)" - ): make_bin_and_factory(make_eq, make_gt), - re.compile( - "If ``x1_i`` is less than (.+) and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_and_factory(make_lt, make_eq), - re.compile( - "If ``x1_i`` is (.+) and ``x2_i`` is less than (.+), the result is (.+)" - ): make_bin_and_factory(make_eq, make_lt), - re.compile( - "If ``x1_i`` is not (?:equal to )?(.+) and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_and_factory(make_neq, make_eq), - re.compile( - "If ``x1_i`` is (.+) and ``x2_i`` is not (?:equal to )?(.+), the result is (.+)" - ): make_bin_and_factory(make_eq, make_neq), - re.compile( - r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), " - "the result is (.+)" - ): make_bin_and_factory(absify_cond_factory(make_gt), make_eq), - re.compile( - r"If `abs\(x1_i\)` is less than (.+) and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_and_factory(absify_cond_factory(make_lt), make_eq), - re.compile( - r"If `abs\(x1_i\)` is (.+) and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_and_factory(absify_cond_factory(make_eq), make_eq), + def cond(i1: float, i2: float) -> bool: + return _cond(i1) + + else: + + def cond(i1: float, i2: float) -> bool: + return _cond(i2) + + return cond + + +class AndCondFactory(CondFactory): + def __init__(self, *cond_factories: CondFactory): + self.cond_factories = cond_factories + + def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: + conds = [cond_factory(groups) for cond_factory in self.cond_factories] + + def cond(i1: float, i2: float) -> bool: + return all(cond(i1, i2) for cond in conds) + + return cond + + +class BinaryCase(NamedTuple): + cond: BinaryCheck + check_result: Callable[[float], bool] + + +class BinaryCaseFactory(NamedTuple): + cond_factory: CondFactory + result_re_group: int + + def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: + in_cond = self.cond_factory(groups) + + s_result = groups[self.result_re_group] + if m := r_array_element.match(s_result): + raise ValueParseError(s_result) # TODO + elif m := r_code.match(s_result): + value = parse_value(m.group(1)) + out_cond = make_eq(value) + elif m := r_approx_value.match(s_result): + value = parse_value(m.group(1)) + out_cond = make_rough_eq(value) + else: + raise ValueParseError(s_result) + + return BinaryCase(in_cond, out_cond) + + +binary_pattern_to_case_factory: Dict[Pattern, BinaryCaseFactory] = { re.compile( "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" - ): make_bin_and_factory(make_eq, make_eq), - re.compile( - "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)" - ): make_bin_or_factory(make_eq), - re.compile( - "If ``x1_i`` is either (.+) or (.+) and ``x2_i`` is (.+), the result is (.+)" - ): lambda v1, v2, v3: ( - lambda i1, i2: make_or(make_eq(v1), make_eq(v2))(i1) and make_eq(v3)(i2) - ), - re.compile( - "If ``x1_i`` is (.+) and ``x2_i`` is either (.+) or (.+), the result is (.+)" - ): lambda v1, v2, v3: ( - lambda i1, i2: make_eq(v1)(i1) and make_or(make_eq(v2), make_eq(v3))(i2) + ): BinaryCaseFactory( + AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)), 2 ), - re.compile( - "If ``x1_i`` is either (.+) or (.+) and " - "``x2_i`` is either (.+) or (.+), the result is (.+)" - ): lambda v1, v2, v3, v4: ( - lambda i1, i2: ( - make_or(make_eq(v1), make_eq(v2))(i1) - and make_or(make_eq(v3), make_eq(v4))(i2) - ) - ), - re.compile( - "If ``x1_i`` and ``x2_i`` have the same mathematical sign, " - "the result has a (.+)" - ): lambda: same_sign, - re.compile( - "If ``x1_i`` and ``x2_i`` have different mathematical signs, " - "the result has a (.+)" - ): lambda: diff_sign, - re.compile( - "If ``x1_i`` and ``x2_i`` have the same mathematical sign and " - "are both (.+), the result has a (.+)" - ): lambda v: lambda i1, i2: same_sign(i1, i2) - and make_eq(v)(i1) - and make_eq(v)(i2), - re.compile( - "If ``x1_i`` and ``x2_i`` have different mathematical signs and " - "are both (.+), the result has a (.+)" - ): lambda v: lambda i1, i2: diff_sign(i1, i2) - and make_eq(v)(i1) - and make_eq(v)(i2), + # re.compile( + # "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+" + # ): lambda v: lambda _, i2: make_eq(v)(i2), + # re.compile( + # "If ``x1_i`` is (.+) and ``x2_i`` is not equal to (.+), the result is (.+)" + # ): make_bin_and_factory(make_eq, lambda v: lambda i: i != v), + # re.compile( + # "If ``x1_i`` is greater than (.+), ``x1_i`` is (.+), " + # "and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_multi_and_factory([make_gt, make_eq], [make_eq]), + # re.compile( + # "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), " + # "and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq]), + # re.compile( + # "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), " + # "and ``x2_i`` is not (.+), the result is (.+)" + # ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq, make_neq]), + # re.compile( + # "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), " + # "and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_multi_and_factory([make_eq], [make_lt, make_eq]), + # re.compile( + # "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), " + # "and ``x2_i`` is not (.+), the result is (.+)" + # ): make_bin_multi_and_factory([make_eq], [make_lt, make_neq]), + # re.compile( + # "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), " + # "and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_multi_and_factory([make_eq], [make_gt, make_eq]), + # re.compile( + # "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), " + # "and ``x2_i`` is not (.+), the result is (.+)" + # ): make_bin_multi_and_factory([make_eq], [make_gt, make_neq]), + # re.compile( + # "If ``x1_i`` is greater than (.+) and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_and_factory(make_gt, make_eq), + # re.compile( + # "If ``x1_i`` is (.+) and ``x2_i`` is greater than (.+), the result is (.+)" + # ): make_bin_and_factory(make_eq, make_gt), + # re.compile( + # "If ``x1_i`` is less than (.+) and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_and_factory(make_lt, make_eq), + # re.compile( + # "If ``x1_i`` is (.+) and ``x2_i`` is less than (.+), the result is (.+)" + # ): make_bin_and_factory(make_eq, make_lt), + # re.compile( + # "If ``x1_i`` is not (?:equal to )?(.+) and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_and_factory(make_neq, make_eq), + # re.compile( + # "If ``x1_i`` is (.+) and ``x2_i`` is not (?:equal to )?(.+), the result is (.+)" + # ): make_bin_and_factory(make_eq, make_neq), + # re.compile( + # r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), " + # "the result is (.+)" + # ): make_bin_and_factory(absify_cond_factory(make_gt), make_eq), + # re.compile( + # r"If `abs\(x1_i\)` is less than (.+) and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_and_factory(absify_cond_factory(make_lt), make_eq), + # re.compile( + # r"If `abs\(x1_i\)` is (.+) and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_and_factory(absify_cond_factory(make_eq), make_eq), + # re.compile( + # "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_and_factory(make_eq, make_eq), + # re.compile( + # "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)" + # ): make_bin_or_factory(make_eq), + # re.compile( + # "If ``x1_i`` is either (.+) or (.+) and ``x2_i`` is (.+), the result is (.+)" + # ): lambda v1, v2, v3: ( + # lambda i1, i2: make_or(make_eq(v1), make_eq(v2))(i1) and make_eq(v3)(i2) + # ), + # re.compile( + # "If ``x1_i`` is (.+) and ``x2_i`` is either (.+) or (.+), the result is (.+)" + # ): lambda v1, v2, v3: ( + # lambda i1, i2: make_eq(v1)(i1) and make_or(make_eq(v2), make_eq(v3))(i2) + # ), + # re.compile( + # "If ``x1_i`` is either (.+) or (.+) and " + # "``x2_i`` is either (.+) or (.+), the result is (.+)" + # ): lambda v1, v2, v3, v4: ( + # lambda i1, i2: ( + # make_or(make_eq(v1), make_eq(v2))(i1) + # and make_or(make_eq(v3), make_eq(v4))(i2) + # ) + # ), + # re.compile( + # "If ``x1_i`` and ``x2_i`` have the same mathematical sign, " + # "the result has a (.+)" + # ): lambda: same_sign, + # re.compile( + # "If ``x1_i`` and ``x2_i`` have different mathematical signs, " + # "the result has a (.+)" + # ): lambda: diff_sign, + # re.compile( + # "If ``x1_i`` and ``x2_i`` have the same mathematical sign and " + # "are both (.+), the result has a (.+)" + # ): lambda v: ( + # lambda i1, i2: same_sign(i1, i2) and make_eq(v)(i1) and make_eq(v)(i2) + # ), + # re.compile( + # "If ``x1_i`` and ``x2_i`` have different mathematical signs and " + # "are both (.+), the result has a (.+)" + # ): lambda v: ( + # lambda i1, i2: diff_sign(i1, i2) and make_eq(v)(i1) and make_eq(v)(i2) + # ), # TODO: support capturing values that come after the result # re.compile(r"If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is .+, the \"sign\" of .+ is implementation-defined") # re.compile(r"If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the \"sign\" of (.+) is implementation-defined") } -def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]: +def parse_binary_docstring(docstring: str) -> List[BinaryCase]: match = r_special_cases.search(docstring) if match is None: - return {} - cases = match.group(1).split("\n")[:-1] - condition_to_result = {} - for line in cases: + return [] + lines = match.group(1).split("\n")[:-1] + cases = [] + for line in lines: if m := r_case.match(line): case = m.group(1) else: warn(f"line not machine-readable: '{line}'") continue - for pattern, make_cond in binary_pattern_to_condition_factory.items(): + for pattern, make_case in binary_pattern_to_case_factory.items(): if m := pattern.search(case): - *s_values, s_result = m.groups() - try: - values = [parse_inline_code(v) for v in s_values] - except ValueParseError as e: - warn(f"value not machine-readable: '{e.value}'") - break - cond = make_cond(*values) try: - result = parse_result(s_result) + case = make_case(m.groups()) except ValueParseError as e: - warn(f"result not machine-readable: '{e.value}'") + warn(f"not machine-readable: '{e.value}'") break - condition_to_result[cond] = result + cases.append(case) break else: if not r_remaining_case.search(case): warn(f"case not machine-readable: '{case}'") - return condition_to_result + return cases # Here be the tests @@ -461,16 +596,16 @@ def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]: warn(f"{func=} has no parameters") continue if param_names[0] == "x": - if condition_to_result := parse_unary_docstring(stub.__doc__): - p = pytest.param(stub.__name__, func, condition_to_result, id=stub.__name__) - unary_params.append(p) + # if cases := parse_unary_docstring(stub.__doc__): + # p = pytest.param(stub.__name__, func, cases, id=stub.__name__) + # unary_params.append(p) continue if len(sig.parameters) == 1: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": - if condition_to_result := parse_binary_docstring(stub.__doc__): - p = pytest.param(stub.__name__, func, condition_to_result, id=stub.__name__) + if cases := parse_binary_docstring(stub.__doc__): + p = pytest.param(stub.__name__, func, cases, id=stub.__name__) binary_params.append(p) continue else: @@ -486,14 +621,14 @@ def parse_binary_docstring(docstring: str) -> Dict[Callable, Result]: # indicating we should modify the array strategy being used. -@pytest.mark.parametrize("func_name, func, condition_to_result", unary_params) +@pytest.mark.parametrize("func_name, func, cases", unary_params) @given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) -def test_unary(func_name, func, condition_to_result, x): +def test_unary(func_name, func, cases, x): res = func(x) good_example = False for idx in sh.ndindex(res.shape): in_ = float(x[idx]) - for cond, result in condition_to_result.items(): + for cond, result in cases.items(): if cond(in_): good_example = True out = float(res[idx]) @@ -519,7 +654,7 @@ def test_unary(func_name, func, condition_to_result, x): assume(good_example) -@pytest.mark.parametrize("func_name, func, condition_to_result", binary_params) +@pytest.mark.parametrize("func_name, func, cases", binary_params) @given( *hh.two_mutual_arrays( dtypes=dh.float_dtypes, @@ -527,34 +662,35 @@ def test_unary(func_name, func, condition_to_result, x): ) ) @settings(suppress_health_check=[HealthCheck.filter_too_much]) # TODO: remove -def test_binary(func_name, func, condition_to_result, x1, x2): +def test_binary(func_name, func, cases, x1, x2): res = func(x1, x2) good_example = False for l_idx, r_idx, o_idx in sh.iter_indices(x1.shape, x2.shape, res.shape): l = float(x1[l_idx]) r = float(x2[r_idx]) - for cond, result in condition_to_result.items(): - if cond(l, r): + for case in cases: + if case.cond(l, r): good_example = True out = float(res[o_idx]) - f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" - f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" - f_out = f"{sh.fmt_idx('out', o_idx)}={out}" - if result.strict_check: - msg = ( - f"{f_out}, but should be {result.repr_} [{func_name}()]\n" - f"{f_left}, {f_right}" - ) - if math.isnan(result.value): - assert math.isnan(out), msg - else: - assert out == result.value, msg - else: - assert math.isfinite(result.value) # sanity check - assert math.isclose(out, result.value, abs_tol=0.1), ( - f"{f_out}, but should be roughly {result.repr_}={result.value} " - f"[{func_name}()]\n" - f"{f_left}, {f_right}" - ) + assert case.check_result(out) + # f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" + # f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" + # f_out = f"{sh.fmt_idx('out', o_idx)}={out}" + # if result.strict_check: + # msg = ( + # f"{f_out}, but should be {result.repr_} [{func_name}()]\n" + # f"{f_left}, {f_right}" + # ) + # if math.isnan(result.value): + # assert math.isnan(out), msg + # else: + # assert out == result.value, msg + # else: + # assert math.isfinite(result.value) # sanity check + # assert math.isclose(out, result.value, abs_tol=0.1), ( + # f"{f_out}, but should be roughly {result.repr_}={result.value} " + # f"[{func_name}()]\n" + # f"{f_left}, {f_right}" + # ) break assume(good_example) From e6f1064d67231257344604bddbd826f1f87cf943 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Feb 2022 12:47:47 +0000 Subject: [PATCH 08/63] Factories for result check functions --- array_api_tests/test_special_cases.py | 82 +++++++++++++++++++-------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index cc76bb71..3a82d87a 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -305,7 +305,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: r_not_code = re.compile(f"not (?:equal to )?{r_code.pattern}") -r_array_element = re.compile(r"``([+-]?)x[12]_i``") +r_array_element = re.compile(r"``([+-]?)x([12])_i``") r_gt = re.compile(f"greater than {r_code.pattern}") r_lt = re.compile(f"less than {r_code.pattern}") r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}") @@ -333,8 +333,8 @@ def cond(i1: float, i2: float) -> bool: return _cond(i2) return cond - # this branch must come after checking for array elements - elif m := r_code.match(group): + + if m := r_code.match(group): value = parse_value(m.group(1)) _cond = make_eq(value) elif m := r_not_code.match(group): @@ -398,39 +398,75 @@ def cond(i1: float, i2: float) -> bool: return cond + def __repr__(self) -> str: + f_cond_factories = ", ".join( + repr(cond_factory) for cond_factory in self.cond_factories + ) + return f"{self.__class__.__name__}({f_cond_factories})" -class BinaryCase(NamedTuple): - cond: BinaryCheck - check_result: Callable[[float], bool] +BinaryResultCheck = Callable[[float, float, float], bool] -class BinaryCaseFactory(NamedTuple): - cond_factory: CondFactory - result_re_group: int - def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: - in_cond = self.cond_factory(groups) +class ResultCheckFactory(NamedTuple): + re_group: int + + def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck: + group = groups[self.re_group] + + if m := r_array_element.match(group): + cond_factory = make_eq if m.group(1) != "-" else make_neq + + if m.group(2) == "1": + + def cond(i1: float, i2: float, result: float) -> bool: + _cond = cond_factory(i1) + return _cond(result) + + else: - s_result = groups[self.result_re_group] - if m := r_array_element.match(s_result): - raise ValueParseError(s_result) # TODO - elif m := r_code.match(s_result): + def cond(i1: float, i2: float, result: float) -> bool: + _cond = cond_factory(i2) + return _cond(result) + + return cond + + if m := r_code.match(group): value = parse_value(m.group(1)) - out_cond = make_eq(value) - elif m := r_approx_value.match(s_result): + _cond = make_eq(value) + elif m := r_approx_value.match(group): value = parse_value(m.group(1)) - out_cond = make_rough_eq(value) + _cond = make_rough_eq(value) else: - raise ValueParseError(s_result) + raise ValueParseError(group) + + def cond(i1: float, i2: float, result: float) -> bool: + return _cond(result) + + return cond - return BinaryCase(in_cond, out_cond) + +class BinaryCase(NamedTuple): + cond: BinaryCheck + check_result: BinaryResultCheck + + +class BinaryCaseFactory(NamedTuple): + cond_factory: CondFactory + check_result_factory: ResultCheckFactory + + def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: + cond = self.cond_factory(groups) + check_result = self.check_result_factory(groups) + return BinaryCase(cond, check_result) binary_pattern_to_case_factory: Dict[Pattern, BinaryCaseFactory] = { re.compile( "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" ): BinaryCaseFactory( - AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)), 2 + AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)), + ResultCheckFactory(2), ), # re.compile( # "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+" @@ -671,8 +707,8 @@ def test_binary(func_name, func, cases, x1, x2): for case in cases: if case.cond(l, r): good_example = True - out = float(res[o_idx]) - assert case.check_result(out) + o = float(res[o_idx]) + assert case.check_result(l, r, o) # f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" # f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" # f_out = f"{sh.fmt_idx('out', o_idx)}={out}" From 74cc8e5f2993f7c435a595a925fb0c7e05c79b06 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Feb 2022 12:59:48 +0000 Subject: [PATCH 09/63] Skip redundant special cases --- array_api_tests/test_special_cases.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 3a82d87a..cc9ec0c7 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -174,7 +174,7 @@ def diff_sign(i1: float, i2: float) -> bool: repr_to_value = { "NaN": float("nan"), - "infinity": float("infinity"), + "infinity": float("inf"), "0": 0.0, "1": 1.0, } @@ -581,6 +581,9 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: } +r_redundant_case = re.compile("result.+determined by the rule already stated above") + + def parse_binary_docstring(docstring: str) -> List[BinaryCase]: match = r_special_cases.search(docstring) if match is None: @@ -593,6 +596,8 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: else: warn(f"line not machine-readable: '{line}'") continue + if r_redundant_case.search(case): + continue for pattern, make_case in binary_pattern_to_case_factory.items(): if m := pattern.search(case): try: From 03b9c003a7780cc6a1aa0d7a1761a0e055eb4c66 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Feb 2022 14:59:15 +0000 Subject: [PATCH 10/63] Generalise not special cases --- array_api_tests/test_special_cases.py | 101 ++++++-------------------- 1 file changed, 21 insertions(+), 80 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index cc9ec0c7..22eeaa61 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -304,7 +304,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: ... -r_not_code = re.compile(f"not (?:equal to )?{r_code.pattern}") +r_not = re.compile("not (?:equal to )?(.+)") r_array_element = re.compile(r"``([+-]?)x([12])_i``") r_gt = re.compile(f"greater than {r_code.pattern}") r_lt = re.compile(f"less than {r_code.pattern}") @@ -313,10 +313,10 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: class ValueCondFactory(NamedTuple): # TODO: inherit from CondFactory as well input_: Union[Literal["i1"], Literal["i2"]] - re_group: int + groups_i: int def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: - group = groups[self.re_group] + group = groups[self.groups_i] if m := r_array_element.match(group): cond_factory = make_eq if m.group(1) != "-" else make_neq @@ -334,12 +334,15 @@ def cond(i1: float, i2: float) -> bool: return cond + if m := r_not.match(group): + group = m.group(1) + notify = True + else: + notify = False + if m := r_code.match(group): value = parse_value(m.group(1)) _cond = make_eq(value) - elif m := r_not_code.match(group): - value = parse_value(m.group(1)) - _cond = make_neq(value) elif m := r_gt.match(group): value = parse_value(m.group(1)) _cond = make_gt(value) @@ -364,24 +367,26 @@ def cond(i1: float, i2: float) -> bool: _cond = lambda i: math.isfinite(i) and i != 0 elif group == "an integer value": _cond = lambda i: i.is_integer() - elif group == "not an integer value": - _cond = lambda i: not i.is_integer() elif group == "an odd integer value": _cond = lambda i: i.is_integer() and i % 2 == 1 - elif group == "not an odd integer value": - _cond = lambda i: not (i.is_integer() and i % 2 == 1) else: + print(f"{group=}") raise ValueParseError(group) + if notify: + final_cond = lambda i: not _cond(i) + else: + final_cond = _cond + if self.input_ == "i1": def cond(i1: float, i2: float) -> bool: - return _cond(i1) + return final_cond(i1) else: def cond(i1: float, i2: float) -> bool: - return _cond(i2) + return final_cond(i2) return cond @@ -409,10 +414,10 @@ def __repr__(self) -> str: class ResultCheckFactory(NamedTuple): - re_group: int + groups_i: int def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck: - group = groups[self.re_group] + group = groups[self.groups_i] if m := r_array_element.match(group): cond_factory = make_eq if m.group(1) != "-" else make_neq @@ -472,54 +477,9 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: # "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+" # ): lambda v: lambda _, i2: make_eq(v)(i2), # re.compile( - # "If ``x1_i`` is (.+) and ``x2_i`` is not equal to (.+), the result is (.+)" - # ): make_bin_and_factory(make_eq, lambda v: lambda i: i != v), - # re.compile( - # "If ``x1_i`` is greater than (.+), ``x1_i`` is (.+), " - # "and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_multi_and_factory([make_gt, make_eq], [make_eq]), - # re.compile( - # "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), " - # "and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq]), - # re.compile( - # "If ``x1_i`` is less than (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), " - # "and ``x2_i`` is not (.+), the result is (.+)" - # ): make_bin_multi_and_factory([make_lt, make_eq], [make_eq, make_neq]), - # re.compile( - # "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), " - # "and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_multi_and_factory([make_eq], [make_lt, make_eq]), - # re.compile( - # "If ``x1_i`` is (.+), ``x2_i`` is less than (.+), " - # "and ``x2_i`` is not (.+), the result is (.+)" - # ): make_bin_multi_and_factory([make_eq], [make_lt, make_neq]), - # re.compile( - # "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), " + # "If ``x1_i`` is (.+), ``x1_i`` (.+), " # "and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_multi_and_factory([make_eq], [make_gt, make_eq]), - # re.compile( - # "If ``x1_i`` is (.+), ``x2_i`` is greater than (.+), " - # "and ``x2_i`` is not (.+), the result is (.+)" - # ): make_bin_multi_and_factory([make_eq], [make_gt, make_neq]), - # re.compile( - # "If ``x1_i`` is greater than (.+) and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_and_factory(make_gt, make_eq), - # re.compile( - # "If ``x1_i`` is (.+) and ``x2_i`` is greater than (.+), the result is (.+)" - # ): make_bin_and_factory(make_eq, make_gt), - # re.compile( - # "If ``x1_i`` is less than (.+) and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_and_factory(make_lt, make_eq), - # re.compile( - # "If ``x1_i`` is (.+) and ``x2_i`` is less than (.+), the result is (.+)" - # ): make_bin_and_factory(make_eq, make_lt), - # re.compile( - # "If ``x1_i`` is not (?:equal to )?(.+) and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_and_factory(make_neq, make_eq), - # re.compile( - # "If ``x1_i`` is (.+) and ``x2_i`` is not (?:equal to )?(.+), the result is (.+)" - # ): make_bin_and_factory(make_eq, make_neq), + # ) # re.compile( # r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), " # "the result is (.+)" @@ -537,25 +497,6 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: # "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)" # ): make_bin_or_factory(make_eq), # re.compile( - # "If ``x1_i`` is either (.+) or (.+) and ``x2_i`` is (.+), the result is (.+)" - # ): lambda v1, v2, v3: ( - # lambda i1, i2: make_or(make_eq(v1), make_eq(v2))(i1) and make_eq(v3)(i2) - # ), - # re.compile( - # "If ``x1_i`` is (.+) and ``x2_i`` is either (.+) or (.+), the result is (.+)" - # ): lambda v1, v2, v3: ( - # lambda i1, i2: make_eq(v1)(i1) and make_or(make_eq(v2), make_eq(v3))(i2) - # ), - # re.compile( - # "If ``x1_i`` is either (.+) or (.+) and " - # "``x2_i`` is either (.+) or (.+), the result is (.+)" - # ): lambda v1, v2, v3, v4: ( - # lambda i1, i2: ( - # make_or(make_eq(v1), make_eq(v2))(i1) - # and make_or(make_eq(v3), make_eq(v4))(i2) - # ) - # ), - # re.compile( # "If ``x1_i`` and ``x2_i`` have the same mathematical sign, " # "the result has a (.+)" # ): lambda: same_sign, From 0058634caf7e0fbad064932348a07621e3e42a1b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Feb 2022 15:21:17 +0000 Subject: [PATCH 11/63] Parse either cases --- array_api_tests/test_special_cases.py | 34 ++++++++++++++++++--------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 22eeaa61..22d10311 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -306,13 +306,14 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: r_not = re.compile("not (?:equal to )?(.+)") r_array_element = re.compile(r"``([+-]?)x([12])_i``") +r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}") r_gt = re.compile(f"greater than {r_code.pattern}") r_lt = re.compile(f"less than {r_code.pattern}") -r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}") -class ValueCondFactory(NamedTuple): # TODO: inherit from CondFactory as well - input_: Union[Literal["i1"], Literal["i2"]] +@dataclass +class ValueCondFactory(CondFactory): + input_: Union[Literal["i1"], Literal["i2"], Literal["either"]] groups_i: int def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: @@ -327,6 +328,7 @@ def cond(i1: float, i2: float) -> bool: return _cond(i1) else: + assert self.input_ == "i2" # sanity check def cond(i1: float, i2: float) -> bool: _cond = cond_factory(i1) @@ -383,11 +385,16 @@ def cond(i1: float, i2: float) -> bool: def cond(i1: float, i2: float) -> bool: return final_cond(i1) - else: + elif self.input_ == "i2": def cond(i1: float, i2: float) -> bool: return final_cond(i2) + else: + + def cond(i1: float, i2: float) -> bool: + return final_cond(i1) or final_cond(i2) + return cond @@ -490,12 +497,12 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: # re.compile( # r"If `abs\(x1_i\)` is (.+) and ``x2_i`` is (.+), the result is (.+)" # ): make_bin_and_factory(absify_cond_factory(make_eq), make_eq), - # re.compile( - # "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_and_factory(make_eq, make_eq), - # re.compile( - # "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_or_factory(make_eq), + re.compile( + "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)" + ): BinaryCaseFactory( + ValueCondFactory("either", 0), + ResultCheckFactory(1), + ), # re.compile( # "If ``x1_i`` and ``x2_i`` have the same mathematical sign, " # "the result has a (.+)" @@ -516,7 +523,6 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: # ): lambda v: ( # lambda i1, i2: diff_sign(i1, i2) and make_eq(v)(i1) and make_eq(v)(i2) # ), - # TODO: support capturing values that come after the result # re.compile(r"If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is .+, the \"sign\" of .+ is implementation-defined") # re.compile(r"If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the \"sign\" of (.+) is implementation-defined") } @@ -676,3 +682,9 @@ def test_binary(func_name, func, cases, x1, x2): # ) break assume(good_example) + + +# TODO: remove +print( + f"no. of cases={sum(len(cases) for _, _, cases in binary_params)}" +) From 0116ac50bf077dec0543025094d74bd20416fc0f Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Feb 2022 16:29:20 +0000 Subject: [PATCH 12/63] Parse sign cases --- array_api_tests/test_special_cases.py | 101 +++++++++++++++++--------- 1 file changed, 65 insertions(+), 36 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 22d10311..1e1bfbe3 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -313,11 +313,11 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: @dataclass class ValueCondFactory(CondFactory): - input_: Union[Literal["i1"], Literal["i2"], Literal["either"]] - groups_i: int + input_: Union[Literal["i1"], Literal["i2"], Literal["either"], Literal["both"]] + re_groups_i: int def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: - group = groups[self.groups_i] + group = groups[self.re_groups_i] if m := r_array_element.match(group): cond_factory = make_eq if m.group(1) != "-" else make_neq @@ -390,11 +390,16 @@ def cond(i1: float, i2: float) -> bool: def cond(i1: float, i2: float) -> bool: return final_cond(i2) - else: + elif self.input_ == "either": def cond(i1: float, i2: float) -> bool: return final_cond(i1) or final_cond(i2) + else: + + def cond(i1: float, i2: float) -> bool: + return final_cond(i1) and final_cond(i2) + return cond @@ -417,14 +422,28 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({f_cond_factories})" +@dataclass +class SignCondFactory(CondFactory): + re_groups_i: int + + def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: + group = groups[self.re_groups_i] + if group == "the same mathematical sign": + return same_sign + elif group == "different mathematical signs": + return diff_sign + else: + raise ValueParseError(group) + + BinaryResultCheck = Callable[[float, float, float], bool] class ResultCheckFactory(NamedTuple): - groups_i: int + re_groups_i: int def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck: - group = groups[self.groups_i] + group = groups[self.re_groups_i] if m := r_array_element.match(group): cond_factory = make_eq if m.group(1) != "-" else make_neq @@ -458,6 +477,29 @@ def cond(i1: float, i2: float, result: float) -> bool: return cond +class ResultSignCheckFactory(ResultCheckFactory): + def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck: + group = groups[self.re_groups_i] + if group == "positive": + + def cond(i1: float, i2: float, result: float) -> bool: + if math.isnan(result): + return True + return result > 0 or ph.is_pos_zero(result) + + elif group == "negative": + + def cond(i1: float, i2: float, result: float) -> bool: + if math.isnan(result): + return True + return result < 0 or ph.is_neg_zero(result) + + else: + raise ValueParseError(group) + + return cond + + class BinaryCase(NamedTuple): cond: BinaryCheck check_result: BinaryResultCheck @@ -473,6 +515,8 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: return BinaryCase(cond, check_result) +r_result_sign = re.compile("([a-z]+) mathematical sign") + binary_pattern_to_case_factory: Dict[Pattern, BinaryCaseFactory] = { re.compile( "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" @@ -499,32 +543,23 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: # ): make_bin_and_factory(absify_cond_factory(make_eq), make_eq), re.compile( "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)" + ): BinaryCaseFactory(ValueCondFactory("either", 0), ResultCheckFactory(1)), + re.compile( + "If ``x1_i`` and ``x2_i`` have (.+signs?), " + f"the result has a {r_result_sign.pattern}" + ): BinaryCaseFactory(SignCondFactory(0), ResultSignCheckFactory(1)), + re.compile( + "If ``x1_i`` and ``x2_i`` have (.+signs?) and are both (.+), " + f"the result has a {r_result_sign.pattern}" ): BinaryCaseFactory( - ValueCondFactory("either", 0), - ResultCheckFactory(1), + AndCondFactory(SignCondFactory(0), ValueCondFactory("both", 1)), + ResultSignCheckFactory(2), ), - # re.compile( - # "If ``x1_i`` and ``x2_i`` have the same mathematical sign, " - # "the result has a (.+)" - # ): lambda: same_sign, - # re.compile( - # "If ``x1_i`` and ``x2_i`` have different mathematical signs, " - # "the result has a (.+)" - # ): lambda: diff_sign, - # re.compile( - # "If ``x1_i`` and ``x2_i`` have the same mathematical sign and " - # "are both (.+), the result has a (.+)" - # ): lambda v: ( - # lambda i1, i2: same_sign(i1, i2) and make_eq(v)(i1) and make_eq(v)(i2) - # ), - # re.compile( - # "If ``x1_i`` and ``x2_i`` have different mathematical signs and " - # "are both (.+), the result has a (.+)" - # ): lambda v: ( - # lambda i1, i2: diff_sign(i1, i2) and make_eq(v)(i1) and make_eq(v)(i2) - # ), - # re.compile(r"If ``x1_i`` and ``x2_i`` have the same mathematical sign, the result has a (.+), unless the result is (.+)\. If the result is .+, the \"sign\" of .+ is implementation-defined") - # re.compile(r"If ``x1_i`` and ``x2_i`` have different mathematical signs, the result has a (.+), unless the result is (.+)\. If the result is (.+), the \"sign\" of (.+) is implementation-defined") + re.compile( + "If ``x1_i`` and ``x2_i`` have (.+signs?), the result has a " + rf"{r_result_sign.pattern} , unless the result is (.+)\. If the result " + r"is ``NaN``, the \"sign\" of ``NaN`` is implementation-defined\." + ): BinaryCaseFactory(SignCondFactory(0), ResultSignCheckFactory(1)), } @@ -682,9 +717,3 @@ def test_binary(func_name, func, cases, x1, x2): # ) break assume(good_example) - - -# TODO: remove -print( - f"no. of cases={sum(len(cases) for _, _, cases in binary_params)}" -) From bc85e7adab0ae7ec92b21120217328e947d5a30b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 16 Feb 2022 10:01:35 +0000 Subject: [PATCH 13/63] Parse more awkward equality special cases --- array_api_tests/test_special_cases.py | 34 ++++++++++++++++++++------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 1e1bfbe3..b1f5d11e 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1,6 +1,7 @@ import inspect import math import re +from dataclasses import dataclass from typing import ( Callable, Dict, @@ -15,7 +16,6 @@ from warnings import warn import pytest -from attr import dataclass from hypothesis import HealthCheck, assume, given, settings from . import dtype_helpers as dh @@ -524,13 +524,28 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)), ResultCheckFactory(2), ), - # re.compile( - # "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+" - # ): lambda v: lambda _, i2: make_eq(v)(i2), - # re.compile( - # "If ``x1_i`` is (.+), ``x1_i`` (.+), " - # "and ``x2_i`` is (.+), the result is (.+)" - # ) + re.compile( + "If ``x1_i`` is (.+), ``x1_i`` (.+), " + "and ``x2_i`` is (.+), the result is (.+)" + ): BinaryCaseFactory( + AndCondFactory( + ValueCondFactory("i1", 0), + ValueCondFactory("i1", 1), + ValueCondFactory("i2", 2), + ), + ResultCheckFactory(3), + ), + re.compile( + "If ``x1_i`` is (.+), ``x2_i`` (.+), " + "and ``x2_i`` is (.+), the result is (.+)" + ): BinaryCaseFactory( + AndCondFactory( + ValueCondFactory("i1", 0), + ValueCondFactory("i2", 1), + ValueCondFactory("i2", 2), + ), + ResultCheckFactory(3), + ), # re.compile( # r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), " # "the result is (.+)" @@ -560,6 +575,9 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: rf"{r_result_sign.pattern} , unless the result is (.+)\. If the result " r"is ``NaN``, the \"sign\" of ``NaN`` is implementation-defined\." ): BinaryCaseFactory(SignCondFactory(0), ResultSignCheckFactory(1)), + re.compile( + "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+" + ): BinaryCaseFactory(ValueCondFactory("i2", 0), ResultCheckFactory(1)), } From 1891a79fdb7d126b97502502d04db18f34f08b00 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 16 Feb 2022 10:15:13 +0000 Subject: [PATCH 14/63] Parse abs special cases --- array_api_tests/test_special_cases.py | 33 +++++++++++++++++++-------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index b1f5d11e..2b58cdb2 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -403,6 +403,21 @@ def cond(i1: float, i2: float) -> bool: return cond +@dataclass +class AbsCondFactory(CondFactory): + cond_factory: CondFactory + + def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: + _cond = self.cond_factory(groups) + + def cond(i1: float, i2: float) -> bool: + i1 = abs(i1) + i2 = abs(i2) + return _cond(i1, i2) + + return cond + + class AndCondFactory(CondFactory): def __init__(self, *cond_factories: CondFactory): self.cond_factories = cond_factories @@ -546,16 +561,14 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: ), ResultCheckFactory(3), ), - # re.compile( - # r"If `abs\(x1_i\)` is greater than (.+) and ``x2_i`` is (.+), " - # "the result is (.+)" - # ): make_bin_and_factory(absify_cond_factory(make_gt), make_eq), - # re.compile( - # r"If `abs\(x1_i\)` is less than (.+) and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_and_factory(absify_cond_factory(make_lt), make_eq), - # re.compile( - # r"If `abs\(x1_i\)` is (.+) and ``x2_i`` is (.+), the result is (.+)" - # ): make_bin_and_factory(absify_cond_factory(make_eq), make_eq), + re.compile( + r"If ``abs\(x1_i\)`` is (.+) and ``x2_i`` is (.+), the result is (.+)" + ): BinaryCaseFactory( + AndCondFactory( + AbsCondFactory(ValueCondFactory("i1", 0)), ValueCondFactory("i2", 1) + ), + ResultCheckFactory(2), + ), re.compile( "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)" ): BinaryCaseFactory(ValueCondFactory("either", 0), ResultCheckFactory(1)), From d7db62dff7bd42dcb59d212e12a2c878974b5fa4 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 16 Feb 2022 13:56:45 +0000 Subject: [PATCH 15/63] Cond reprs, merge abs cond logic, fix approx cond factories --- array_api_tests/test_special_cases.py | 189 +++++++++++++++----------- 1 file changed, 113 insertions(+), 76 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 2b58cdb2..9945801c 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -299,8 +299,19 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: return cases -class CondFactory(Protocol): - def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: +class BinaryCond(NamedTuple): + cond: BinaryCheck + repr_: str + + def __call__(self, i1: float, i2: float) -> bool: + return self.cond(i1, i2) + + def __repr__(self): + return self.repr_ + + +class BinaryCondFactory(Protocol): + def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: ... @@ -310,31 +321,43 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: r_gt = re.compile(f"greater than {r_code.pattern}") r_lt = re.compile(f"less than {r_code.pattern}") +x1_i = "x1ᵢ" +x2_i = "x2ᵢ" + @dataclass -class ValueCondFactory(CondFactory): +class ValueCondFactory(BinaryCondFactory): input_: Union[Literal["i1"], Literal["i2"], Literal["either"], Literal["both"]] re_groups_i: int + abs_: bool = False - def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: + def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: group = groups[self.re_groups_i] if m := r_array_element.match(group): - cond_factory = make_eq if m.group(1) != "-" else make_neq + assert not self.abs_ # sanity check + sign = m.group(1) + if sign == "-": + signer = lambda i: -i + else: + signer = lambda i: i + if self.input_ == "i1": + repr_ = f"{x1_i} == {sign}{x2_i}" def cond(i1: float, i2: float) -> bool: - _cond = cond_factory(i2) + _cond = make_eq(signer(i2)) return _cond(i1) else: assert self.input_ == "i2" # sanity check + repr_ = f"{x2_i} == {sign}{x1_i}" def cond(i1: float, i2: float) -> bool: - _cond = cond_factory(i1) + _cond = make_eq(signer(i1)) return _cond(i2) - return cond + return BinaryCond(cond, repr_) if m := r_not.match(group): group = m.group(1) @@ -345,34 +368,45 @@ def cond(i1: float, i2: float) -> bool: if m := r_code.match(group): value = parse_value(m.group(1)) _cond = make_eq(value) + repr_template = "{} == " + str(value) elif m := r_gt.match(group): value = parse_value(m.group(1)) _cond = make_gt(value) + repr_template = "{} > " + str(value) elif m := r_lt.match(group): value = parse_value(m.group(1)) _cond = make_lt(value) + repr_template = "{} < " + str(value) elif m := r_either_code.match(group): v1 = parse_value(m.group(1)) v2 = parse_value(m.group(2)) _cond = make_or(make_eq(v1), make_eq(v2)) + repr_template = "{} == " + str(v1) + " or {} == " + str(v2) elif group in ["finite", "a finite number"]: _cond = math.isfinite + repr_template = "isfinite({})" elif group in "a positive (i.e., greater than ``0``) finite number": _cond = lambda i: math.isfinite(i) and i > 0 + repr_template = "isfinite({}) and {} > 0" elif group == "a negative (i.e., less than ``0``) finite number": _cond = lambda i: math.isfinite(i) and i < 0 + repr_template = "isfinite({}) and {} < 0" elif group == "positive": _cond = lambda i: math.copysign(1, i) == 1 + repr_template = "copysign(1, {}) == 1" elif group == "negative": _cond = lambda i: math.copysign(1, i) == -1 + repr_template = "copysign(1, {}) == -1" elif "nonzero finite" in group: _cond = lambda i: math.isfinite(i) and i != 0 + repr_template = "copysign(1, {}) == -1" elif group == "an integer value": _cond = lambda i: i.is_integer() + repr_template = "{}.is_integer()" elif group == "an odd integer value": _cond = lambda i: i.is_integer() and i % 2 == 1 + repr_template = "{}.is_integer() and {} % 2 == 1" else: - print(f"{group=}") raise ValueParseError(group) if notify: @@ -380,65 +414,59 @@ def cond(i1: float, i2: float) -> bool: else: final_cond = _cond + f_i1 = x1_i + f_i2 = x2_i + if self.abs_: + f_i1 = f"abs{f_i1}" + f_i2 = f"abs{f_i2}" + if self.input_ == "i1": + repr_ = repr_template.replace("{}", f_i1) def cond(i1: float, i2: float) -> bool: return final_cond(i1) elif self.input_ == "i2": + repr_ = repr_template.replace("{}", f_i2) def cond(i1: float, i2: float) -> bool: return final_cond(i2) elif self.input_ == "either": + repr_ = f"({repr_template.replace('{}', f_i1)}) or ({repr_template.replace('{}', f_i2)})" def cond(i1: float, i2: float) -> bool: return final_cond(i1) or final_cond(i2) else: + assert self.input_ == "both" # sanity check + repr_ = f"({repr_template.replace('{}', f_i1)}) and ({repr_template.replace('{}', f_i2)})" def cond(i1: float, i2: float) -> bool: return final_cond(i1) and final_cond(i2) - return cond - - -@dataclass -class AbsCondFactory(CondFactory): - cond_factory: CondFactory - - def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: - _cond = self.cond_factory(groups) - - def cond(i1: float, i2: float) -> bool: - i1 = abs(i1) - i2 = abs(i2) - return _cond(i1, i2) + if notify: + repr_ = f"not ({repr_})" - return cond + return BinaryCond(cond, repr_) -class AndCondFactory(CondFactory): - def __init__(self, *cond_factories: CondFactory): +class AndCondFactory(BinaryCondFactory): + def __init__(self, *cond_factories: BinaryCondFactory): self.cond_factories = cond_factories - def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: + def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: conds = [cond_factory(groups) for cond_factory in self.cond_factories] + repr_ = " and ".join(f"({cond!r})" for cond in conds) def cond(i1: float, i2: float) -> bool: return all(cond(i1, i2) for cond in conds) - return cond - - def __repr__(self) -> str: - f_cond_factories = ", ".join( - repr(cond_factory) for cond_factory in self.cond_factories - ) - return f"{self.__class__.__name__}({f_cond_factories})" + return BinaryCond(cond, repr_) @dataclass -class SignCondFactory(CondFactory): +class SignCondFactory(BinaryCondFactory): re_groups_i: int def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: @@ -451,45 +479,67 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: raise ValueParseError(group) -BinaryResultCheck = Callable[[float, float, float], bool] +class BinaryResultCheck(NamedTuple): + check_result: Callable[[float, float, float], bool] + repr_: str + + def __call__(self, i1: float, i2: float, result: float) -> bool: + return self.check_result(i1, i2, result) + + def __repr__(self): + return self.repr_ + + +class BinaryResultCheckFactory(Protocol): + def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: + ... -class ResultCheckFactory(NamedTuple): +@dataclass +class ResultCheckFactory(BinaryResultCheckFactory): re_groups_i: int def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck: group = groups[self.re_groups_i] if m := r_array_element.match(group): - cond_factory = make_eq if m.group(1) != "-" else make_neq + sign, input_ = m.groups() + if sign == "-": + signer = lambda i: -i + else: + signer = lambda i: i - if m.group(2) == "1": + if input_ == "1": + repr_ = f"{sign}{x1_i}" - def cond(i1: float, i2: float, result: float) -> bool: - _cond = cond_factory(i1) - return _cond(result) + def check_result(i1: float, i2: float, result: float) -> bool: + _check_result = make_eq(signer(i1)) + return _check_result(result) else: + repr_ = f"{sign}{x2_i}" - def cond(i1: float, i2: float, result: float) -> bool: - _cond = cond_factory(i2) - return _cond(result) + def check_result(i1: float, i2: float, result: float) -> bool: + _check_result = make_eq(signer(i2)) + return _check_result(result) - return cond + return BinaryResultCheck(check_result, repr_) if m := r_code.match(group): value = parse_value(m.group(1)) - _cond = make_eq(value) + _check_result = make_eq(value) + repr_ = str(value) elif m := r_approx_value.match(group): value = parse_value(m.group(1)) - _cond = make_rough_eq(value) + _check_result = make_rough_eq(value) + repr_ = f"~{value}" else: raise ValueParseError(group) - def cond(i1: float, i2: float, result: float) -> bool: - return _cond(result) + def check_result(i1: float, i2: float, result: float) -> bool: + return _check_result(result) - return cond + return BinaryResultCheck(check_result, repr_) class ResultSignCheckFactory(ResultCheckFactory): @@ -516,12 +566,15 @@ def cond(i1: float, i2: float, result: float) -> bool: class BinaryCase(NamedTuple): - cond: BinaryCheck + cond: BinaryCond check_result: BinaryResultCheck + def __repr__(self): + return f"BinaryCase(<{self.cond} -> {self.check_result}>)" + class BinaryCaseFactory(NamedTuple): - cond_factory: CondFactory + cond_factory: BinaryCondFactory check_result_factory: ResultCheckFactory def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: @@ -564,9 +617,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: re.compile( r"If ``abs\(x1_i\)`` is (.+) and ``x2_i`` is (.+), the result is (.+)" ): BinaryCaseFactory( - AndCondFactory( - AbsCondFactory(ValueCondFactory("i1", 0)), ValueCondFactory("i2", 1) - ), + AndCondFactory(ValueCondFactory("i1", 0, abs_=True), ValueCondFactory("i2", 1)), ResultCheckFactory(2), ), re.compile( @@ -726,25 +777,11 @@ def test_binary(func_name, func, cases, x1, x2): if case.cond(l, r): good_example = True o = float(res[o_idx]) - assert case.check_result(l, r, o) - # f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" - # f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" - # f_out = f"{sh.fmt_idx('out', o_idx)}={out}" - # if result.strict_check: - # msg = ( - # f"{f_out}, but should be {result.repr_} [{func_name}()]\n" - # f"{f_left}, {f_right}" - # ) - # if math.isnan(result.value): - # assert math.isnan(out), msg - # else: - # assert out == result.value, msg - # else: - # assert math.isfinite(result.value) # sanity check - # assert math.isclose(out, result.value, abs_tol=0.1), ( - # f"{f_out}, but should be roughly {result.repr_}={result.value} " - # f"[{func_name}()]\n" - # f"{f_left}, {f_right}" - # ) + f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" + f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" + f_out = f"{sh.fmt_idx('out', o_idx)}={o}" + assert case.check_result(l, r, o), ( + f"{f_out} not good [{func_name}()]\n" f"{f_left}, {f_right}" + ) break assume(good_example) From b0ed32ff56508c6c642e5d0b8eed475059acb5d2 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 17 Feb 2022 10:46:40 +0000 Subject: [PATCH 16/63] Mark special case tests for CI --- array_api_tests/test_special_cases.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 9945801c..bfd899c3 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -26,6 +26,8 @@ from ._array_module import mod as xp from .stubs import category_to_funcs +pytestmark = pytest.mark.ci + # Condition factories # ------------------------------------------------------------------------------ From 05e179e205c2e3510963aafe02686d29a8a4720b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 17 Feb 2022 11:23:46 +0000 Subject: [PATCH 17/63] `repr_` -> `expr`, return `BinaryCond` in `SignCondFactory` --- array_api_tests/test_special_cases.py | 188 ++++++++++++-------------- 1 file changed, 88 insertions(+), 100 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index bfd899c3..f3e6a988 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -222,27 +222,6 @@ def parse_inline_code(inline_code: str) -> float: raise ValueParseError(inline_code) -class Result(NamedTuple): - value: float - repr_: str - strict_check: bool - - -def parse_result(s_result: str) -> Result: - match = None - if m := r_code.match(s_result): - match = m - strict_check = True - elif m := r_approx_value.match(s_result): - match = m - strict_check = False - else: - raise ValueParseError(s_result) - value = parse_value(match.group(1)) - repr_ = match.group(1) - return Result(value, repr_, strict_check) - - r_special_cases = re.compile( r"\*\*Special [Cc]ases\*\*\n+\s*" r"For floating-point operands,\n+" @@ -252,65 +231,74 @@ def parse_result(s_result: str) -> Result: r_remaining_case = re.compile("In the remaining cases.+") -unary_pattern_to_condition_factory: Dict[Pattern, Callable] = { - re.compile("If ``x_i`` is greater than (.+), the result is (.+)"): make_gt, - re.compile("If ``x_i`` is less than (.+), the result is (.+)"): make_lt, - re.compile("If ``x_i`` is either (.+) or (.+), the result is (.+)"): ( - lambda v1, v2: make_or(make_eq(v1), make_eq(v2)) - ), - # This pattern must come after the previous patterns to avoid unwanted matches - re.compile("If ``x_i`` is (.+), the result is (.+)"): make_eq, - re.compile( - "If two integers are equally close to ``x_i``, the result is (.+)" - ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5), -} +# unary_pattern_to_condition_factory: Dict[Pattern, Callable] = { +# re.compile("If ``x_i`` is greater than (.+), the result is (.+)"): make_gt, +# re.compile("If ``x_i`` is less than (.+), the result is (.+)"): make_lt, +# re.compile("If ``x_i`` is either (.+) or (.+), the result is (.+)"): ( +# lambda v1, v2: make_or(make_eq(v1), make_eq(v2)) +# ), +# # This pattern must come after the previous patterns to avoid unwanted matches +# re.compile("If ``x_i`` is (.+), the result is (.+)"): make_eq, +# re.compile( +# "If two integers are equally close to ``x_i``, the result is (.+)" +# ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5), +# } + + +# def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: +# match = r_special_cases.search(docstring) +# if match is None: +# return {} +# cases = match.group(1).split("\n")[:-1] +# cases = {} +# for line in cases: +# if m := r_case.match(line): +# case = m.group(1) +# else: +# warn(f"line not machine-readable: '{line}'") +# continue +# for pattern, make_cond in unary_pattern_to_condition_factory.items(): +# if m := pattern.search(case): +# *s_values, s_result = m.groups() +# try: +# values = [parse_inline_code(v) for v in s_values] +# except ValueParseError as e: +# warn(f"value not machine-readable: '{e.value}'") +# break +# cond = make_cond(*values) +# try: +# result = parse_result(s_result) +# except ValueParseError as e: +# warn(f"result not machine-readable: '{e.value}'") + +# break +# cases[cond] = result +# break +# else: +# if not r_remaining_case.search(case): +# warn(f"case not machine-readable: '{case}'") +# return cases + +x_i = "xᵢ" +x1_i = "x1ᵢ" +x2_i = "x2ᵢ" -def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: - match = r_special_cases.search(docstring) - if match is None: - return {} - cases = match.group(1).split("\n")[:-1] - cases = {} - for line in cases: - if m := r_case.match(line): - case = m.group(1) - else: - warn(f"line not machine-readable: '{line}'") - continue - for pattern, make_cond in unary_pattern_to_condition_factory.items(): - if m := pattern.search(case): - *s_values, s_result = m.groups() - try: - values = [parse_inline_code(v) for v in s_values] - except ValueParseError as e: - warn(f"value not machine-readable: '{e.value}'") - break - cond = make_cond(*values) - try: - result = parse_result(s_result) - except ValueParseError as e: - warn(f"result not machine-readable: '{e.value}'") +class Cond(Protocol): + expr: str - break - cases[cond] = result - break - else: - if not r_remaining_case.search(case): - warn(f"case not machine-readable: '{case}'") - return cases + def __call__(self, *args) -> bool: + ... -class BinaryCond(NamedTuple): +@dataclass +class BinaryCond(Cond): cond: BinaryCheck - repr_: str + expr: str def __call__(self, i1: float, i2: float) -> bool: return self.cond(i1, i2) - def __repr__(self): - return self.repr_ - class BinaryCondFactory(Protocol): def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: @@ -323,9 +311,6 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: r_gt = re.compile(f"greater than {r_code.pattern}") r_lt = re.compile(f"less than {r_code.pattern}") -x1_i = "x1ᵢ" -x2_i = "x2ᵢ" - @dataclass class ValueCondFactory(BinaryCondFactory): @@ -345,7 +330,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: signer = lambda i: i if self.input_ == "i1": - repr_ = f"{x1_i} == {sign}{x2_i}" + expr = f"{x1_i} == {sign}{x2_i}" def cond(i1: float, i2: float) -> bool: _cond = make_eq(signer(i2)) @@ -353,13 +338,13 @@ def cond(i1: float, i2: float) -> bool: else: assert self.input_ == "i2" # sanity check - repr_ = f"{x2_i} == {sign}{x1_i}" + expr = f"{x2_i} == {sign}{x1_i}" def cond(i1: float, i2: float) -> bool: _cond = make_eq(signer(i1)) return _cond(i2) - return BinaryCond(cond, repr_) + return BinaryCond(cond, expr) if m := r_not.match(group): group = m.group(1) @@ -419,38 +404,38 @@ def cond(i1: float, i2: float) -> bool: f_i1 = x1_i f_i2 = x2_i if self.abs_: - f_i1 = f"abs{f_i1}" - f_i2 = f"abs{f_i2}" + f_i1 = f"abs({f_i1})" + f_i2 = f"abs({f_i2})" if self.input_ == "i1": - repr_ = repr_template.replace("{}", f_i1) + expr = repr_template.replace("{}", f_i1) def cond(i1: float, i2: float) -> bool: return final_cond(i1) elif self.input_ == "i2": - repr_ = repr_template.replace("{}", f_i2) + expr = repr_template.replace("{}", f_i2) def cond(i1: float, i2: float) -> bool: return final_cond(i2) elif self.input_ == "either": - repr_ = f"({repr_template.replace('{}', f_i1)}) or ({repr_template.replace('{}', f_i2)})" + expr = f"({repr_template.replace('{}', f_i1)}) or ({repr_template.replace('{}', f_i2)})" def cond(i1: float, i2: float) -> bool: return final_cond(i1) or final_cond(i2) else: assert self.input_ == "both" # sanity check - repr_ = f"({repr_template.replace('{}', f_i1)}) and ({repr_template.replace('{}', f_i2)})" + expr = f"({repr_template.replace('{}', f_i1)}) and ({repr_template.replace('{}', f_i2)})" def cond(i1: float, i2: float) -> bool: return final_cond(i1) and final_cond(i2) if notify: - repr_ = f"not ({repr_})" + expr = f"not ({expr})" - return BinaryCond(cond, repr_) + return BinaryCond(cond, expr) class AndCondFactory(BinaryCondFactory): @@ -459,37 +444,40 @@ def __init__(self, *cond_factories: BinaryCondFactory): def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: conds = [cond_factory(groups) for cond_factory in self.cond_factories] - repr_ = " and ".join(f"({cond!r})" for cond in conds) + expr = " and ".join(f"({cond.expr})" for cond in conds) def cond(i1: float, i2: float) -> bool: return all(cond(i1, i2) for cond in conds) - return BinaryCond(cond, repr_) + return BinaryCond(cond, expr) @dataclass class SignCondFactory(BinaryCondFactory): re_groups_i: int - def __call__(self, groups: Tuple[str, ...]) -> BinaryCheck: + def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: group = groups[self.re_groups_i] if group == "the same mathematical sign": - return same_sign + cond = same_sign + expr = f"copysign(1, {x1_i}) == copysign(1, {x2_i})" elif group == "different mathematical signs": - return diff_sign + cond = diff_sign + expr = f"copysign(1, {x1_i}) != copysign(1, {x2_i})" else: raise ValueParseError(group) + return BinaryCond(cond, expr) class BinaryResultCheck(NamedTuple): check_result: Callable[[float, float, float], bool] - repr_: str + expr: str def __call__(self, i1: float, i2: float, result: float) -> bool: return self.check_result(i1, i2, result) def __repr__(self): - return self.repr_ + return self.expr class BinaryResultCheckFactory(Protocol): @@ -512,36 +500,36 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck: signer = lambda i: i if input_ == "1": - repr_ = f"{sign}{x1_i}" + expr = f"{sign}{x1_i}" def check_result(i1: float, i2: float, result: float) -> bool: _check_result = make_eq(signer(i1)) return _check_result(result) else: - repr_ = f"{sign}{x2_i}" + expr = f"{sign}{x2_i}" def check_result(i1: float, i2: float, result: float) -> bool: _check_result = make_eq(signer(i2)) return _check_result(result) - return BinaryResultCheck(check_result, repr_) + return BinaryResultCheck(check_result, expr) if m := r_code.match(group): value = parse_value(m.group(1)) _check_result = make_eq(value) - repr_ = str(value) + expr = str(value) elif m := r_approx_value.match(group): value = parse_value(m.group(1)) _check_result = make_rough_eq(value) - repr_ = f"~{value}" + expr = f"~{value}" else: raise ValueParseError(group) def check_result(i1: float, i2: float, result: float) -> bool: return _check_result(result) - return BinaryResultCheck(check_result, repr_) + return BinaryResultCheck(check_result, expr) class ResultSignCheckFactory(ResultCheckFactory): @@ -572,7 +560,7 @@ class BinaryCase(NamedTuple): check_result: BinaryResultCheck def __repr__(self): - return f"BinaryCase(<{self.cond} -> {self.check_result}>)" + return f"BinaryCase(<{self.cond.expr} -> {self.check_result}>)" class BinaryCaseFactory(NamedTuple): @@ -743,7 +731,7 @@ def test_unary(func_name, func, cases, x): f_out = f"{sh.fmt_idx('out', idx)}={out}" if result.strict_check: msg = ( - f"{f_out}, but should be {result.repr_} [{func_name}()]\n" + f"{f_out}, but should be {result.expr} [{func_name}()]\n" f"{f_in}" ) if math.isnan(result.value): @@ -753,7 +741,7 @@ def test_unary(func_name, func, cases, x): else: assert math.isfinite(result.value) # sanity check assert math.isclose(out, result.value, abs_tol=0.1), ( - f"{f_out}, but should be roughly {result.repr_}={result.value} " + f"{f_out}, but should be roughly {result.expr}={result.value} " f"[{func_name}()]\n" f"{f_in}" ) From d6f4f5697712cfca564ec3747a38e713f309f651 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 17 Feb 2022 11:30:41 +0000 Subject: [PATCH 18/63] Fix abs conds not actually absifying --- array_api_tests/test_special_cases.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index f3e6a988..7fdf9978 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -396,8 +396,11 @@ def cond(i1: float, i2: float) -> bool: else: raise ValueParseError(group) + assert not (notify and self.abs_) # sanity check if notify: final_cond = lambda i: not _cond(i) + elif self.abs_: + final_cond = lambda i: _cond(abs(i)) else: final_cond = _cond From 3d0478f9244bf5fae280c13463bf410946717fbb Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 17 Feb 2022 11:56:36 +0000 Subject: [PATCH 19/63] Reorder binary case patterns to avoid false matches --- array_api_tests/test_special_cases.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 7fdf9978..4faa12ee 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -579,12 +579,6 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: r_result_sign = re.compile("([a-z]+) mathematical sign") binary_pattern_to_case_factory: Dict[Pattern, BinaryCaseFactory] = { - re.compile( - "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" - ): BinaryCaseFactory( - AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)), - ResultCheckFactory(2), - ), re.compile( "If ``x1_i`` is (.+), ``x1_i`` (.+), " "and ``x2_i`` is (.+), the result is (.+)" @@ -607,6 +601,13 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: ), ResultCheckFactory(3), ), + # This case must come after the above to avoid false matches + re.compile( + "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" + ): BinaryCaseFactory( + AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)), + ResultCheckFactory(2), + ), re.compile( r"If ``abs\(x1_i\)`` is (.+) and ``x2_i`` is (.+), the result is (.+)" ): BinaryCaseFactory( From 0ba67539f5ca83534b527e70602d10837ccacef7 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 17 Feb 2022 12:01:34 +0000 Subject: [PATCH 20/63] Parse quad cases --- array_api_tests/test_special_cases.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 4faa12ee..977f59d6 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -16,7 +16,7 @@ from warnings import warn import pytest -from hypothesis import HealthCheck, assume, given, settings +from hypothesis import assume, given from . import dtype_helpers as dh from . import hypothesis_helpers as hh @@ -580,7 +580,19 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: binary_pattern_to_case_factory: Dict[Pattern, BinaryCaseFactory] = { re.compile( - "If ``x1_i`` is (.+), ``x1_i`` (.+), " + "If ``x1_i`` is (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), " + "and ``x2_i`` is (.+), the result is (.+)" + ): BinaryCaseFactory( + AndCondFactory( + ValueCondFactory("i1", 0), + ValueCondFactory("i1", 1), + ValueCondFactory("i2", 2), + ValueCondFactory("i2", 3), + ), + ResultCheckFactory(4), + ), + re.compile( + "If ``x1_i`` is (.+), ``x1_i`` is (.+), " "and ``x2_i`` is (.+), the result is (.+)" ): BinaryCaseFactory( AndCondFactory( @@ -591,7 +603,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: ResultCheckFactory(3), ), re.compile( - "If ``x1_i`` is (.+), ``x2_i`` (.+), " + "If ``x1_i`` is (.+), ``x2_i`` is (.+), " "and ``x2_i`` is (.+), the result is (.+)" ): BinaryCaseFactory( AndCondFactory( @@ -601,7 +613,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: ), ResultCheckFactory(3), ), - # This case must come after the above to avoid false matches + # This pattern must come after the above to avoid false matches re.compile( "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" ): BinaryCaseFactory( @@ -760,7 +772,6 @@ def test_unary(func_name, func, cases, x): two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1), ) ) -@settings(suppress_health_check=[HealthCheck.filter_too_much]) # TODO: remove def test_binary(func_name, func, cases, x1, x2): res = func(x1, x2) good_example = False From 404542f4ac13949d271018968a5c8418623e6437 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 18 Feb 2022 11:05:16 +0000 Subject: [PATCH 21/63] Parse most unary cases again, refactor parse utils --- array_api_tests/test_special_cases.py | 346 ++++++++++++++------------ 1 file changed, 191 insertions(+), 155 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 977f59d6..91b7d52b 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -130,6 +130,20 @@ def bin_or(i1: float, i2: float) -> bool: return make_bin_or +def notify_cond(cond: UnaryCheck) -> UnaryCheck: + def not_cond(i: float) -> bool: + return not cond(i) + + return not_cond + + +def absify_cond(cond: UnaryCheck) -> UnaryCheck: + def abs_cond(i: float) -> bool: + return cond(abs(i)) + + return abs_cond + + def absify_cond_factory( make_cond: Callable[[float], UnaryCheck] ) -> Callable[[float], UnaryCheck]: @@ -190,11 +204,10 @@ class ValueParseError(ValueError): value: str -def parse_value(s_value: str) -> float: - assert not s_value.startswith("``") and not s_value.endswith("``") # sanity check - m = r_value.match(s_value) +def parse_value(value_str: str) -> float: + m = r_value.match(value_str) if m is None: - raise ValueParseError(s_value) + raise ValueParseError(value_str) if pi_m := r_pi.match(m.group(2)): value = math.pi if numerator := pi_m.group(1): @@ -230,60 +243,84 @@ def parse_inline_code(inline_code: str) -> float: r_case = re.compile(r"\s+-\s*(.*)\.\n?") r_remaining_case = re.compile("In the remaining cases.+") - -# unary_pattern_to_condition_factory: Dict[Pattern, Callable] = { -# re.compile("If ``x_i`` is greater than (.+), the result is (.+)"): make_gt, -# re.compile("If ``x_i`` is less than (.+), the result is (.+)"): make_lt, -# re.compile("If ``x_i`` is either (.+) or (.+), the result is (.+)"): ( -# lambda v1, v2: make_or(make_eq(v1), make_eq(v2)) -# ), -# # This pattern must come after the previous patterns to avoid unwanted matches -# re.compile("If ``x_i`` is (.+), the result is (.+)"): make_eq, -# re.compile( -# "If two integers are equally close to ``x_i``, the result is (.+)" -# ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5), -# } - - -# def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]: -# match = r_special_cases.search(docstring) -# if match is None: -# return {} -# cases = match.group(1).split("\n")[:-1] -# cases = {} -# for line in cases: -# if m := r_case.match(line): -# case = m.group(1) -# else: -# warn(f"line not machine-readable: '{line}'") -# continue -# for pattern, make_cond in unary_pattern_to_condition_factory.items(): -# if m := pattern.search(case): -# *s_values, s_result = m.groups() -# try: -# values = [parse_inline_code(v) for v in s_values] -# except ValueParseError as e: -# warn(f"value not machine-readable: '{e.value}'") -# break -# cond = make_cond(*values) -# try: -# result = parse_result(s_result) -# except ValueParseError as e: -# warn(f"result not machine-readable: '{e.value}'") - -# break -# cases[cond] = result -# break -# else: -# if not r_remaining_case.search(case): -# warn(f"case not machine-readable: '{case}'") -# return cases - x_i = "xᵢ" x1_i = "x1ᵢ" x2_i = "x2ᵢ" +def parse_cond(cond_str: str): + if m := r_not.match(cond_str): + cond_str = m.group(1) + notify = True + else: + notify = False + + if m := r_code.match(cond_str): + value = parse_value(m.group(1)) + cond = make_eq(value) + expr_template = "{} == " + str(value) + elif m := r_gt.match(cond_str): + value = parse_value(m.group(1)) + cond = make_gt(value) + expr_template = "{} > " + str(value) + elif m := r_lt.match(cond_str): + value = parse_value(m.group(1)) + cond = make_lt(value) + expr_template = "{} < " + str(value) + elif m := r_either_code.match(cond_str): + v1 = parse_value(m.group(1)) + v2 = parse_value(m.group(2)) + cond = make_or(make_eq(v1), make_eq(v2)) + expr_template = "{} == " + str(v1) + " or {} == " + str(v2) + elif cond_str in ["finite", "a finite number"]: + cond = math.isfinite + expr_template = "isfinite({})" + elif cond_str in "a positive (i.e., greater than ``0``) finite number": + cond = lambda i: math.isfinite(i) and i > 0 + expr_template = "isfinite({}) and {} > 0" + elif cond_str == "a negative (i.e., less than ``0``) finite number": + cond = lambda i: math.isfinite(i) and i < 0 + expr_template = "isfinite({}) and {} < 0" + elif cond_str == "positive": + cond = lambda i: math.copysign(1, i) == 1 + expr_template = "copysign(1, {}) == 1" + elif cond_str == "negative": + cond = lambda i: math.copysign(1, i) == -1 + expr_template = "copysign(1, {}) == -1" + elif "nonzero finite" in cond_str: + cond = lambda i: math.isfinite(i) and i != 0 + expr_template = "copysign(1, {}) == -1" + elif cond_str == "an integer value": + cond = lambda i: i.is_integer() + expr_template = "{}.is_integer()" + elif cond_str == "an odd integer value": + cond = lambda i: i.is_integer() and i % 2 == 1 + expr_template = "{}.is_integer() and {} % 2 == 1" + else: + raise ValueParseError(cond_str) + + if notify: + cond = notify_cond(cond) + expr_template = f"not ({expr_template})" + + return cond, expr_template + + +def parse_result(result_str: str): + if m := r_code.match(result_str): + value = parse_value(m.group(1)) + check_result = make_eq(value) + expr = str(value) + elif m := r_approx_value.match(result_str): + value = parse_value(m.group(1)) + check_result = make_rough_eq(value) + expr = f"~{value}" + else: + raise ValueParseError(result_str) + + return check_result, expr + + class Cond(Protocol): expr: str @@ -291,6 +328,82 @@ def __call__(self, *args) -> bool: ... +@dataclass +class UnaryCond(Cond): + cond: UnaryCheck + expr: str + + def __call__(self, i: float) -> bool: + return self.cond(i) + + +@dataclass +class UnaryResultCheck: + check_result: Callable + expr: str + + def __call__(self, result: float) -> bool: + return self.check_result(result) + + +class Case(Protocol): + def cond(self, *args) -> bool: + ... + + def check_result(self, *args) -> bool: + ... + + +@dataclass +class UnaryCase(Case): + cond: UnaryCond + check_result: UnaryResultCheck + + @classmethod + def from_strings(cls, cond_str: str, result_str: str): + cond, cond_expr_template = parse_cond(cond_str) + cond_expr = cond_expr_template.replace("{}", x_i) + check_result, check_result_expr = parse_result(result_str) + return cls( + UnaryCond(cond, cond_expr), + UnaryResultCheck(check_result, check_result_expr), + ) + + def __repr__(self): + return f"UnaryCase(<{self.cond.expr} -> {self.check_result.expr}>)" + + +r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") +# re.compile( +# "If two integers are equally close to ``x_i``, the result is (.+)" +# ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5), + + +def parse_unary_docstring(docstring: str) -> List[UnaryCase]: + match = r_special_cases.search(docstring) + if match is None: + return [] + lines = match.group(1).split("\n")[:-1] + cases = [] + for line in lines: + if m := r_case.match(line): + case = m.group(1) + else: + warn(f"line not machine-readable: '{line}'") + continue + if m := r_unary_case.search(case): + try: + case = UnaryCase.from_strings(*m.groups()) + except ValueParseError as e: + warn(f"not machine-readable: '{e.value}'") + continue + cases.append(case) + else: + if not r_remaining_case.search(case): + warn(f"case not machine-readable: '{case}'") + return cases + + @dataclass class BinaryCond(Cond): cond: BinaryCheck @@ -346,63 +459,10 @@ def cond(i1: float, i2: float) -> bool: return BinaryCond(cond, expr) - if m := r_not.match(group): - group = m.group(1) - notify = True - else: - notify = False - - if m := r_code.match(group): - value = parse_value(m.group(1)) - _cond = make_eq(value) - repr_template = "{} == " + str(value) - elif m := r_gt.match(group): - value = parse_value(m.group(1)) - _cond = make_gt(value) - repr_template = "{} > " + str(value) - elif m := r_lt.match(group): - value = parse_value(m.group(1)) - _cond = make_lt(value) - repr_template = "{} < " + str(value) - elif m := r_either_code.match(group): - v1 = parse_value(m.group(1)) - v2 = parse_value(m.group(2)) - _cond = make_or(make_eq(v1), make_eq(v2)) - repr_template = "{} == " + str(v1) + " or {} == " + str(v2) - elif group in ["finite", "a finite number"]: - _cond = math.isfinite - repr_template = "isfinite({})" - elif group in "a positive (i.e., greater than ``0``) finite number": - _cond = lambda i: math.isfinite(i) and i > 0 - repr_template = "isfinite({}) and {} > 0" - elif group == "a negative (i.e., less than ``0``) finite number": - _cond = lambda i: math.isfinite(i) and i < 0 - repr_template = "isfinite({}) and {} < 0" - elif group == "positive": - _cond = lambda i: math.copysign(1, i) == 1 - repr_template = "copysign(1, {}) == 1" - elif group == "negative": - _cond = lambda i: math.copysign(1, i) == -1 - repr_template = "copysign(1, {}) == -1" - elif "nonzero finite" in group: - _cond = lambda i: math.isfinite(i) and i != 0 - repr_template = "copysign(1, {}) == -1" - elif group == "an integer value": - _cond = lambda i: i.is_integer() - repr_template = "{}.is_integer()" - elif group == "an odd integer value": - _cond = lambda i: i.is_integer() and i % 2 == 1 - repr_template = "{}.is_integer() and {} % 2 == 1" - else: - raise ValueParseError(group) + _cond, expr_template = parse_cond(group) - assert not (notify and self.abs_) # sanity check - if notify: - final_cond = lambda i: not _cond(i) - elif self.abs_: - final_cond = lambda i: _cond(abs(i)) - else: - final_cond = _cond + if self.abs_: + _cond = absify_cond(_cond) f_i1 = x1_i f_i2 = x2_i @@ -411,32 +471,29 @@ def cond(i1: float, i2: float) -> bool: f_i2 = f"abs({f_i2})" if self.input_ == "i1": - expr = repr_template.replace("{}", f_i1) + expr = expr_template.replace("{}", f_i1) def cond(i1: float, i2: float) -> bool: - return final_cond(i1) + return _cond(i1) elif self.input_ == "i2": - expr = repr_template.replace("{}", f_i2) + expr = expr_template.replace("{}", f_i2) def cond(i1: float, i2: float) -> bool: - return final_cond(i2) + return _cond(i2) elif self.input_ == "either": - expr = f"({repr_template.replace('{}', f_i1)}) or ({repr_template.replace('{}', f_i2)})" + expr = f"({expr_template.replace('{}', f_i1)}) or ({expr_template.replace('{}', f_i2)})" def cond(i1: float, i2: float) -> bool: - return final_cond(i1) or final_cond(i2) + return _cond(i1) or _cond(i2) else: assert self.input_ == "both" # sanity check - expr = f"({repr_template.replace('{}', f_i1)}) and ({repr_template.replace('{}', f_i2)})" + expr = f"({expr_template.replace('{}', f_i1)}) and ({expr_template.replace('{}', f_i2)})" def cond(i1: float, i2: float) -> bool: - return final_cond(i1) and final_cond(i2) - - if notify: - expr = f"not ({expr})" + return _cond(i1) and _cond(i2) return BinaryCond(cond, expr) @@ -518,16 +575,7 @@ def check_result(i1: float, i2: float, result: float) -> bool: return BinaryResultCheck(check_result, expr) - if m := r_code.match(group): - value = parse_value(m.group(1)) - _check_result = make_eq(value) - expr = str(value) - elif m := r_approx_value.match(group): - value = parse_value(m.group(1)) - _check_result = make_rough_eq(value) - expr = f"~{value}" - else: - raise ValueParseError(group) + _check_result, expr = parse_result(group) def check_result(i1: float, i2: float, result: float) -> bool: return _check_result(result) @@ -558,7 +606,8 @@ def cond(i1: float, i2: float, result: float) -> bool: return cond -class BinaryCase(NamedTuple): +@dataclass +class BinaryCase(Case): cond: BinaryCond check_result: BinaryResultCheck @@ -707,9 +756,9 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: warn(f"{func=} has no parameters") continue if param_names[0] == "x": - # if cases := parse_unary_docstring(stub.__doc__): - # p = pytest.param(stub.__name__, func, cases, id=stub.__name__) - # unary_params.append(p) + if cases := parse_unary_docstring(stub.__doc__): + p = pytest.param(stub.__name__, func, cases, id=stub.__name__) + unary_params.append(p) continue if len(sig.parameters) == 1: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") @@ -739,28 +788,15 @@ def test_unary(func_name, func, cases, x): good_example = False for idx in sh.ndindex(res.shape): in_ = float(x[idx]) - for cond, result in cases.items(): - if cond(in_): + for case in cases: + if case.cond(in_): good_example = True out = float(res[idx]) f_in = f"{sh.fmt_idx('x', idx)}={in_}" f_out = f"{sh.fmt_idx('out', idx)}={out}" - if result.strict_check: - msg = ( - f"{f_out}, but should be {result.expr} [{func_name}()]\n" - f"{f_in}" - ) - if math.isnan(result.value): - assert math.isnan(out), msg - else: - assert out == result.value, msg - else: - assert math.isfinite(result.value) # sanity check - assert math.isclose(out, result.value, abs_tol=0.1), ( - f"{f_out}, but should be roughly {result.expr}={result.value} " - f"[{func_name}()]\n" - f"{f_in}" - ) + assert case.check_result( + out + ), f"{f_out} not good [{func_name}()]\n{f_in}" break assume(good_example) @@ -785,8 +821,8 @@ def test_binary(func_name, func, cases, x1, x2): f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" f_out = f"{sh.fmt_idx('out', o_idx)}={o}" - assert case.check_result(l, r, o), ( - f"{f_out} not good [{func_name}()]\n" f"{f_left}, {f_right}" - ) + assert case.check_result( + l, r, o + ), f"{f_out} not good [{func_name}()]\n{f_left}, {f_right}" break assume(good_example) From 4ea01546a344406b4a3d003b25271ae22ab79142 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 18 Feb 2022 11:58:07 +0000 Subject: [PATCH 22/63] Parse even round case --- array_api_tests/test_special_cases.py | 28 +++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 91b7d52b..5f3d0bdd 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -2,6 +2,7 @@ import math import re from dataclasses import dataclass +from decimal import ROUND_HALF_EVEN, Decimal from typing import ( Callable, Dict, @@ -236,7 +237,7 @@ def parse_inline_code(inline_code: str) -> float: r_special_cases = re.compile( - r"\*\*Special [Cc]ases\*\*\n+\s*" + r"\*\*Special [Cc]ases\*\*(?:\n.*)+" r"For floating-point operands,\n+" r"((?:\s*-\s*.*\n)+)" ) @@ -342,8 +343,8 @@ class UnaryResultCheck: check_result: Callable expr: str - def __call__(self, result: float) -> bool: - return self.check_result(result) + def __call__(self, i: float, result: float) -> bool: + return self.check_result(i, result) class Case(Protocol): @@ -366,7 +367,7 @@ def from_strings(cls, cond_str: str, result_str: str): check_result, check_result_expr = parse_result(result_str) return cls( UnaryCond(cond, cond_expr), - UnaryResultCheck(check_result, check_result_expr), + UnaryResultCheck(lambda _, r: check_result(r), check_result_expr), ) def __repr__(self): @@ -374,9 +375,18 @@ def __repr__(self): r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") -# re.compile( -# "If two integers are equally close to ``x_i``, the result is (.+)" -# ): lambda: (lambda i: (abs(i) - math.floor(abs(i))) == 0.5), +r_even_int_round_case = re.compile( + "If two integers are equally close to ``x_i``, " + "the result is the even integer closest to ``x_i``" +) + +even_int_round_case = UnaryCase( + cond=UnaryCond(lambda i: i % 0.5 == 0, "i % 0.5 == 0"), + check_result=UnaryResultCheck( + lambda i, r: r == float(Decimal(i).to_integral_exact(ROUND_HALF_EVEN)), + "Decimal(i).to_integral_exact(ROUND_HALF_EVEN)", + ), +) def parse_unary_docstring(docstring: str) -> List[UnaryCase]: @@ -398,6 +408,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: warn(f"not machine-readable: '{e.value}'") continue cases.append(case) + elif m := r_even_int_round_case.search(case): + cases.append(even_int_round_case) else: if not r_remaining_case.search(case): warn(f"case not machine-readable: '{case}'") @@ -795,7 +807,7 @@ def test_unary(func_name, func, cases, x): f_in = f"{sh.fmt_idx('x', idx)}={in_}" f_out = f"{sh.fmt_idx('out', idx)}={out}" assert case.check_result( - out + in_, out ), f"{f_out} not good [{func_name}()]\n{f_in}" break assume(good_example) From e808673533bc6cdb5716563ec1ed2d685fb83392 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 18 Feb 2022 14:23:05 +0000 Subject: [PATCH 23/63] Factor out class bloat for unary params --- array_api_tests/test_special_cases.py | 230 ++++++++------------------ 1 file changed, 68 insertions(+), 162 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 5f3d0bdd..5a06c6c6 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -29,10 +29,6 @@ pytestmark = pytest.mark.ci -# Condition factories -# ------------------------------------------------------------------------------ - - UnaryCheck = Callable[[float], bool] BinaryCheck = Callable[[float, float], bool] @@ -102,35 +98,6 @@ def and_(i: float) -> bool: return and_ -def make_bin_and_factory( - make_cond1: Callable[[float], UnaryCheck], make_cond2: Callable[[float], UnaryCheck] -) -> Callable[[float, float], BinaryCheck]: - def make_bin_and(v1: float, v2: float) -> BinaryCheck: - cond1 = make_cond1(v1) - cond2 = make_cond2(v2) - - def bin_and(i1: float, i2: float) -> bool: - return cond1(i1) and cond2(i2) - - return bin_and - - return make_bin_and - - -def make_bin_or_factory( - make_cond: Callable[[float], UnaryCheck] -) -> Callable[[float], BinaryCheck]: - def make_bin_or(v: float) -> BinaryCheck: - cond = make_cond(v) - - def bin_or(i1: float, i2: float) -> bool: - return cond(i1) or cond(i2) - - return bin_or - - return make_bin_or - - def notify_cond(cond: UnaryCheck) -> UnaryCheck: def not_cond(i: float) -> bool: return not cond(i) @@ -145,57 +112,12 @@ def abs_cond(i: float) -> bool: return abs_cond -def absify_cond_factory( - make_cond: Callable[[float], UnaryCheck] -) -> Callable[[float], UnaryCheck]: - def make_abs_cond(v: float) -> UnaryCheck: - cond = make_cond(v) - - def abs_cond(i: float) -> bool: - i = abs(i) - return cond(i) - - return abs_cond - - return make_abs_cond - - -def make_bin_multi_and_factory( - make_conds1: List[Callable[[float], UnaryCheck]], - make_conds2: List[Callable[[float], UnaryCheck]], -) -> Callable: - def make_bin_multi_and(*values: float) -> BinaryCheck: - assert len(values) == len(make_conds1) + len(make_conds2) - conds1 = [make_cond(v) for make_cond, v in zip(make_conds1, values)] - conds2 = [make_cond(v) for make_cond, v in zip(make_conds2, values[::-1])] - - def bin_multi_and(i1: float, i2: float) -> bool: - return all(cond(i1) for cond in conds1) and all(cond(i2) for cond in conds2) - - return bin_multi_and - - return make_bin_multi_and - - -def same_sign(i1: float, i2: float) -> bool: - return math.copysign(1, i1) == math.copysign(1, i2) - - -def diff_sign(i1: float, i2: float) -> bool: - return not same_sign(i1, i2) - - -# Parse utils -# ------------------------------------------------------------------------------ - - repr_to_value = { "NaN": float("nan"), "infinity": float("inf"), "0": 0.0, "1": 1.0, } - r_value = re.compile(r"([+-]?)(.+)") r_pi = re.compile(r"(\d?)π(?:/(\d))?") @@ -244,12 +166,8 @@ def parse_inline_code(inline_code: str) -> float: r_case = re.compile(r"\s+-\s*(.*)\.\n?") r_remaining_case = re.compile("In the remaining cases.+") -x_i = "xᵢ" -x1_i = "x1ᵢ" -x2_i = "x2ᵢ" - -def parse_cond(cond_str: str): +def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str]: if m := r_not.match(cond_str): cond_str = m.group(1) notify = True @@ -290,7 +208,7 @@ def parse_cond(cond_str: str): expr_template = "copysign(1, {}) == -1" elif "nonzero finite" in cond_str: cond = lambda i: math.isfinite(i) and i != 0 - expr_template = "copysign(1, {}) == -1" + expr_template = "isfinite({}) and {} != 0" elif cond_str == "an integer value": cond = lambda i: i.is_integer() expr_template = "{}.is_integer()" @@ -307,7 +225,7 @@ def parse_cond(cond_str: str): return cond, expr_template -def parse_result(result_str: str): +def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: if m := r_code.match(result_str): value = parse_value(m.group(1)) check_result = make_eq(value) @@ -316,62 +234,69 @@ def parse_result(result_str: str): value = parse_value(m.group(1)) check_result = make_rough_eq(value) expr = f"~{value}" + elif result_str == "positive": + + def check_result(result: float) -> bool: + if math.isnan(result): + # The sign of NaN is out-of-scope + return True + return math.copysign(1, result) == 1 + + expr = "+" + elif result_str == "negative": + + def check_result(result: float) -> bool: + if math.isnan(result): + # The sign of NaN is out-of-scope + return True + return math.copysign(1, result) == -1 + + expr = "-" else: raise ValueParseError(result_str) return check_result, expr -class Cond(Protocol): +class Case(Protocol): expr: str - def __call__(self, *args) -> bool: + def cond(self, *args) -> bool: ... + def check_result(self, *args) -> bool: + ... -@dataclass -class UnaryCond(Cond): - cond: UnaryCheck - expr: str - - def __call__(self, i: float) -> bool: - return self.cond(i) - - -@dataclass -class UnaryResultCheck: - check_result: Callable - expr: str + def __repr__(self): + return f"{self.__class__.__name__}(<{self.expr}>)" - def __call__(self, i: float, result: float) -> bool: - return self.check_result(i, result) + def __str__(self): + return self.expr -class Case(Protocol): - def cond(self, *args) -> bool: +class UnaryCond(Protocol): + def __call__(self, i: float) -> bool: ... - def check_result(self, *args) -> bool: + +class UnaryResultCheck(Protocol): + def __call__(self, i: float, result: float) -> bool: ... -@dataclass +@dataclass(repr=False) class UnaryCase(Case): - cond: UnaryCond + expr: str + cond: UnaryCheck check_result: UnaryResultCheck @classmethod def from_strings(cls, cond_str: str, result_str: str): cond, cond_expr_template = parse_cond(cond_str) - cond_expr = cond_expr_template.replace("{}", x_i) check_result, check_result_expr = parse_result(result_str) - return cls( - UnaryCond(cond, cond_expr), - UnaryResultCheck(lambda _, r: check_result(r), check_result_expr), - ) - - def __repr__(self): - return f"UnaryCase(<{self.cond.expr} -> {self.check_result.expr}>)" + cond_expr = cond_expr_template.replace("{}", "xᵢ") + expr = f"{cond_expr} -> {check_result_expr}" + return cls(expr, cond, lambda i, result: check_result(result)) r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") @@ -379,12 +304,11 @@ def __repr__(self): "If two integers are equally close to ``x_i``, " "the result is the even integer closest to ``x_i``" ) - even_int_round_case = UnaryCase( - cond=UnaryCond(lambda i: i % 0.5 == 0, "i % 0.5 == 0"), - check_result=UnaryResultCheck( - lambda i, r: r == float(Decimal(i).to_integral_exact(ROUND_HALF_EVEN)), - "Decimal(i).to_integral_exact(ROUND_HALF_EVEN)", + expr="i % 0.5 == 0 -> Decimal(i).to_integral_exact(ROUND_HALF_EVEN)", + cond=lambda i: i % 0.5 == 0, + check_result=lambda i, result: ( + result == float(Decimal(i).to_integral_exact(ROUND_HALF_EVEN)) ), ) @@ -417,7 +341,7 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: @dataclass -class BinaryCond(Cond): +class BinaryCond: cond: BinaryCheck expr: str @@ -455,7 +379,7 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: signer = lambda i: i if self.input_ == "i1": - expr = f"{x1_i} == {sign}{x2_i}" + expr = f"x1ᵢ == {sign}x2ᵢ" def cond(i1: float, i2: float) -> bool: _cond = make_eq(signer(i2)) @@ -463,7 +387,7 @@ def cond(i1: float, i2: float) -> bool: else: assert self.input_ == "i2" # sanity check - expr = f"{x2_i} == {sign}{x1_i}" + expr = f"x2ᵢ == {sign}x1ᵢ" def cond(i1: float, i2: float) -> bool: _cond = make_eq(signer(i1)) @@ -476,8 +400,8 @@ def cond(i1: float, i2: float) -> bool: if self.abs_: _cond = absify_cond(_cond) - f_i1 = x1_i - f_i2 = x2_i + f_i1 = "x1ᵢ" + f_i2 = "x2ᵢ" if self.abs_: f_i1 = f"abs({f_i1})" f_i2 = f"abs({f_i2})" @@ -531,11 +455,17 @@ class SignCondFactory(BinaryCondFactory): def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: group = groups[self.re_groups_i] if group == "the same mathematical sign": - cond = same_sign - expr = f"copysign(1, {x1_i}) == copysign(1, {x2_i})" + + def cond(i1: float, i2: float) -> bool: + return math.copysign(1, i1) == math.copysign(1, i2) + + expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)" elif group == "different mathematical signs": - cond = diff_sign - expr = f"copysign(1, {x1_i}) != copysign(1, {x2_i})" + + def cond(i1: float, i2: float) -> bool: + return math.copysign(1, i1) != math.copysign(1, i2) + + expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)" else: raise ValueParseError(group) return BinaryCond(cond, expr) @@ -572,14 +502,14 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck: signer = lambda i: i if input_ == "1": - expr = f"{sign}{x1_i}" + expr = f"{sign}x1ᵢ" def check_result(i1: float, i2: float, result: float) -> bool: _check_result = make_eq(signer(i1)) return _check_result(result) else: - expr = f"{sign}{x2_i}" + expr = f"{sign}x2ᵢ" def check_result(i1: float, i2: float, result: float) -> bool: _check_result = make_eq(signer(i2)) @@ -595,37 +525,12 @@ def check_result(i1: float, i2: float, result: float) -> bool: return BinaryResultCheck(check_result, expr) -class ResultSignCheckFactory(ResultCheckFactory): - def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck: - group = groups[self.re_groups_i] - if group == "positive": - - def cond(i1: float, i2: float, result: float) -> bool: - if math.isnan(result): - return True - return result > 0 or ph.is_pos_zero(result) - - elif group == "negative": - - def cond(i1: float, i2: float, result: float) -> bool: - if math.isnan(result): - return True - return result < 0 or ph.is_neg_zero(result) - - else: - raise ValueParseError(group) - - return cond - - -@dataclass +@dataclass(repr=False) class BinaryCase(Case): + expr: str cond: BinaryCond check_result: BinaryResultCheck - def __repr__(self): - return f"BinaryCase(<{self.cond.expr} -> {self.check_result}>)" - class BinaryCaseFactory(NamedTuple): cond_factory: BinaryCondFactory @@ -634,7 +539,8 @@ class BinaryCaseFactory(NamedTuple): def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: cond = self.cond_factory(groups) check_result = self.check_result_factory(groups) - return BinaryCase(cond, check_result) + expr = f"{cond.expr} -> {check_result.expr}" + return BinaryCase(expr, cond, check_result) r_result_sign = re.compile("([a-z]+) mathematical sign") @@ -693,19 +599,19 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: re.compile( "If ``x1_i`` and ``x2_i`` have (.+signs?), " f"the result has a {r_result_sign.pattern}" - ): BinaryCaseFactory(SignCondFactory(0), ResultSignCheckFactory(1)), + ): BinaryCaseFactory(SignCondFactory(0), ResultCheckFactory(1)), re.compile( "If ``x1_i`` and ``x2_i`` have (.+signs?) and are both (.+), " f"the result has a {r_result_sign.pattern}" ): BinaryCaseFactory( AndCondFactory(SignCondFactory(0), ValueCondFactory("both", 1)), - ResultSignCheckFactory(2), + ResultCheckFactory(2), ), re.compile( "If ``x1_i`` and ``x2_i`` have (.+signs?), the result has a " rf"{r_result_sign.pattern} , unless the result is (.+)\. If the result " r"is ``NaN``, the \"sign\" of ``NaN`` is implementation-defined\." - ): BinaryCaseFactory(SignCondFactory(0), ResultSignCheckFactory(1)), + ): BinaryCaseFactory(SignCondFactory(0), ResultCheckFactory(1)), re.compile( "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+" ): BinaryCaseFactory(ValueCondFactory("i2", 0), ResultCheckFactory(1)), From a90760c86e3f47b50d20b7529298063badb125cc Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 21 Feb 2022 10:47:14 +0000 Subject: [PATCH 24/63] Rudimentary generic condition parsing for binary cases --- array_api_tests/test_special_cases.py | 105 +++++++++++++++++++++++--- 1 file changed, 93 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 5a06c6c6..cdd5f442 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -8,6 +8,7 @@ Dict, List, Literal, + Match, NamedTuple, Pattern, Protocol, @@ -617,6 +618,87 @@ def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: ): BinaryCaseFactory(ValueCondFactory("i2", 0), ResultCheckFactory(1)), } +r_binary_case = re.compile("If (.+), the result (.+)") + +r_cond_sep = re.compile(", | and ") +r_cond = re.compile("(.+) (?:is|have) (.+)") + +r_element = re.compile("x([12])_i") +r_input = re.compile(rf"``{r_element.pattern}``") +r_abs_input = re.compile(r"``abs\({r_element.pattern}\)``") +r_and_input = re.compile(f"{r_input.pattern} and {r_input.pattern}") +r_or_input = re.compile(f"either {r_input.pattern} or {r_input.pattern}") + +r_result = re.compile(r"(?:is|has a) (.+)") + +r_both_inputs_are_value = re.compile("are both (.+)") + + +def parse_binary_case(case_m: Match) -> BinaryCase: + cond_strs = r_cond_sep.split(case_m.group(1)) + conds = [] + cond_exprs = [] + for cond_str in cond_strs: + if m := r_both_inputs_are_value.match(cond_str): + raise ValueParseError(cond_str) + else: + cond_m = r_cond.match(cond_str) + if cond_m is None: + raise ValueParseError(cond_str) + input_str, value_str = cond_m.groups() + + unary_cond, expr_template = parse_cond(value_str) + + if m := r_input.match(input_str): + x_no = m.group(1) + args_i = int(x_no) - 1 + expr = expr_template.replace("{}", f"x{x_no}ᵢ") + + def cond(*inputs) -> bool: + return unary_cond(inputs[args_i]) + + elif m := r_abs_input.match(input_str): + x_no = m.group(1) + args_i = int(x_no) - 1 + expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)") + + def cond(*inputs) -> bool: + return unary_cond(abs(inputs[args_i])) + + elif r_and_input.match(input_str): + left_expr = expr_template.replace("{}", "x1ᵢ") + right_expr = expr_template.replace("{}", "x2ᵢ") + expr = f"({left_expr}) and ({right_expr})" + + def cond(i1: float, i2: float) -> bool: + return unary_cond(i1) and unary_cond(i2) + + elif r_or_input.match(input_str): + left_expr = expr_template.replace("{}", "x1ᵢ") + right_expr = expr_template.replace("{}", "x2ᵢ") + expr = f"({left_expr}) and ({right_expr})" + + def cond(i1: float, i2: float) -> bool: + return unary_cond(i1) or unary_cond(i2) + + else: + raise ValueParseError(input_str) + + conds.append(cond) + cond_exprs.append(expr) + + result_m = r_result.match(case_m.group(2)) + if result_m is None: + raise ValueParseError(case_m.group(2)) + check_result, result_expr = parse_result(result_m.group(1)) + + expr = " and ".join(f"({expr})" for expr in cond_exprs) + " -> " + result_expr + + def cond(i1: float, i2: float) -> bool: + return all(c(i1, i2) for c in conds) + + return BinaryCase(expr, cond, lambda l, r, o: check_result(o)) + r_redundant_case = re.compile("result.+determined by the rule already stated above") @@ -629,24 +711,23 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: cases = [] for line in lines: if m := r_case.match(line): - case = m.group(1) + case_str = m.group(1) else: warn(f"line not machine-readable: '{line}'") continue - if r_redundant_case.search(case): + if r_redundant_case.search(case_str): continue - for pattern, make_case in binary_pattern_to_case_factory.items(): - if m := pattern.search(case): - try: - case = make_case(m.groups()) - except ValueParseError as e: - warn(f"not machine-readable: '{e.value}'") - break - cases.append(case) + if m := r_binary_case.search(case_str): + try: + case = parse_binary_case(m) + except ValueParseError as e: + warn(f"not machine-readable: '{e.value}'") break + cases.append(case) + break else: - if not r_remaining_case.search(case): - warn(f"case not machine-readable: '{case}'") + if not r_remaining_case.search(case_str): + warn(f"case not machine-readable: '{case_str}'") return cases From f91599374740b2d367a50ca0eb2b54c2479dcf7d Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 21 Feb 2022 14:15:02 +0000 Subject: [PATCH 25/63] Bring back full case coverage --- array_api_tests/test_special_cases.py | 463 ++++++++------------------ 1 file changed, 141 insertions(+), 322 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index cdd5f442..bff35f5c 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -3,18 +3,7 @@ import re from dataclasses import dataclass from decimal import ROUND_HALF_EVEN, Decimal -from typing import ( - Callable, - Dict, - List, - Literal, - Match, - NamedTuple, - Pattern, - Protocol, - Tuple, - Union, -) +from typing import Callable, List, Match, Protocol, Tuple from warnings import warn import pytest @@ -159,13 +148,11 @@ def parse_inline_code(inline_code: str) -> float: raise ValueParseError(inline_code) -r_special_cases = re.compile( - r"\*\*Special [Cc]ases\*\*(?:\n.*)+" - r"For floating-point operands,\n+" - r"((?:\s*-\s*.*\n)+)" -) -r_case = re.compile(r"\s+-\s*(.*)\.\n?") -r_remaining_case = re.compile("In the remaining cases.+") +r_not = re.compile("not (?:equal to )?(.+)") +r_array_element = re.compile(r"``([+-]?)x([12])_i``") +r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}") +r_gt = re.compile(f"greater than {r_code.pattern}") +r_lt = re.compile(f"less than {r_code.pattern}") def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str]: @@ -235,7 +222,7 @@ def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: value = parse_value(m.group(1)) check_result = make_rough_eq(value) expr = f"~{value}" - elif result_str == "positive": + elif "positive" in result_str: def check_result(result: float) -> bool: if math.isnan(result): @@ -244,7 +231,7 @@ def check_result(result: float) -> bool: return math.copysign(1, result) == 1 expr = "+" - elif result_str == "negative": + elif "negative" in result_str: def check_result(result: float) -> bool: if math.isnan(result): @@ -341,191 +328,16 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: return cases -@dataclass -class BinaryCond: - cond: BinaryCheck - expr: str - +class BinaryCond(Protocol): def __call__(self, i1: float, i2: float) -> bool: - return self.cond(i1, i2) - - -class BinaryCondFactory(Protocol): - def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: ... -r_not = re.compile("not (?:equal to )?(.+)") -r_array_element = re.compile(r"``([+-]?)x([12])_i``") -r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}") -r_gt = re.compile(f"greater than {r_code.pattern}") -r_lt = re.compile(f"less than {r_code.pattern}") - - -@dataclass -class ValueCondFactory(BinaryCondFactory): - input_: Union[Literal["i1"], Literal["i2"], Literal["either"], Literal["both"]] - re_groups_i: int - abs_: bool = False - - def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: - group = groups[self.re_groups_i] - - if m := r_array_element.match(group): - assert not self.abs_ # sanity check - sign = m.group(1) - if sign == "-": - signer = lambda i: -i - else: - signer = lambda i: i - - if self.input_ == "i1": - expr = f"x1ᵢ == {sign}x2ᵢ" - - def cond(i1: float, i2: float) -> bool: - _cond = make_eq(signer(i2)) - return _cond(i1) - - else: - assert self.input_ == "i2" # sanity check - expr = f"x2ᵢ == {sign}x1ᵢ" - - def cond(i1: float, i2: float) -> bool: - _cond = make_eq(signer(i1)) - return _cond(i2) - - return BinaryCond(cond, expr) - - _cond, expr_template = parse_cond(group) - - if self.abs_: - _cond = absify_cond(_cond) - - f_i1 = "x1ᵢ" - f_i2 = "x2ᵢ" - if self.abs_: - f_i1 = f"abs({f_i1})" - f_i2 = f"abs({f_i2})" - - if self.input_ == "i1": - expr = expr_template.replace("{}", f_i1) - - def cond(i1: float, i2: float) -> bool: - return _cond(i1) - - elif self.input_ == "i2": - expr = expr_template.replace("{}", f_i2) - - def cond(i1: float, i2: float) -> bool: - return _cond(i2) - - elif self.input_ == "either": - expr = f"({expr_template.replace('{}', f_i1)}) or ({expr_template.replace('{}', f_i2)})" - - def cond(i1: float, i2: float) -> bool: - return _cond(i1) or _cond(i2) - - else: - assert self.input_ == "both" # sanity check - expr = f"({expr_template.replace('{}', f_i1)}) and ({expr_template.replace('{}', f_i2)})" - - def cond(i1: float, i2: float) -> bool: - return _cond(i1) and _cond(i2) - - return BinaryCond(cond, expr) - - -class AndCondFactory(BinaryCondFactory): - def __init__(self, *cond_factories: BinaryCondFactory): - self.cond_factories = cond_factories - - def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: - conds = [cond_factory(groups) for cond_factory in self.cond_factories] - expr = " and ".join(f"({cond.expr})" for cond in conds) - - def cond(i1: float, i2: float) -> bool: - return all(cond(i1, i2) for cond in conds) - - return BinaryCond(cond, expr) - - -@dataclass -class SignCondFactory(BinaryCondFactory): - re_groups_i: int - - def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: - group = groups[self.re_groups_i] - if group == "the same mathematical sign": - - def cond(i1: float, i2: float) -> bool: - return math.copysign(1, i1) == math.copysign(1, i2) - - expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)" - elif group == "different mathematical signs": - - def cond(i1: float, i2: float) -> bool: - return math.copysign(1, i1) != math.copysign(1, i2) - - expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)" - else: - raise ValueParseError(group) - return BinaryCond(cond, expr) - - -class BinaryResultCheck(NamedTuple): - check_result: Callable[[float, float, float], bool] - expr: str - +class BinaryResultCheck(Protocol): def __call__(self, i1: float, i2: float, result: float) -> bool: - return self.check_result(i1, i2, result) - - def __repr__(self): - return self.expr - - -class BinaryResultCheckFactory(Protocol): - def __call__(self, groups: Tuple[str, ...]) -> BinaryCond: ... -@dataclass -class ResultCheckFactory(BinaryResultCheckFactory): - re_groups_i: int - - def __call__(self, groups: Tuple[str, ...]) -> BinaryResultCheck: - group = groups[self.re_groups_i] - - if m := r_array_element.match(group): - sign, input_ = m.groups() - if sign == "-": - signer = lambda i: -i - else: - signer = lambda i: i - - if input_ == "1": - expr = f"{sign}x1ᵢ" - - def check_result(i1: float, i2: float, result: float) -> bool: - _check_result = make_eq(signer(i1)) - return _check_result(result) - - else: - expr = f"{sign}x2ᵢ" - - def check_result(i1: float, i2: float, result: float) -> bool: - _check_result = make_eq(signer(i2)) - return _check_result(result) - - return BinaryResultCheck(check_result, expr) - - _check_result, expr = parse_result(group) - - def check_result(i1: float, i2: float, result: float) -> bool: - return _check_result(result) - - return BinaryResultCheck(check_result, expr) - - @dataclass(repr=False) class BinaryCase(Case): expr: str @@ -533,104 +345,29 @@ class BinaryCase(Case): check_result: BinaryResultCheck -class BinaryCaseFactory(NamedTuple): - cond_factory: BinaryCondFactory - check_result_factory: ResultCheckFactory - - def __call__(self, groups: Tuple[str, ...]) -> BinaryCase: - cond = self.cond_factory(groups) - check_result = self.check_result_factory(groups) - expr = f"{cond.expr} -> {check_result.expr}" - return BinaryCase(expr, cond, check_result) - - -r_result_sign = re.compile("([a-z]+) mathematical sign") - -binary_pattern_to_case_factory: Dict[Pattern, BinaryCaseFactory] = { - re.compile( - "If ``x1_i`` is (.+), ``x1_i`` is (.+), ``x2_i`` is (.+), " - "and ``x2_i`` is (.+), the result is (.+)" - ): BinaryCaseFactory( - AndCondFactory( - ValueCondFactory("i1", 0), - ValueCondFactory("i1", 1), - ValueCondFactory("i2", 2), - ValueCondFactory("i2", 3), - ), - ResultCheckFactory(4), - ), - re.compile( - "If ``x1_i`` is (.+), ``x1_i`` is (.+), " - "and ``x2_i`` is (.+), the result is (.+)" - ): BinaryCaseFactory( - AndCondFactory( - ValueCondFactory("i1", 0), - ValueCondFactory("i1", 1), - ValueCondFactory("i2", 2), - ), - ResultCheckFactory(3), - ), - re.compile( - "If ``x1_i`` is (.+), ``x2_i`` is (.+), " - "and ``x2_i`` is (.+), the result is (.+)" - ): BinaryCaseFactory( - AndCondFactory( - ValueCondFactory("i1", 0), - ValueCondFactory("i2", 1), - ValueCondFactory("i2", 2), - ), - ResultCheckFactory(3), - ), - # This pattern must come after the above to avoid false matches - re.compile( - "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)" - ): BinaryCaseFactory( - AndCondFactory(ValueCondFactory("i1", 0), ValueCondFactory("i2", 1)), - ResultCheckFactory(2), - ), - re.compile( - r"If ``abs\(x1_i\)`` is (.+) and ``x2_i`` is (.+), the result is (.+)" - ): BinaryCaseFactory( - AndCondFactory(ValueCondFactory("i1", 0, abs_=True), ValueCondFactory("i2", 1)), - ResultCheckFactory(2), - ), - re.compile( - "If either ``x1_i`` or ``x2_i`` is (.+), the result is (.+)" - ): BinaryCaseFactory(ValueCondFactory("either", 0), ResultCheckFactory(1)), - re.compile( - "If ``x1_i`` and ``x2_i`` have (.+signs?), " - f"the result has a {r_result_sign.pattern}" - ): BinaryCaseFactory(SignCondFactory(0), ResultCheckFactory(1)), - re.compile( - "If ``x1_i`` and ``x2_i`` have (.+signs?) and are both (.+), " - f"the result has a {r_result_sign.pattern}" - ): BinaryCaseFactory( - AndCondFactory(SignCondFactory(0), ValueCondFactory("both", 1)), - ResultCheckFactory(2), - ), - re.compile( - "If ``x1_i`` and ``x2_i`` have (.+signs?), the result has a " - rf"{r_result_sign.pattern} , unless the result is (.+)\. If the result " - r"is ``NaN``, the \"sign\" of ``NaN`` is implementation-defined\." - ): BinaryCaseFactory(SignCondFactory(0), ResultCheckFactory(1)), - re.compile( - "If ``x2_i`` is (.+), the result is (.+), even if ``x1_i`` is .+" - ): BinaryCaseFactory(ValueCondFactory("i2", 0), ResultCheckFactory(1)), -} - +r_special_cases = re.compile( + r"\*\*Special [Cc]ases\*\*(?:\n.*)+" + r"For floating-point operands,\n+" + r"((?:\s*-\s*.*\n)+)" +) +r_case = re.compile(r"\s+-\s*(.*)\.\n?") r_binary_case = re.compile("If (.+), the result (.+)") +r_remaining_case = re.compile("In the remaining cases.+") -r_cond_sep = re.compile(", | and ") +r_cond_sep = re.compile(r"(? BinaryCase: conds = [] cond_exprs = [] for cond_str in cond_strs: - if m := r_both_inputs_are_value.match(cond_str): - raise ValueParseError(cond_str) + if m := r_input_is_array_element.match(cond_str): + in_sign, input_array, value_sign, value_array = m.groups() + assert in_sign == "" and value_array != input_array # sanity check + expr = f"{in_sign}x{input_array}ᵢ == {value_sign}x{value_array}ᵢ" + if value_array == "1": + if value_sign != "-": + + def cond(i1: float, i2: float) -> bool: + eq = make_eq(i1) + return eq(i2) + + else: + + def cond(i1: float, i2: float) -> bool: + eq = make_eq(-i1) + return eq(i2) + + else: + if value_sign != "-": + + def cond(i1: float, i2: float) -> bool: + eq = make_eq(i2) + return eq(i1) + + else: + + def cond(i1: float, i2: float) -> bool: + eq = make_eq(-i2) + return eq(i1) + + elif m := r_both_inputs_are_value.match(cond_str): + unary_cond, expr_template = parse_cond(m.group(1)) + left_expr = expr_template.replace("{}", "x1ᵢ") + right_expr = expr_template.replace("{}", "x2ᵢ") + expr = f"({left_expr}) and ({right_expr})" + + def cond(i1: float, i2: float) -> bool: + return unary_cond(i1) and unary_cond(i2) + else: cond_m = r_cond.match(cond_str) if cond_m is None: raise ValueParseError(cond_str) input_str, value_str = cond_m.groups() - unary_cond, expr_template = parse_cond(value_str) + if value_str == "the same mathematical sign": + expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)" + + def cond(i1: float, i2: float) -> bool: + return math.copysign(1, i1) == math.copysign(1, i2) + + elif value_str == "different mathematical signs": + expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)" - if m := r_input.match(input_str): - x_no = m.group(1) - args_i = int(x_no) - 1 - expr = expr_template.replace("{}", f"x{x_no}ᵢ") + def cond(i1: float, i2: float) -> bool: + return math.copysign(1, i1) != math.copysign(1, i2) - def cond(*inputs) -> bool: - return unary_cond(inputs[args_i]) + else: + unary_cond, expr_template = parse_cond(value_str) - elif m := r_abs_input.match(input_str): - x_no = m.group(1) - args_i = int(x_no) - 1 - expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)") + if m := r_input.match(input_str): + x_no = m.group(1) + args_i = int(x_no) - 1 + expr = expr_template.replace("{}", f"x{x_no}ᵢ") - def cond(*inputs) -> bool: - return unary_cond(abs(inputs[args_i])) + def cond(*inputs) -> bool: + return unary_cond(inputs[args_i]) - elif r_and_input.match(input_str): - left_expr = expr_template.replace("{}", "x1ᵢ") - right_expr = expr_template.replace("{}", "x2ᵢ") - expr = f"({left_expr}) and ({right_expr})" + elif m := r_abs_input.match(input_str): + x_no = m.group(1) + args_i = int(x_no) - 1 + expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)") - def cond(i1: float, i2: float) -> bool: - return unary_cond(i1) and unary_cond(i2) + def cond(*inputs) -> bool: + return unary_cond(abs(inputs[args_i])) - elif r_or_input.match(input_str): - left_expr = expr_template.replace("{}", "x1ᵢ") - right_expr = expr_template.replace("{}", "x2ᵢ") - expr = f"({left_expr}) and ({right_expr})" + elif r_and_input.match(input_str): + left_expr = expr_template.replace("{}", "x1ᵢ") + right_expr = expr_template.replace("{}", "x2ᵢ") + expr = f"({left_expr}) and ({right_expr})" - def cond(i1: float, i2: float) -> bool: - return unary_cond(i1) or unary_cond(i2) + def cond(i1: float, i2: float) -> bool: + return unary_cond(i1) and unary_cond(i2) - else: - raise ValueParseError(input_str) + elif r_or_input.match(input_str): + left_expr = expr_template.replace("{}", "x1ᵢ") + right_expr = expr_template.replace("{}", "x2ᵢ") + expr = f"({left_expr}) or ({right_expr})" + + def cond(i1: float, i2: float) -> bool: + return unary_cond(i1) or unary_cond(i2) + + else: + raise ValueParseError(input_str) conds.append(cond) cond_exprs.append(expr) @@ -690,14 +477,48 @@ def cond(i1: float, i2: float) -> bool: result_m = r_result.match(case_m.group(2)) if result_m is None: raise ValueParseError(case_m.group(2)) - check_result, result_expr = parse_result(result_m.group(1)) + result_str = result_m.group(1) + if m := r_array_element.match(result_str): + sign, input_ = m.groups() + result_expr = f"{sign}x{input_}ᵢ" + if input_ == "1": + if sign != "-": - expr = " and ".join(f"({expr})" for expr in cond_exprs) + " -> " + result_expr + def check_result(i1: float, i2: float, result: float) -> bool: + eq = make_eq(i1) + return eq(result) + + else: + + def check_result(i1: float, i2: float, result: float) -> bool: + eq = make_eq(-i1) + return eq(result) + + else: + if sign != "-": + + def check_result(i1: float, i2: float, result: float) -> bool: + eq = make_eq(i2) + return eq(result) + + else: + + def check_result(i1: float, i2: float, result: float) -> bool: + eq = make_eq(-i2) + return eq(result) + + else: + _check_result, result_expr = parse_result(result_m.group(1)) + + def check_result(i1: float, i2: float, result: float) -> bool: + return _check_result(result) + + expr = " and ".join(cond_exprs) + " -> " + result_expr def cond(i1: float, i2: float) -> bool: return all(c(i1, i2) for c in conds) - return BinaryCase(expr, cond, lambda l, r, o: check_result(o)) + return BinaryCase(expr, cond, check_result) r_redundant_case = re.compile("result.+determined by the rule already stated above") @@ -717,16 +538,14 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: continue if r_redundant_case.search(case_str): continue - if m := r_binary_case.search(case_str): + if m := r_binary_case.match(case_str): try: case = parse_binary_case(m) + cases.append(case) except ValueParseError as e: warn(f"not machine-readable: '{e.value}'") - break - cases.append(case) - break else: - if not r_remaining_case.search(case_str): + if not r_remaining_case.match(case_str): warn(f"case not machine-readable: '{case_str}'") return cases From a1a5780a08313ec9a28781788dbb9462aca6b5d2 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Mar 2022 14:13:52 +0000 Subject: [PATCH 26/63] Fix local func definitions causing pass-by-reference problems --- array_api_tests/test_special_cases.py | 135 +++++++++++++++++--------- 1 file changed, 89 insertions(+), 46 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index bff35f5c..6c61bc21 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -3,6 +3,7 @@ import re from dataclasses import dataclass from decimal import ROUND_HALF_EVEN, Decimal +from enum import Enum, auto from typing import Callable, List, Match, Protocol, Tuple from warnings import warn @@ -371,38 +372,77 @@ class BinaryCase(Case): r_both_inputs_are_value = re.compile("are both (.+)") +class BinaryCondInput(Enum): + FIRST = auto() + SECOND = auto() + BOTH = auto() + EITHER = auto() + + +def noop(obj): + return obj + + +def make_partial_cond( + input_: BinaryCondInput, unary_check: UnaryCheck, *, input_wrapper=None +) -> BinaryCond: + if input_wrapper is None: + input_wrapper = noop + if input_ == BinaryCondInput.FIRST: + + def partial_cond(i1: float, i2: float) -> bool: + return unary_check(input_wrapper(i1)) + + elif input_ == BinaryCondInput.SECOND: + + def partial_cond(i1: float, i2: float) -> bool: + return unary_check(input_wrapper(i2)) + + elif input_ == BinaryCondInput.BOTH: + + def partial_cond(i1: float, i2: float) -> bool: + return unary_check(input_wrapper(i1)) and unary_check(input_wrapper(i2)) + + else: + + def partial_cond(i1: float, i2: float) -> bool: + return unary_check(input_wrapper(i1)) or unary_check(input_wrapper(i2)) + + return partial_cond + + def parse_binary_case(case_m: Match) -> BinaryCase: cond_strs = r_cond_sep.split(case_m.group(1)) - conds = [] - cond_exprs = [] + partial_conds = [] + partial_exprs = [] for cond_str in cond_strs: if m := r_input_is_array_element.match(cond_str): in_sign, input_array, value_sign, value_array = m.groups() assert in_sign == "" and value_array != input_array # sanity check - expr = f"{in_sign}x{input_array}ᵢ == {value_sign}x{value_array}ᵢ" + partial_expr = f"{in_sign}x{input_array}ᵢ == {value_sign}x{value_array}ᵢ" if value_array == "1": if value_sign != "-": - def cond(i1: float, i2: float) -> bool: + def partial_cond(i1: float, i2: float) -> bool: eq = make_eq(i1) return eq(i2) else: - def cond(i1: float, i2: float) -> bool: + def partial_cond(i1: float, i2: float) -> bool: eq = make_eq(-i1) return eq(i2) else: if value_sign != "-": - def cond(i1: float, i2: float) -> bool: + def partial_cond(i1: float, i2: float) -> bool: eq = make_eq(i2) return eq(i1) else: - def cond(i1: float, i2: float) -> bool: + def partial_cond(i1: float, i2: float) -> bool: eq = make_eq(-i2) return eq(i1) @@ -410,10 +450,10 @@ def cond(i1: float, i2: float) -> bool: unary_cond, expr_template = parse_cond(m.group(1)) left_expr = expr_template.replace("{}", "x1ᵢ") right_expr = expr_template.replace("{}", "x2ᵢ") - expr = f"({left_expr}) and ({right_expr})" - - def cond(i1: float, i2: float) -> bool: - return unary_cond(i1) and unary_cond(i2) + partial_expr = f"({left_expr}) and ({right_expr})" + partial_cond = make_partial_cond( # type: ignore + BinaryCondInput.BOTH, unary_cond + ) else: cond_m = r_cond.match(cond_str) @@ -422,57 +462,58 @@ def cond(i1: float, i2: float) -> bool: input_str, value_str = cond_m.groups() if value_str == "the same mathematical sign": - expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)" + partial_expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)" - def cond(i1: float, i2: float) -> bool: + def partial_cond(i1: float, i2: float) -> bool: return math.copysign(1, i1) == math.copysign(1, i2) elif value_str == "different mathematical signs": - expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)" + partial_expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)" - def cond(i1: float, i2: float) -> bool: + def partial_cond(i1: float, i2: float) -> bool: return math.copysign(1, i1) != math.copysign(1, i2) else: - unary_cond, expr_template = parse_cond(value_str) - + unary_check, expr_template = parse_cond(value_str) + # Do not define partial_cond via the def keyword, as one + # partial_cond definition can mess up previous definitions + # in the partial_conds list. This is a hard-limitation of + # using local functions with the same name and that use the same + # outer variables (i.e. unary_cond). + input_wrapper = None if m := r_input.match(input_str): x_no = m.group(1) - args_i = int(x_no) - 1 - expr = expr_template.replace("{}", f"x{x_no}ᵢ") - - def cond(*inputs) -> bool: - return unary_cond(inputs[args_i]) - + partial_expr = expr_template.replace("{}", f"x{x_no}ᵢ") + if x_no == "1": + input_ = BinaryCondInput.FIRST + else: + input_ = BinaryCondInput.SECOND elif m := r_abs_input.match(input_str): x_no = m.group(1) - args_i = int(x_no) - 1 - expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)") - - def cond(*inputs) -> bool: - return unary_cond(abs(inputs[args_i])) - + partial_expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)") + if x_no == "1": + input_ = BinaryCondInput.FIRST + else: + input_ = BinaryCondInput.SECOND + input_wrapper = abs elif r_and_input.match(input_str): left_expr = expr_template.replace("{}", "x1ᵢ") right_expr = expr_template.replace("{}", "x2ᵢ") - expr = f"({left_expr}) and ({right_expr})" - - def cond(i1: float, i2: float) -> bool: - return unary_cond(i1) and unary_cond(i2) - + partial_expr = f"({left_expr}) and ({right_expr})" + input_ = BinaryCondInput.BOTH elif r_or_input.match(input_str): left_expr = expr_template.replace("{}", "x1ᵢ") right_expr = expr_template.replace("{}", "x2ᵢ") - expr = f"({left_expr}) or ({right_expr})" - - def cond(i1: float, i2: float) -> bool: - return unary_cond(i1) or unary_cond(i2) - + partial_expr = f"({left_expr}) or ({right_expr})" + input_ = BinaryCondInput.EITHER else: raise ValueParseError(input_str) + partial_cond = make_partial_cond( # type: ignore + input_, unary_check, input_wrapper=input_wrapper + ) - conds.append(cond) - cond_exprs.append(expr) + partial_conds.append(partial_cond) + partial_exprs.append(partial_expr) result_m = r_result.match(case_m.group(2)) if result_m is None: @@ -513,10 +554,10 @@ def check_result(i1: float, i2: float, result: float) -> bool: def check_result(i1: float, i2: float, result: float) -> bool: return _check_result(result) - expr = " and ".join(cond_exprs) + " -> " + result_expr + expr = " and ".join(partial_exprs) + " -> " + result_expr def cond(i1: float, i2: float) -> bool: - return all(c(i1, i2) for c in conds) + return all(pc(i1, i2) for pc in partial_conds) return BinaryCase(expr, cond, check_result) @@ -639,8 +680,10 @@ def test_binary(func_name, func, cases, x1, x2): f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" f_out = f"{sh.fmt_idx('out', o_idx)}={o}" - assert case.check_result( - l, r, o - ), f"{f_out} not good [{func_name}()]\n{f_left}, {f_right}" + assert case.check_result(l, r, o), ( + f"{f_out} not good [{func_name}()]\n" + f"{case.expr}\n" + f"{f_left}, {f_right}" + ) break assume(good_example) From 03b8f32396b5177941b52c9d3164fc4929f8fd56 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 1 Mar 2022 14:43:20 +0000 Subject: [PATCH 27/63] Use spec reprs for case `expr` --- array_api_tests/test_special_cases.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 6c61bc21..93fcc4a4 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -166,20 +166,20 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str]: if m := r_code.match(cond_str): value = parse_value(m.group(1)) cond = make_eq(value) - expr_template = "{} == " + str(value) + expr_template = "{} == " + m.group(1) elif m := r_gt.match(cond_str): value = parse_value(m.group(1)) cond = make_gt(value) - expr_template = "{} > " + str(value) + expr_template = "{} > " + m.group(1) elif m := r_lt.match(cond_str): value = parse_value(m.group(1)) cond = make_lt(value) - expr_template = "{} < " + str(value) + expr_template = "{} < " + m.group(1) elif m := r_either_code.match(cond_str): v1 = parse_value(m.group(1)) v2 = parse_value(m.group(2)) cond = make_or(make_eq(v1), make_eq(v2)) - expr_template = "{} == " + str(v1) + " or {} == " + str(v2) + expr_template = "{} == " + m.group(1) + " or {} == " + m.group(2) elif cond_str in ["finite", "a finite number"]: cond = math.isfinite expr_template = "isfinite({})" @@ -217,12 +217,12 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str]: def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: if m := r_code.match(result_str): value = parse_value(m.group(1)) - check_result = make_eq(value) - expr = str(value) + check_result = make_eq(value) # type: ignore + expr = m.group(1) elif m := r_approx_value.match(result_str): value = parse_value(m.group(1)) - check_result = make_rough_eq(value) - expr = f"~{value}" + check_result = make_rough_eq(value) # type: ignore + expr = f"roughly {m.group(1)}" elif "positive" in result_str: def check_result(result: float) -> bool: From ef10e907c8f50a9920ce77f63cb63c9bc297eb40 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Mar 2022 18:58:09 +0000 Subject: [PATCH 28/63] Factor out logic for input eq conds/results --- array_api_tests/test_special_cases.py | 170 ++++++++++++++------------ 1 file changed, 91 insertions(+), 79 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 93fcc4a4..9fd9f9da 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from decimal import ROUND_HALF_EVEN, Decimal from enum import Enum, auto -from typing import Callable, List, Match, Protocol, Tuple +from typing import Callable, List, Match, Optional, Protocol, Tuple from warnings import warn import pytest @@ -372,33 +372,46 @@ class BinaryCase(Case): r_both_inputs_are_value = re.compile("are both (.+)") -class BinaryCondInput(Enum): +class BinaryCondArg(Enum): FIRST = auto() SECOND = auto() BOTH = auto() EITHER = auto() + @classmethod + def from_x_no(cls, string): + if string == "1": + return cls.FIRST + elif string == "2": + return cls.SECOND + else: + raise ValueError(f"{string=} not '1' or '2'") -def noop(obj): - return obj +def noop(n: float) -> float: + return n -def make_partial_cond( - input_: BinaryCondInput, unary_check: UnaryCheck, *, input_wrapper=None + +def make_binary_cond( + cond_arg: BinaryCondArg, + unary_check: UnaryCheck, + *, + input_wrapper: Optional[Callable[[float], float]] = None, ) -> BinaryCond: if input_wrapper is None: input_wrapper = noop - if input_ == BinaryCondInput.FIRST: + + if cond_arg == BinaryCondArg.FIRST: def partial_cond(i1: float, i2: float) -> bool: return unary_check(input_wrapper(i1)) - elif input_ == BinaryCondInput.SECOND: + elif cond_arg == BinaryCondArg.SECOND: def partial_cond(i1: float, i2: float) -> bool: return unary_check(input_wrapper(i2)) - elif input_ == BinaryCondInput.BOTH: + elif cond_arg == BinaryCondArg.BOTH: def partial_cond(i1: float, i2: float) -> bool: return unary_check(input_wrapper(i1)) and unary_check(input_wrapper(i2)) @@ -411,50 +424,78 @@ def partial_cond(i1: float, i2: float) -> bool: return partial_cond -def parse_binary_case(case_m: Match) -> BinaryCase: - cond_strs = r_cond_sep.split(case_m.group(1)) - partial_conds = [] - partial_exprs = [] - for cond_str in cond_strs: - if m := r_input_is_array_element.match(cond_str): - in_sign, input_array, value_sign, value_array = m.groups() - assert in_sign == "" and value_array != input_array # sanity check - partial_expr = f"{in_sign}x{input_array}ᵢ == {value_sign}x{value_array}ᵢ" - if value_array == "1": - if value_sign != "-": +def make_eq_other_input_cond( + eq_to: BinaryCondArg, *, eq_neg: bool = False +) -> BinaryCond: + if eq_neg: + input_wrapper = lambda i: -i + else: + input_wrapper = noop - def partial_cond(i1: float, i2: float) -> bool: - eq = make_eq(i1) - return eq(i2) + if eq_to == BinaryCondArg.FIRST: - else: + def cond(i1: float, i2: float) -> bool: + eq = make_eq(input_wrapper(i1)) + return eq(i2) - def partial_cond(i1: float, i2: float) -> bool: - eq = make_eq(-i1) - return eq(i2) + elif eq_to == BinaryCondArg.SECOND: - else: - if value_sign != "-": + def cond(i1: float, i2: float) -> bool: + eq = make_eq(input_wrapper(i2)) + return eq(i1) + + else: + raise ValueError(f"{eq_to=} must be FIRST or SECOND") - def partial_cond(i1: float, i2: float) -> bool: - eq = make_eq(i2) - return eq(i1) + return cond - else: - def partial_cond(i1: float, i2: float) -> bool: - eq = make_eq(-i2) - return eq(i1) +def make_eq_input_check_result( + eq_to: BinaryCondArg, *, eq_neg: bool = False +) -> BinaryResultCheck: + if eq_neg: + input_wrapper = lambda i: -i + else: + input_wrapper = noop + + if eq_to == BinaryCondArg.FIRST: + + def check_result(i1: float, i2: float, result: float) -> bool: + eq = make_eq(input_wrapper(i1)) + return eq(result) + + elif eq_to == BinaryCondArg.SECOND: + + def check_result(i1: float, i2: float, result: float) -> bool: + eq = make_eq(input_wrapper(i2)) + return eq(result) + + else: + raise ValueError(f"{eq_to=} must be FIRST or SECOND") + + return check_result + +def parse_binary_case(case_m: Match) -> BinaryCase: + cond_strs = r_cond_sep.split(case_m.group(1)) + partial_conds = [] + partial_exprs = [] + for cond_str in cond_strs: + if m := r_input_is_array_element.match(cond_str): + in_sign, in_no, other_sign, other_no = m.groups() + assert in_sign == "" and other_no != in_no # sanity check + partial_expr = f"{in_sign}x{in_no}ᵢ == {other_sign}x{other_no}ᵢ" + partial_cond = make_eq_other_input_cond( # type: ignore + BinaryCondArg.from_x_no(other_no), eq_neg=other_sign == "-" + ) elif m := r_both_inputs_are_value.match(cond_str): unary_cond, expr_template = parse_cond(m.group(1)) left_expr = expr_template.replace("{}", "x1ᵢ") right_expr = expr_template.replace("{}", "x2ᵢ") partial_expr = f"({left_expr}) and ({right_expr})" - partial_cond = make_partial_cond( # type: ignore - BinaryCondInput.BOTH, unary_cond + partial_cond = make_binary_cond( # type: ignore + BinaryCondArg.BOTH, unary_cond ) - else: cond_m = r_cond.match(cond_str) if cond_m is None: @@ -484,32 +525,26 @@ def partial_cond(i1: float, i2: float) -> bool: if m := r_input.match(input_str): x_no = m.group(1) partial_expr = expr_template.replace("{}", f"x{x_no}ᵢ") - if x_no == "1": - input_ = BinaryCondInput.FIRST - else: - input_ = BinaryCondInput.SECOND + cond_arg = BinaryCondArg.from_x_no(x_no) elif m := r_abs_input.match(input_str): x_no = m.group(1) partial_expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)") - if x_no == "1": - input_ = BinaryCondInput.FIRST - else: - input_ = BinaryCondInput.SECOND + cond_arg = BinaryCondArg.from_x_no(x_no) input_wrapper = abs elif r_and_input.match(input_str): left_expr = expr_template.replace("{}", "x1ᵢ") right_expr = expr_template.replace("{}", "x2ᵢ") partial_expr = f"({left_expr}) and ({right_expr})" - input_ = BinaryCondInput.BOTH + cond_arg = BinaryCondArg.BOTH elif r_or_input.match(input_str): left_expr = expr_template.replace("{}", "x1ᵢ") right_expr = expr_template.replace("{}", "x2ᵢ") partial_expr = f"({left_expr}) or ({right_expr})" - input_ = BinaryCondInput.EITHER + cond_arg = BinaryCondArg.EITHER else: raise ValueParseError(input_str) - partial_cond = make_partial_cond( # type: ignore - input_, unary_check, input_wrapper=input_wrapper + partial_cond = make_binary_cond( # type: ignore + cond_arg, unary_check, input_wrapper=input_wrapper ) partial_conds.append(partial_cond) @@ -520,34 +555,11 @@ def partial_cond(i1: float, i2: float) -> bool: raise ValueParseError(case_m.group(2)) result_str = result_m.group(1) if m := r_array_element.match(result_str): - sign, input_ = m.groups() - result_expr = f"{sign}x{input_}ᵢ" - if input_ == "1": - if sign != "-": - - def check_result(i1: float, i2: float, result: float) -> bool: - eq = make_eq(i1) - return eq(result) - - else: - - def check_result(i1: float, i2: float, result: float) -> bool: - eq = make_eq(-i1) - return eq(result) - - else: - if sign != "-": - - def check_result(i1: float, i2: float, result: float) -> bool: - eq = make_eq(i2) - return eq(result) - - else: - - def check_result(i1: float, i2: float, result: float) -> bool: - eq = make_eq(-i2) - return eq(result) - + sign, x_no = m.groups() + result_expr = f"{sign}x{x_no}ᵢ" + check_result = make_eq_input_check_result( # type: ignore + BinaryCondArg.from_x_no(x_no), eq_neg=sign == "-" + ) else: _check_result, result_expr = parse_result(result_m.group(1)) From e9e99ab96c4e6173fede8affa84d7acfdf53a7c4 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 2 Mar 2022 19:23:27 +0000 Subject: [PATCH 29/63] Enable submodules for NumPy workflow --- .github/workflows/numpy.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml index 82420452..a6118330 100644 --- a/.github/workflows/numpy.yml +++ b/.github/workflows/numpy.yml @@ -11,7 +11,10 @@ jobs: python-version: [3.8, 3.9] steps: - - uses: actions/checkout@v1 + - name: Checkout array-api-tests + uses: actions/checkout@v1 + with: + submodules: 'true' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v1 with: From 709cb4c0105941793b54aea5e170308716f5858d Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 3 Mar 2022 12:42:51 +0000 Subject: [PATCH 30/63] Rudimentary cond element strategy pattern to ensure good examples --- array_api_tests/test_special_cases.py | 292 +++++++++++++++++++------- 1 file changed, 221 insertions(+), 71 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 9fd9f9da..3c797d68 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1,14 +1,29 @@ +from __future__ import annotations + import inspect import math import re from dataclasses import dataclass from decimal import ROUND_HALF_EVEN, Decimal from enum import Enum, auto -from typing import Callable, List, Match, Optional, Protocol, Tuple +from typing import ( + Any, + Callable, + Dict, + List, + Match, + NamedTuple, + Optional, + Protocol, + Tuple, +) from warnings import warn import pytest -from hypothesis import assume, given +from hypothesis import assume, given, note +from hypothesis import strategies as st + +from array_api_tests.typing import Array, DataType from . import dtype_helpers as dh from . import hypothesis_helpers as hh @@ -89,7 +104,7 @@ def and_(i: float) -> bool: return and_ -def notify_cond(cond: UnaryCheck) -> UnaryCheck: +def make_not_cond(cond: UnaryCheck) -> UnaryCheck: def not_cond(i: float) -> bool: return not cond(i) @@ -156,62 +171,160 @@ def parse_inline_code(inline_code: str) -> float: r_lt = re.compile(f"less than {r_code.pattern}") -def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str]: +FromDataType = Callable[[DataType], st.SearchStrategy] + + +class ElementsStrategyFactory(NamedTuple): + from_dtype: FromDataType + kwargs: Optional[Dict[str, Any]] + filter_: Optional[Callable[[Array], bool]] + + def __add__(self, other: ElementsStrategyFactory) -> ElementsStrategyFactory: + assert not ( + isinstance(self.kwargs, Callable) or isinstance(other.kwargs, Callable) + ), ( + f"{self.kwargs=} and {other.kwargs=}, " "but both must be from_dtype kwargs" + ) + kwargs1 = self.kwargs or {} + kwargs2 = other.kwargs or {} + for k in kwargs1.keys(): + if k in kwargs2.keys(): + assert kwargs1[k] == kwargs2[k] # sanity check + + if self.filter_ is not None and other.filter_ is not None: + filter_ = lambda i: self.filter_(i) and other.filter_(i) + else: + try: + filter_ = next( + f for f in [self.filter_, other.filter_] if f is not None + ) + except StopIteration: + filter_ = None + + return ElementsStrategyFactory( + kwargs={**kwargs1, **kwargs2}, + filter_=filter_, + ) + + def to_strategy(self, dtype: DataType) -> st.SearchStrategy[float]: + kw = self.kwargs or {} + if self.from_dtype != xps.from_dtype: + assert kw == {} # sanity check + strat = self.from_dtype(dtype, **kw) + if self.filter_ is not None: + strat = strat.filter(self.filter_) + return strat + + +def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, ElementsStrategyFactory]: if m := r_not.match(cond_str): cond_str = m.group(1) - notify = True + not_cond = True else: - notify = False + not_cond = False + from_dtype = xps.from_dtype # type: ignore + kwargs = None + filter_ = None if m := r_code.match(cond_str): value = parse_value(m.group(1)) cond = make_eq(value) expr_template = "{} == " + m.group(1) + if not not_cond: + from_dtype = lambda _: st.just(value) # type: ignore elif m := r_gt.match(cond_str): value = parse_value(m.group(1)) cond = make_gt(value) expr_template = "{} > " + m.group(1) + if not not_cond: + kwargs = {"min_value": value, "exclude_min": True} elif m := r_lt.match(cond_str): value = parse_value(m.group(1)) cond = make_lt(value) expr_template = "{} < " + m.group(1) + if not not_cond: + kwargs = {"max_value": value, "exclude_max": True} elif m := r_either_code.match(cond_str): v1 = parse_value(m.group(1)) v2 = parse_value(m.group(2)) cond = make_or(make_eq(v1), make_eq(v2)) expr_template = "{} == " + m.group(1) + " or {} == " + m.group(2) + if not not_cond: + from_dtype = lambda _: st.sampled_from([v1, v2]) # type: ignore elif cond_str in ["finite", "a finite number"]: cond = math.isfinite expr_template = "isfinite({})" + if not not_cond: + kwargs = {"allow_nan": False, "allow_infinity": False} elif cond_str in "a positive (i.e., greater than ``0``) finite number": cond = lambda i: math.isfinite(i) and i > 0 expr_template = "isfinite({}) and {} > 0" + if not not_cond: + kwargs = { + "allow_nan": False, + "allow_infinity": False, + "min_value": 0, + "exclude_min": True, + } elif cond_str == "a negative (i.e., less than ``0``) finite number": cond = lambda i: math.isfinite(i) and i < 0 expr_template = "isfinite({}) and {} < 0" + if not not_cond: + kwargs = { + "allow_nan": False, + "allow_infinity": False, + "max_value": 0, + "exclude_max": True, + } elif cond_str == "positive": cond = lambda i: math.copysign(1, i) == 1 expr_template = "copysign(1, {}) == 1" + if not not_cond: + # We assume (positive) zero is special cased seperately + kwargs = {"min_value": 0, "exclude_min": True} elif cond_str == "negative": cond = lambda i: math.copysign(1, i) == -1 expr_template = "copysign(1, {}) == -1" + if not not_cond: + # We assume (positive) zero is special cased seperately + kwargs = {"max_value": 0, "exclude_max": True} elif "nonzero finite" in cond_str: cond = lambda i: math.isfinite(i) and i != 0 expr_template = "isfinite({}) and {} != 0" + if not not_cond: + kwargs = {"allow_nan": False, "allow_infinity": False} + filter_ = lambda n: n != 0 elif cond_str == "an integer value": cond = lambda i: i.is_integer() expr_template = "{}.is_integer()" + if not not_cond: + + def from_dtype(dtype: DataType) -> st.SearchStrategy: + m, M = dh.dtype_ranges[dtype] + return st.integers(math.ceil(m), math.floor(M)).map(float) + elif cond_str == "an odd integer value": cond = lambda i: i.is_integer() and i % 2 == 1 expr_template = "{}.is_integer() and {} % 2 == 1" + if not not_cond: + + def from_dtype(dtype: DataType) -> st.SearchStrategy: + m, M = dh.dtype_ranges[dtype] + return ( + st.integers(math.ceil(m), math.floor(M)) + .filter(lambda n: n % 2 == 1) + .map(float) + ) + else: raise ValueParseError(cond_str) - if notify: - cond = notify_cond(cond) - expr_template = f"not ({expr_template})" + if not_cond: + expr_template = f"not {expr_template}" + cond = make_not_cond(cond) + filter_ = cond - return cond, expr_template + return cond, expr_template, ElementsStrategyFactory(from_dtype, kwargs, filter_) def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: @@ -222,7 +335,8 @@ def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: elif m := r_approx_value.match(result_str): value = parse_value(m.group(1)) check_result = make_rough_eq(value) # type: ignore - expr = f"roughly {m.group(1)}" + repr_ = m.group(1).replace("π", "pi") # for pytest param names + expr = f"roughly {repr_}" elif "positive" in result_str: def check_result(result: float) -> bool: @@ -248,7 +362,8 @@ def check_result(result: float) -> bool: class Case(Protocol): - expr: str + cond_expr: str + result_expr: str def cond(self, *args) -> bool: ... @@ -256,11 +371,11 @@ def cond(self, *args) -> bool: def check_result(self, *args) -> bool: ... - def __repr__(self): - return f"{self.__class__.__name__}(<{self.expr}>)" + def __str__(self) -> str: + return f"{self.cond_expr} -> {self.result_expr}" - def __str__(self): - return self.expr + def __repr__(self) -> str: + return f"{self.__class__.__name__}(<{self}>)" class UnaryCond(Protocol): @@ -275,17 +390,29 @@ def __call__(self, i: float, result: float) -> bool: @dataclass(repr=False) class UnaryCase(Case): - expr: str + cond_expr: str + result_expr: str + cond_strat: FromDataType cond: UnaryCheck check_result: UnaryResultCheck @classmethod def from_strings(cls, cond_str: str, result_str: str): - cond, cond_expr_template = parse_cond(cond_str) - check_result, check_result_expr = parse_result(result_str) - cond_expr = cond_expr_template.replace("{}", "xᵢ") - expr = f"{cond_expr} -> {check_result_expr}" - return cls(expr, cond, lambda i, result: check_result(result)) + cond, cond_expr_template, strat_factory = parse_cond(cond_str) + cond_expr = cond_expr_template.replace("{}", "x_i") + cond_strat = strat_factory.to_strategy + _check_result, result_expr = parse_result(result_str) + + def check_result(i: float, result: float) -> bool: + return _check_result(result) + + return cls( + cond_expr=cond_expr, + cond=cond, + cond_strat=cond_strat, + result_expr=result_expr, + check_result=check_result, + ) r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") @@ -293,9 +420,18 @@ def from_strings(cls, cond_str: str, result_str: str): "If two integers are equally close to ``x_i``, " "the result is the even integer closest to ``x_i``" ) + + +def point_5_from_dtype(dtype: DataType): + m, M = dh.dtype_ranges[dtype] + return st.integers(math.ceil(m) // 2, math.floor(M) // 2).map(lambda n: n * 0.5) + + even_int_round_case = UnaryCase( - expr="i % 0.5 == 0 -> Decimal(i).to_integral_exact(ROUND_HALF_EVEN)", + cond_expr="i % 0.5 == 0", cond=lambda i: i % 0.5 == 0, + cond_strat=point_5_from_dtype, + result_expr="Decimal(i).to_integral_exact(ROUND_HALF_EVEN)", check_result=lambda i, result: ( result == float(Decimal(i).to_integral_exact(ROUND_HALF_EVEN)) ), @@ -341,7 +477,8 @@ def __call__(self, i1: float, i2: float, result: float) -> bool: @dataclass(repr=False) class BinaryCase(Case): - expr: str + cond_expr: str + result_expr: str cond: BinaryCond check_result: BinaryResultCheck @@ -484,14 +621,14 @@ def parse_binary_case(case_m: Match) -> BinaryCase: if m := r_input_is_array_element.match(cond_str): in_sign, in_no, other_sign, other_no = m.groups() assert in_sign == "" and other_no != in_no # sanity check - partial_expr = f"{in_sign}x{in_no}ᵢ == {other_sign}x{other_no}ᵢ" + partial_expr = f"{in_sign}x{in_no}_i == {other_sign}x{other_no}_i" partial_cond = make_eq_other_input_cond( # type: ignore BinaryCondArg.from_x_no(other_no), eq_neg=other_sign == "-" ) elif m := r_both_inputs_are_value.match(cond_str): unary_cond, expr_template = parse_cond(m.group(1)) - left_expr = expr_template.replace("{}", "x1ᵢ") - right_expr = expr_template.replace("{}", "x2ᵢ") + left_expr = expr_template.replace("{}", "x1_i") + right_expr = expr_template.replace("{}", "x2_i") partial_expr = f"({left_expr}) and ({right_expr})" partial_cond = make_binary_cond( # type: ignore BinaryCondArg.BOTH, unary_cond @@ -503,13 +640,13 @@ def parse_binary_case(case_m: Match) -> BinaryCase: input_str, value_str = cond_m.groups() if value_str == "the same mathematical sign": - partial_expr = "copysign(1, x1ᵢ) == copysign(1, x2ᵢ)" + partial_expr = "copysign(1, x1_i) == copysign(1, x2_i)" def partial_cond(i1: float, i2: float) -> bool: return math.copysign(1, i1) == math.copysign(1, i2) elif value_str == "different mathematical signs": - partial_expr = "copysign(1, x1ᵢ) != copysign(1, x2ᵢ)" + partial_expr = "copysign(1, x1_i) != copysign(1, x2_i)" def partial_cond(i1: float, i2: float) -> bool: return math.copysign(1, i1) != math.copysign(1, i2) @@ -524,21 +661,21 @@ def partial_cond(i1: float, i2: float) -> bool: input_wrapper = None if m := r_input.match(input_str): x_no = m.group(1) - partial_expr = expr_template.replace("{}", f"x{x_no}ᵢ") + partial_expr = expr_template.replace("{}", f"x{x_no}_i") cond_arg = BinaryCondArg.from_x_no(x_no) elif m := r_abs_input.match(input_str): x_no = m.group(1) - partial_expr = expr_template.replace("{}", f"abs(x{x_no}ᵢ)") + partial_expr = expr_template.replace("{}", f"abs(x{x_no}_i)") cond_arg = BinaryCondArg.from_x_no(x_no) input_wrapper = abs elif r_and_input.match(input_str): - left_expr = expr_template.replace("{}", "x1ᵢ") - right_expr = expr_template.replace("{}", "x2ᵢ") + left_expr = expr_template.replace("{}", "x1_i") + right_expr = expr_template.replace("{}", "x2_i") partial_expr = f"({left_expr}) and ({right_expr})" cond_arg = BinaryCondArg.BOTH elif r_or_input.match(input_str): - left_expr = expr_template.replace("{}", "x1ᵢ") - right_expr = expr_template.replace("{}", "x2ᵢ") + left_expr = expr_template.replace("{}", "x1_i") + right_expr = expr_template.replace("{}", "x2_i") partial_expr = f"({left_expr}) or ({right_expr})" cond_arg = BinaryCondArg.EITHER else: @@ -556,7 +693,7 @@ def partial_cond(i1: float, i2: float) -> bool: result_str = result_m.group(1) if m := r_array_element.match(result_str): sign, x_no = m.groups() - result_expr = f"{sign}x{x_no}ᵢ" + result_expr = f"{sign}x{x_no}_i" check_result = make_eq_input_check_result( # type: ignore BinaryCondArg.from_x_no(x_no), eq_neg=sign == "-" ) @@ -566,12 +703,12 @@ def partial_cond(i1: float, i2: float) -> bool: def check_result(i1: float, i2: float, result: float) -> bool: return _check_result(result) - expr = " and ".join(partial_exprs) + " -> " + result_expr + cond_expr = " and ".join(partial_exprs) def cond(i1: float, i2: float) -> bool: return all(pc(i1, i2) for pc in partial_conds) - return BinaryCase(expr, cond, check_result) + return BinaryCase(cond_expr, result_expr, cond, check_result) r_redundant_case = re.compile("result.+determined by the rule already stated above") @@ -628,16 +765,20 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: continue if param_names[0] == "x": if cases := parse_unary_docstring(stub.__doc__): - p = pytest.param(stub.__name__, func, cases, id=stub.__name__) - unary_params.append(p) + for case in cases: + id_ = f"{stub.__name__}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(stub.__name__, func, case, id=id_) + unary_params.append(p) continue if len(sig.parameters) == 1: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": - if cases := parse_binary_docstring(stub.__doc__): - p = pytest.param(stub.__name__, func, cases, id=stub.__name__) - binary_params.append(p) + # if cases := parse_binary_docstring(stub.__doc__): + # for case in cases: + # id_ = f"{stub.__name__}({case.cond_expr}) -> {case.result_expr}" + # p = pytest.param(stub.__name__, func, case, id=id_) + # binary_params.append(p) continue else: warn( @@ -652,50 +793,59 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: # indicating we should modify the array strategy being used. -@pytest.mark.parametrize("func_name, func, cases", unary_params) -@given(x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1))) -def test_unary(func_name, func, cases, x): +@pytest.mark.parametrize("func_name, func, case", unary_params) +@given( + x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + data=st.data(), +) +def test_unary(func_name, func, case, x, data): + set_idx = data.draw( + xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx" + ) + set_value = data.draw(case.cond_strat(x.dtype), label="set value") + x[set_idx] = set_value + note(f"{x=}") + res = func(x) + good_example = False for idx in sh.ndindex(res.shape): in_ = float(x[idx]) - for case in cases: - if case.cond(in_): - good_example = True - out = float(res[idx]) - f_in = f"{sh.fmt_idx('x', idx)}={in_}" - f_out = f"{sh.fmt_idx('out', idx)}={out}" - assert case.check_result( - in_, out - ), f"{f_out} not good [{func_name}()]\n{f_in}" - break + if case.cond(in_): + good_example = True + out = float(res[idx]) + f_in = f"{sh.fmt_idx('x', idx)}={in_}" + f_out = f"{sh.fmt_idx('out', idx)}={out}" + assert case.check_result(in_, out), ( + f"{f_out} not good [{func_name}()]\n" f"{case}\n" f"{f_in}" + ) + break assume(good_example) -@pytest.mark.parametrize("func_name, func, cases", binary_params) +@pytest.mark.parametrize("func_name, func, case", binary_params) @given( *hh.two_mutual_arrays( dtypes=dh.float_dtypes, two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1), ) ) -def test_binary(func_name, func, cases, x1, x2): +def test_binary(func_name, func, case, x1, x2): res = func(x1, x2) good_example = False for l_idx, r_idx, o_idx in sh.iter_indices(x1.shape, x2.shape, res.shape): l = float(x1[l_idx]) r = float(x2[r_idx]) - for case in cases: - if case.cond(l, r): - good_example = True - o = float(res[o_idx]) - f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" - f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" - f_out = f"{sh.fmt_idx('out', o_idx)}={o}" - assert case.check_result(l, r, o), ( - f"{f_out} not good [{func_name}()]\n" - f"{case.expr}\n" - f"{f_left}, {f_right}" - ) - break + if case.cond(l, r): + good_example = True + o = float(res[o_idx]) + f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" + f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" + f_out = f"{sh.fmt_idx('out', o_idx)}={o}" + assert case.check_result(l, r, o), ( + f"{f_out} not good [{func_name}()]\n" + f"{case.expr}\n" + f"{f_left}, {f_right}" + ) + break assume(good_example) From 9c7c05134d203faf39cac00c1f6397807f3d203b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 9 Mar 2022 16:07:06 +0000 Subject: [PATCH 31/63] Cond strategies for most binary cases --- array_api_tests/test_special_cases.py | 249 +++++++++++++++++--------- 1 file changed, 161 insertions(+), 88 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 3c797d68..2b3b0df0 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -12,7 +12,6 @@ Dict, List, Match, - NamedTuple, Optional, Protocol, Tuple, @@ -171,25 +170,21 @@ def parse_inline_code(inline_code: str) -> float: r_lt = re.compile(f"less than {r_code.pattern}") -FromDataType = Callable[[DataType], st.SearchStrategy] +class FromDtypeFunc(Protocol): + def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]: + ... -class ElementsStrategyFactory(NamedTuple): - from_dtype: FromDataType - kwargs: Optional[Dict[str, Any]] +@dataclass +class BoundFromDtype(FromDtypeFunc): + kwargs: Dict[str, Any] filter_: Optional[Callable[[Array], bool]] - def __add__(self, other: ElementsStrategyFactory) -> ElementsStrategyFactory: - assert not ( - isinstance(self.kwargs, Callable) or isinstance(other.kwargs, Callable) - ), ( - f"{self.kwargs=} and {other.kwargs=}, " "but both must be from_dtype kwargs" - ) - kwargs1 = self.kwargs or {} - kwargs2 = other.kwargs or {} - for k in kwargs1.keys(): - if k in kwargs2.keys(): - assert kwargs1[k] == kwargs2[k] # sanity check + def __add__(self, other: BoundFromDtype) -> BoundFromDtype: + for k in self.kwargs.keys(): + if k in other.kwargs.keys(): + assert self.kwargs[k] == other.kwargs[k] # sanity check + kwargs = {**self.kwargs, **other.kwargs} if self.filter_ is not None and other.filter_ is not None: filter_ = lambda i: self.filter_(i) and other.filter_(i) @@ -201,37 +196,43 @@ def __add__(self, other: ElementsStrategyFactory) -> ElementsStrategyFactory: except StopIteration: filter_ = None - return ElementsStrategyFactory( - kwargs={**kwargs1, **kwargs2}, - filter_=filter_, - ) + return BoundFromDtype(kwargs, filter_) - def to_strategy(self, dtype: DataType) -> st.SearchStrategy[float]: - kw = self.kwargs or {} - if self.from_dtype != xps.from_dtype: - assert kw == {} # sanity check - strat = self.from_dtype(dtype, **kw) + def __call__(self, dtype: DataType) -> st.SearchStrategy[float]: + strat = xps.from_dtype(dtype, **self.kwargs) if self.filter_ is not None: strat = strat.filter(self.filter_) return strat -def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, ElementsStrategyFactory]: +def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc: + def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: + assert kw == {} # sanity check + return strat + + return from_dtype + + +def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: + if "equal to" in cond_str: + raise ValueParseError(cond_str) # TODO + if m := r_not.match(cond_str): cond_str = m.group(1) not_cond = True else: not_cond = False - from_dtype = xps.from_dtype # type: ignore - kwargs = None + kwargs = {} filter_ = None + from_dtype = None # type: ignore + strat = None if m := r_code.match(cond_str): value = parse_value(m.group(1)) cond = make_eq(value) expr_template = "{} == " + m.group(1) if not not_cond: - from_dtype = lambda _: st.just(value) # type: ignore + strat = st.just(value) elif m := r_gt.match(cond_str): value = parse_value(m.group(1)) cond = make_gt(value) @@ -250,7 +251,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, ElementsStrategyFactory] cond = make_or(make_eq(v1), make_eq(v2)) expr_template = "{} == " + m.group(1) + " or {} == " + m.group(2) if not not_cond: - from_dtype = lambda _: st.sampled_from([v1, v2]) # type: ignore + strat = st.sampled_from([v1, v2]) elif cond_str in ["finite", "a finite number"]: cond = math.isfinite expr_template = "isfinite({})" @@ -286,7 +287,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, ElementsStrategyFactory] cond = lambda i: math.copysign(1, i) == -1 expr_template = "copysign(1, {}) == -1" if not not_cond: - # We assume (positive) zero is special cased seperately + # We assume (negative) zero is special cased seperately kwargs = {"max_value": 0, "exclude_max": True} elif "nonzero finite" in cond_str: cond = lambda i: math.isfinite(i) and i != 0 @@ -294,37 +295,41 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, ElementsStrategyFactory] if not not_cond: kwargs = {"allow_nan": False, "allow_infinity": False} filter_ = lambda n: n != 0 - elif cond_str == "an integer value": - cond = lambda i: i.is_integer() - expr_template = "{}.is_integer()" - if not not_cond: - - def from_dtype(dtype: DataType) -> st.SearchStrategy: - m, M = dh.dtype_ranges[dtype] - return st.integers(math.ceil(m), math.floor(M)).map(float) - - elif cond_str == "an odd integer value": - cond = lambda i: i.is_integer() and i % 2 == 1 - expr_template = "{}.is_integer() and {} % 2 == 1" - if not not_cond: - - def from_dtype(dtype: DataType) -> st.SearchStrategy: - m, M = dh.dtype_ranges[dtype] - return ( - st.integers(math.ceil(m), math.floor(M)) - .filter(lambda n: n % 2 == 1) - .map(float) - ) + elif "integer value" in cond_str: + raise ValueError( + "integer values are only specified in dual cases, " + "which cannot be handled in parse_cond()" + ) + # elif cond_str == "an integer value": + # cond = lambda i: i.is_integer() + # expr_template = "{}.is_integer()" + # if not not_cond: + # from_dtype = integers_from_dtype # type: ignore + # elif cond_str == "an odd integer value": + # cond = lambda i: i.is_integer() and i % 2 == 1 + # expr_template = "{}.is_integer() and {} % 2 == 1" + # if not not_cond: + # def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: + # return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1) else: raise ValueParseError(cond_str) + if strat is not None: + # sanity checks + assert not not_cond + assert kwargs == {} + assert filter_ is None + assert from_dtype is None + return cond, expr_template, wrap_strat_as_from_dtype(strat) + if not_cond: expr_template = f"not {expr_template}" cond = make_not_cond(cond) + kwargs = {} filter_ = cond - - return cond, expr_template, ElementsStrategyFactory(from_dtype, kwargs, filter_) + assert kwargs is not None + return cond, expr_template, BoundFromDtype(kwargs, filter_) def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: @@ -392,15 +397,14 @@ def __call__(self, i: float, result: float) -> bool: class UnaryCase(Case): cond_expr: str result_expr: str - cond_strat: FromDataType + cond_from_dtype: FromDtypeFunc cond: UnaryCheck check_result: UnaryResultCheck @classmethod def from_strings(cls, cond_str: str, result_str: str): - cond, cond_expr_template, strat_factory = parse_cond(cond_str) + cond, cond_expr_template, cond_from_dtype = parse_cond(cond_str) cond_expr = cond_expr_template.replace("{}", "x_i") - cond_strat = strat_factory.to_strategy _check_result, result_expr = parse_result(result_str) def check_result(i: float, result: float) -> bool: @@ -409,7 +413,7 @@ def check_result(i: float, result: float) -> bool: return cls( cond_expr=cond_expr, cond=cond, - cond_strat=cond_strat, + cond_from_dtype=cond_from_dtype, result_expr=result_expr, check_result=check_result, ) @@ -422,7 +426,7 @@ def check_result(i: float, result: float) -> bool: ) -def point_5_from_dtype(dtype: DataType): +def trailing_halves_from_dtype(dtype: DataType): m, M = dh.dtype_ranges[dtype] return st.integers(math.ceil(m) // 2, math.floor(M) // 2).map(lambda n: n * 0.5) @@ -430,7 +434,7 @@ def point_5_from_dtype(dtype: DataType): even_int_round_case = UnaryCase( cond_expr="i % 0.5 == 0", cond=lambda i: i % 0.5 == 0, - cond_strat=point_5_from_dtype, + cond_from_dtype=trailing_halves_from_dtype, result_expr="Decimal(i).to_integral_exact(ROUND_HALF_EVEN)", check_result=lambda i, result: ( result == float(Decimal(i).to_integral_exact(ROUND_HALF_EVEN)) @@ -479,6 +483,8 @@ def __call__(self, i1: float, i2: float, result: float) -> bool: class BinaryCase(Case): cond_expr: str result_expr: str + x1_cond_from_dtype: FromDtypeFunc + x2_cond_from_dtype: FromDtypeFunc cond: BinaryCond check_result: BinaryResultCheck @@ -491,23 +497,19 @@ class BinaryCase(Case): r_case = re.compile(r"\s+-\s*(.*)\.\n?") r_binary_case = re.compile("If (.+), the result (.+)") r_remaining_case = re.compile("In the remaining cases.+") - r_cond_sep = re.compile(r"(? float: return n +def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: + for k in kw.keys(): + # sanity check + assert k in ["min_value", "max_value", "exclude_min", "exclude_max"] + m, M = dh.dtype_ranges[dtype] + if "min_value" in kw.keys(): + m = kw["min_value"] + if "exclude_min" in kw.keys(): + m += 1 + if "max_value" in kw.keys(): + M = kw["max_value"] + if "exclude_max" in kw.keys(): + M -= 1 + return st.integers(math.ceil(m), math.floor(M)).map(float) + + def make_binary_cond( cond_arg: BinaryCondArg, unary_check: UnaryCheck, @@ -615,8 +633,14 @@ def check_result(i1: float, i2: float, result: float) -> bool: def parse_binary_case(case_m: Match) -> BinaryCase: cond_strs = r_cond_sep.split(case_m.group(1)) + + if len(cond_strs) > 2: + raise ValueParseError(", ".join(cond_strs)) + partial_conds = [] partial_exprs = [] + x1_cond_from_dtypes = [] + x2_cond_from_dtypes = [] for cond_str in cond_strs: if m := r_input_is_array_element.match(cond_str): in_sign, in_no, other_sign, other_no = m.groups() @@ -626,13 +650,15 @@ def parse_binary_case(case_m: Match) -> BinaryCase: BinaryCondArg.from_x_no(other_no), eq_neg=other_sign == "-" ) elif m := r_both_inputs_are_value.match(cond_str): - unary_cond, expr_template = parse_cond(m.group(1)) + unary_cond, expr_template, cond_from_dtype = parse_cond(m.group(1)) left_expr = expr_template.replace("{}", "x1_i") right_expr = expr_template.replace("{}", "x2_i") partial_expr = f"({left_expr}) and ({right_expr})" partial_cond = make_binary_cond( # type: ignore BinaryCondArg.BOTH, unary_cond ) + x1_cond_from_dtypes.append(cond_from_dtype) + x2_cond_from_dtypes.append(cond_from_dtype) else: cond_m = r_cond.match(cond_str) if cond_m is None: @@ -652,7 +678,7 @@ def partial_cond(i1: float, i2: float) -> bool: return math.copysign(1, i1) != math.copysign(1, i2) else: - unary_check, expr_template = parse_cond(value_str) + unary_check, expr_template, cond_from_dtype = parse_cond(value_str) # Do not define partial_cond via the def keyword, as one # partial_cond definition can mess up previous definitions # in the partial_conds list. This is a hard-limitation of @@ -683,6 +709,14 @@ def partial_cond(i1: float, i2: float) -> bool: partial_cond = make_binary_cond( # type: ignore cond_arg, unary_check, input_wrapper=input_wrapper ) + if cond_arg == BinaryCondArg.FIRST: + x1_cond_from_dtypes.append(cond_from_dtype) + elif cond_arg == BinaryCondArg.SECOND: + x2_cond_from_dtypes.append(cond_from_dtype) + else: + # TODO: xor scenarios + x1_cond_from_dtypes.append(cond_from_dtype) + x2_cond_from_dtypes.append(cond_from_dtype) partial_conds.append(partial_cond) partial_exprs.append(partial_expr) @@ -708,7 +742,31 @@ def check_result(i1: float, i2: float, result: float) -> bool: def cond(i1: float, i2: float) -> bool: return all(pc(i1, i2) for pc in partial_conds) - return BinaryCase(cond_expr, result_expr, cond, check_result) + if len(x1_cond_from_dtypes) == 0: + x1_cond_from_dtype = xps.from_dtype + elif len(x1_cond_from_dtypes) == 1: + x1_cond_from_dtype = x1_cond_from_dtypes[0] + else: + # sanity check + assert all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes) + x1_cond_from_dtype = sum(x1_cond_from_dtypes) + if len(x2_cond_from_dtypes) == 0: + x2_cond_from_dtype = xps.from_dtype + elif len(x2_cond_from_dtypes) == 1: + x2_cond_from_dtype = x2_cond_from_dtypes[0] + else: + # sanity check + assert all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes) + x2_cond_from_dtype = sum(x2_cond_from_dtypes) + + return BinaryCase( + cond_expr=cond_expr, + cond=cond, + x1_cond_from_dtype=x1_cond_from_dtype, + x2_cond_from_dtype=x2_cond_from_dtype, + result_expr=result_expr, + check_result=check_result, + ) r_redundant_case = re.compile("result.+determined by the rule already stated above") @@ -774,11 +832,11 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": - # if cases := parse_binary_docstring(stub.__doc__): - # for case in cases: - # id_ = f"{stub.__name__}({case.cond_expr}) -> {case.result_expr}" - # p = pytest.param(stub.__name__, func, case, id=id_) - # binary_params.append(p) + if cases := parse_binary_docstring(stub.__doc__): + for case in cases: + id_ = f"{stub.__name__}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(stub.__name__, func, case, id=id_) + binary_params.append(p) continue else: warn( @@ -802,7 +860,7 @@ def test_unary(func_name, func, case, x, data): set_idx = data.draw( xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx" ) - set_value = data.draw(case.cond_strat(x.dtype), label="set value") + set_value = data.draw(case.cond_from_dtype(x.dtype), label="set value") x[set_idx] = set_value note(f"{x=}") @@ -823,17 +881,34 @@ def test_unary(func_name, func, case, x, data): assume(good_example) -@pytest.mark.parametrize("func_name, func, case", binary_params) -@given( - *hh.two_mutual_arrays( - dtypes=dh.float_dtypes, - two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1), - ) +x1_strat, x2_strat = hh.two_mutual_arrays( + dtypes=dh.float_dtypes, + two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1), ) -def test_binary(func_name, func, case, x1, x2): + + +@pytest.mark.parametrize("func_name, func, case", binary_params) +@given(x1=x1_strat, x2=x2_strat, data=st.data()) +def test_binary(func_name, func, case, x1, x2, data): + result_shape = sh.broadcast_shapes(x1.shape, x2.shape) + all_indices = list(sh.iter_indices(x1.shape, x2.shape, result_shape)) + + indices_strat = st.shared(st.sampled_from(all_indices)) + set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx") + set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") + set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value") + set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") + x1[set_x1_idx] = set_x1_value + note(f"{x1=}") + x2[set_x2_idx] = set_x2_value + note(f"{x2=}") + res = func(x1, x2) + # sanity check + ph.assert_result_shape(func_name, [x1.shape, x2.shape], res.shape, result_shape) + good_example = False - for l_idx, r_idx, o_idx in sh.iter_indices(x1.shape, x2.shape, res.shape): + for l_idx, r_idx, o_idx in all_indices: l = float(x1[l_idx]) r = float(x2[r_idx]) if case.cond(l, r): @@ -843,9 +918,7 @@ def test_binary(func_name, func, case, x1, x2): f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" f_out = f"{sh.fmt_idx('out', o_idx)}={o}" assert case.check_result(l, r, o), ( - f"{f_out} not good [{func_name}()]\n" - f"{case.expr}\n" - f"{f_left}, {f_right}" + f"{f_out} not good [{func_name}()]\n" f"{case}\n" f"{f_left}, {f_right}" ) break assume(good_example) From 420a5498f30b4c66447fcc29cf3060e92e704bb6 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 10 Mar 2022 17:24:15 +0000 Subject: [PATCH 32/63] Cover binary cases with two unary conds for one array --- array_api_tests/test_special_cases.py | 130 +++++++++++++------------- 1 file changed, 64 insertions(+), 66 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 2b3b0df0..a54b7dba 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -6,16 +6,7 @@ from dataclasses import dataclass from decimal import ROUND_HALF_EVEN, Decimal from enum import Enum, auto -from typing import ( - Any, - Callable, - Dict, - List, - Match, - Optional, - Protocol, - Tuple, -) +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple from warnings import warn import pytest @@ -178,7 +169,8 @@ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]: @dataclass class BoundFromDtype(FromDtypeFunc): kwargs: Dict[str, Any] - filter_: Optional[Callable[[Array], bool]] + filter_: Optional[Callable[[Array], bool]] = None + base_func: Optional[FromDtypeFunc] = None def __add__(self, other: BoundFromDtype) -> BoundFromDtype: for k in self.kwargs.keys(): @@ -189,17 +181,28 @@ def __add__(self, other: BoundFromDtype) -> BoundFromDtype: if self.filter_ is not None and other.filter_ is not None: filter_ = lambda i: self.filter_(i) and other.filter_(i) else: - try: - filter_ = next( - f for f in [self.filter_, other.filter_] if f is not None - ) - except StopIteration: + if self.filter_ is not None: + filter_ = self.filter_ + elif other.filter_ is not None: + filter_ = other.filter_ + else: filter_ = None - return BoundFromDtype(kwargs, filter_) + # sanity check + assert not (self.base_func is not None and other.base_func is not None) + if self.base_func is not None: + base_func = self.base_func + elif other.base_func is not None: + base_func = other.base_func + else: + base_func = None + + return BoundFromDtype(kwargs, filter_, base_func) - def __call__(self, dtype: DataType) -> st.SearchStrategy[float]: - strat = xps.from_dtype(dtype, **self.kwargs) + def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]: + assert len(kw) == 0 # sanity check + from_dtype = self.base_func or xps.from_dtype + strat = from_dtype(dtype, **self.kwargs) if self.filter_ is not None: strat = strat.filter(self.filter_) return strat @@ -295,22 +298,18 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: if not not_cond: kwargs = {"allow_nan": False, "allow_infinity": False} filter_ = lambda n: n != 0 - elif "integer value" in cond_str: - raise ValueError( - "integer values are only specified in dual cases, " - "which cannot be handled in parse_cond()" - ) - # elif cond_str == "an integer value": - # cond = lambda i: i.is_integer() - # expr_template = "{}.is_integer()" - # if not not_cond: - # from_dtype = integers_from_dtype # type: ignore - # elif cond_str == "an odd integer value": - # cond = lambda i: i.is_integer() and i % 2 == 1 - # expr_template = "{}.is_integer() and {} % 2 == 1" - # if not not_cond: - # def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: - # return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1) + elif cond_str == "an integer value": + cond = lambda i: i.is_integer() + expr_template = "{}.is_integer()" + if not not_cond: + from_dtype = integers_from_dtype # type: ignore + elif cond_str == "an odd integer value": + cond = lambda i: i.is_integer() and i % 2 == 1 + expr_template = "{}.is_integer() and {} % 2 == 1" + if not not_cond: + + def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: + return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1) else: raise ValueParseError(cond_str) @@ -329,7 +328,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: kwargs = {} filter_ = cond assert kwargs is not None - return cond, expr_template, BoundFromDtype(kwargs, filter_) + return cond, expr_template, BoundFromDtype(kwargs, filter_, from_dtype) def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: @@ -531,25 +530,9 @@ def noop(n: float) -> float: return n -def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: - for k in kw.keys(): - # sanity check - assert k in ["min_value", "max_value", "exclude_min", "exclude_max"] - m, M = dh.dtype_ranges[dtype] - if "min_value" in kw.keys(): - m = kw["min_value"] - if "exclude_min" in kw.keys(): - m += 1 - if "max_value" in kw.keys(): - M = kw["max_value"] - if "exclude_max" in kw.keys(): - M -= 1 - return st.integers(math.ceil(m), math.floor(M)).map(float) - - def make_binary_cond( cond_arg: BinaryCondArg, - unary_check: UnaryCheck, + unary_cond: UnaryCheck, *, input_wrapper: Optional[Callable[[float], float]] = None, ) -> BinaryCond: @@ -559,22 +542,22 @@ def make_binary_cond( if cond_arg == BinaryCondArg.FIRST: def partial_cond(i1: float, i2: float) -> bool: - return unary_check(input_wrapper(i1)) + return unary_cond(input_wrapper(i1)) elif cond_arg == BinaryCondArg.SECOND: def partial_cond(i1: float, i2: float) -> bool: - return unary_check(input_wrapper(i2)) + return unary_cond(input_wrapper(i2)) elif cond_arg == BinaryCondArg.BOTH: def partial_cond(i1: float, i2: float) -> bool: - return unary_check(input_wrapper(i1)) and unary_check(input_wrapper(i2)) + return unary_cond(input_wrapper(i1)) and unary_cond(input_wrapper(i2)) else: def partial_cond(i1: float, i2: float) -> bool: - return unary_check(input_wrapper(i1)) or unary_check(input_wrapper(i2)) + return unary_cond(input_wrapper(i1)) or unary_cond(input_wrapper(i2)) return partial_cond @@ -631,11 +614,26 @@ def check_result(i1: float, i2: float, result: float) -> bool: return check_result -def parse_binary_case(case_m: Match) -> BinaryCase: - cond_strs = r_cond_sep.split(case_m.group(1)) +def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: + for k in kw.keys(): + # sanity check + assert k in ["min_value", "max_value", "exclude_min", "exclude_max"] + m, M = dh.dtype_ranges[dtype] + if "min_value" in kw.keys(): + m = kw["min_value"] + if "exclude_min" in kw.keys(): + m += 1 + if "max_value" in kw.keys(): + M = kw["max_value"] + if "exclude_max" in kw.keys(): + M -= 1 + return st.integers(math.ceil(m), math.floor(M)).map(float) - if len(cond_strs) > 2: - raise ValueParseError(", ".join(cond_strs)) + +def parse_binary_case(case_str: str) -> BinaryCase: + case_m = r_binary_case.match(case_str) + assert case_m is not None # sanity check + cond_strs = r_cond_sep.split(case_m.group(1)) partial_conds = [] partial_exprs = [] @@ -678,7 +676,7 @@ def partial_cond(i1: float, i2: float) -> bool: return math.copysign(1, i1) != math.copysign(1, i2) else: - unary_check, expr_template, cond_from_dtype = parse_cond(value_str) + unary_cond, expr_template, cond_from_dtype = parse_cond(value_str) # Do not define partial_cond via the def keyword, as one # partial_cond definition can mess up previous definitions # in the partial_conds list. This is a hard-limitation of @@ -707,7 +705,7 @@ def partial_cond(i1: float, i2: float) -> bool: else: raise ValueParseError(input_str) partial_cond = make_binary_cond( # type: ignore - cond_arg, unary_check, input_wrapper=input_wrapper + cond_arg, unary_cond, input_wrapper=input_wrapper ) if cond_arg == BinaryCondArg.FIRST: x1_cond_from_dtypes.append(cond_from_dtype) @@ -749,7 +747,7 @@ def cond(i1: float, i2: float) -> bool: else: # sanity check assert all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes) - x1_cond_from_dtype = sum(x1_cond_from_dtypes) + x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype({}, None)) if len(x2_cond_from_dtypes) == 0: x2_cond_from_dtype = xps.from_dtype elif len(x2_cond_from_dtypes) == 1: @@ -757,7 +755,7 @@ def cond(i1: float, i2: float) -> bool: else: # sanity check assert all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes) - x2_cond_from_dtype = sum(x2_cond_from_dtypes) + x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype({}, None)) return BinaryCase( cond_expr=cond_expr, @@ -788,7 +786,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: continue if m := r_binary_case.match(case_str): try: - case = parse_binary_case(m) + case = parse_binary_case(case_str) cases.append(case) except ValueParseError as e: warn(f"not machine-readable: '{e.value}'") From e6f95481c173c733714309bd5ca6abc654178164 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 10 Mar 2022 17:34:18 +0000 Subject: [PATCH 33/63] Cover "equal to" cases (as opposed to "is" cases) --- array_api_tests/test_special_cases.py | 33 +++++++++++++++------------ 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index a54b7dba..f79224a6 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -29,7 +29,7 @@ BinaryCheck = Callable[[float, float], bool] -def make_eq(v: float) -> UnaryCheck: +def make_strict_eq(v: float) -> UnaryCheck: if math.isnan(v): return math.isnan if v == 0: @@ -38,14 +38,14 @@ def make_eq(v: float) -> UnaryCheck: else: return ph.is_neg_zero - def eq(i: float) -> bool: + def strict_eq(i: float) -> bool: return i == v - return eq + return strict_eq def make_neq(v: float) -> UnaryCheck: - eq = make_eq(v) + eq = make_strict_eq(v) def neq(i: float) -> bool: return not eq(i) @@ -154,7 +154,8 @@ def parse_inline_code(inline_code: str) -> float: raise ValueParseError(inline_code) -r_not = re.compile("not (?:equal to )?(.+)") +r_not = re.compile("not (.+)") +r_equal_to = re.compile(f"equal to {r_code.pattern}") r_array_element = re.compile(r"``([+-]?)x([12])_i``") r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}") r_gt = re.compile(f"greater than {r_code.pattern}") @@ -217,9 +218,6 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: - if "equal to" in cond_str: - raise ValueParseError(cond_str) # TODO - if m := r_not.match(cond_str): cond_str = m.group(1) not_cond = True @@ -232,10 +230,15 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: strat = None if m := r_code.match(cond_str): value = parse_value(m.group(1)) - cond = make_eq(value) + cond = make_strict_eq(value) expr_template = "{} == " + m.group(1) if not not_cond: strat = st.just(value) + elif m := r_equal_to.match(cond_str): + value = parse_value(m.group(1)) + assert not math.isnan(value) # sanity check + cond = lambda i: i == value + expr_template = "{} == " + m.group(1) elif m := r_gt.match(cond_str): value = parse_value(m.group(1)) cond = make_gt(value) @@ -251,7 +254,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: elif m := r_either_code.match(cond_str): v1 = parse_value(m.group(1)) v2 = parse_value(m.group(2)) - cond = make_or(make_eq(v1), make_eq(v2)) + cond = make_or(make_strict_eq(v1), make_strict_eq(v2)) expr_template = "{} == " + m.group(1) + " or {} == " + m.group(2) if not not_cond: strat = st.sampled_from([v1, v2]) @@ -334,7 +337,7 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: if m := r_code.match(result_str): value = parse_value(m.group(1)) - check_result = make_eq(value) # type: ignore + check_result = make_strict_eq(value) # type: ignore expr = m.group(1) elif m := r_approx_value.match(result_str): value = parse_value(m.group(1)) @@ -573,13 +576,13 @@ def make_eq_other_input_cond( if eq_to == BinaryCondArg.FIRST: def cond(i1: float, i2: float) -> bool: - eq = make_eq(input_wrapper(i1)) + eq = make_strict_eq(input_wrapper(i1)) return eq(i2) elif eq_to == BinaryCondArg.SECOND: def cond(i1: float, i2: float) -> bool: - eq = make_eq(input_wrapper(i2)) + eq = make_strict_eq(input_wrapper(i2)) return eq(i1) else: @@ -599,13 +602,13 @@ def make_eq_input_check_result( if eq_to == BinaryCondArg.FIRST: def check_result(i1: float, i2: float, result: float) -> bool: - eq = make_eq(input_wrapper(i1)) + eq = make_strict_eq(input_wrapper(i1)) return eq(result) elif eq_to == BinaryCondArg.SECOND: def check_result(i1: float, i2: float, result: float) -> bool: - eq = make_eq(input_wrapper(i2)) + eq = make_strict_eq(input_wrapper(i2)) return eq(result) else: From c5a30d09955744c01be6bc6a375b33e7d5ad4062 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 11 Mar 2022 09:48:22 +0000 Subject: [PATCH 34/63] Test xor scenarios for either special cases --- array_api_tests/test_special_cases.py | 32 ++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index f79224a6..ecdc2eec 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -714,10 +714,30 @@ def partial_cond(i1: float, i2: float) -> bool: x1_cond_from_dtypes.append(cond_from_dtype) elif cond_arg == BinaryCondArg.SECOND: x2_cond_from_dtypes.append(cond_from_dtype) - else: - # TODO: xor scenarios + elif cond_arg == BinaryCondArg.BOTH: x1_cond_from_dtypes.append(cond_from_dtype) x2_cond_from_dtypes.append(cond_from_dtype) + else: + use_x1_or_x2_strat = st.shared( + st.sampled_from([(True, False), (True, False), (True, True)]) + ) + + def x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: + return use_x1_or_x2_strat.flatmap( + lambda t: cond_from_dtype(dtype) + if t[0] + else xps.from_dtype(dtype) + ) + + def x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: + return use_x1_or_x2_strat.flatmap( + lambda t: cond_from_dtype(dtype) + if t[1] + else xps.from_dtype(dtype) + ) + + x1_cond_from_dtypes.append(x1_cond_from_dtype) + x2_cond_from_dtypes.append(x2_cond_from_dtype) partial_conds.append(partial_cond) partial_exprs.append(partial_expr) @@ -750,7 +770,7 @@ def cond(i1: float, i2: float) -> bool: else: # sanity check assert all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes) - x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype({}, None)) + x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype({})) if len(x2_cond_from_dtypes) == 0: x2_cond_from_dtype = xps.from_dtype elif len(x2_cond_from_dtypes) == 1: @@ -758,7 +778,7 @@ def cond(i1: float, i2: float) -> bool: else: # sanity check assert all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes) - x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype({}, None)) + x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype({})) return BinaryCase( cond_expr=cond_expr, @@ -896,11 +916,11 @@ def test_binary(func_name, func, case, x1, x2, data): indices_strat = st.shared(st.sampled_from(all_indices)) set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx") - set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value") - set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") x1[set_x1_idx] = set_x1_value note(f"{x1=}") + set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") + set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") x2[set_x2_idx] = set_x2_value note(f"{x2=}") From 2217bb9691ebe831218763e5fcfd579082a0736f Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 11 Mar 2022 11:48:49 +0000 Subject: [PATCH 35/63] Comment on `__future__` and either case strategies --- array_api_tests/test_special_cases.py | 44 ++++++++++++++++----------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index ecdc2eec..643c5bcb 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1,9 +1,11 @@ +# We use __future__ for forward reference type hints - this will work for even py3.8.0 +# See https://stackoverflow.com/a/33533514/5193926 from __future__ import annotations import inspect import math import re -from dataclasses import dataclass +from dataclasses import dataclass, field from decimal import ROUND_HALF_EVEN, Decimal from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple @@ -169,7 +171,7 @@ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]: @dataclass class BoundFromDtype(FromDtypeFunc): - kwargs: Dict[str, Any] + kwargs: Dict[str, Any] = field(default_factory=dict) filter_: Optional[Callable[[Array], bool]] = None base_func: Optional[FromDtypeFunc] = None @@ -718,26 +720,38 @@ def partial_cond(i1: float, i2: float) -> bool: x1_cond_from_dtypes.append(cond_from_dtype) x2_cond_from_dtypes.append(cond_from_dtype) else: + # For "either x1_i or x2_i is " cases, we want to + # test three scenarios: + # + # 1. x1_i is + # 2. x2_i is + # 3. x1_i AND x2_i is + # + # This is achieved by a shared base strategy that picks one + # of these scenarios to determine whether each array will + # use either cond_from_dtype() (i.e. meet the condition), or + # simply xps.from_dtype() (i.e. be any value). + use_x1_or_x2_strat = st.shared( - st.sampled_from([(True, False), (True, False), (True, True)]) + st.sampled_from([(True, False), (False, True), (True, True)]) ) - def x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: + def _x1_cond_from_dtype(dtype) -> st.SearchStrategy[float]: return use_x1_or_x2_strat.flatmap( lambda t: cond_from_dtype(dtype) if t[0] else xps.from_dtype(dtype) ) - def x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: + def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]: return use_x1_or_x2_strat.flatmap( lambda t: cond_from_dtype(dtype) if t[1] else xps.from_dtype(dtype) ) - x1_cond_from_dtypes.append(x1_cond_from_dtype) - x2_cond_from_dtypes.append(x2_cond_from_dtype) + x1_cond_from_dtypes.append(_x1_cond_from_dtype) + x2_cond_from_dtypes.append(_x2_cond_from_dtype) partial_conds.append(partial_cond) partial_exprs.append(partial_expr) @@ -768,17 +782,17 @@ def cond(i1: float, i2: float) -> bool: elif len(x1_cond_from_dtypes) == 1: x1_cond_from_dtype = x1_cond_from_dtypes[0] else: - # sanity check - assert all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes) - x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype({})) + if not all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes): + raise ValueParseError(case_str) + x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype()) if len(x2_cond_from_dtypes) == 0: x2_cond_from_dtype = xps.from_dtype elif len(x2_cond_from_dtypes) == 1: x2_cond_from_dtype = x2_cond_from_dtypes[0] else: - # sanity check - assert all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes) - x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype({})) + if not all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes): + raise ValueParseError(case_str) + x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype()) return BinaryCase( cond_expr=cond_expr, @@ -819,10 +833,6 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: return cases -# Here be the tests -# ------------------------------------------------------------------------------ - - unary_params = [] binary_params = [] for stub in category_to_funcs["elementwise"]: From 4372e905b42f24c2d8990b458d607363f77e271e Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 11 Mar 2022 13:08:24 +0000 Subject: [PATCH 36/63] Update NumPy workflow for new special cases --- .github/workflows/numpy.yml | 7 +++++++ array_api_tests/test_special_cases.py | 4 +++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml index a6118330..04090923 100644 --- a/.github/workflows/numpy.yml +++ b/.github/workflows/numpy.yml @@ -43,6 +43,13 @@ jobs: # waiting on NumPy to allow/revert distinct NaNs for np.unique # https://github.com/numpy/numpy/issues/20326#issuecomment-1012380448 array_api_tests/test_set_functions.py + # noted diversions from spec + array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i == +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] + array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i == +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] + array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i == -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] + array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i == -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] + array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i == -infinity) -> -0] + array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i == +infinity) -> -0] EOF diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 643c5bcb..9c9ace88 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -705,7 +705,9 @@ def partial_cond(i1: float, i2: float) -> bool: elif r_or_input.match(input_str): left_expr = expr_template.replace("{}", "x1_i") right_expr = expr_template.replace("{}", "x2_i") - partial_expr = f"({left_expr}) or ({right_expr})" + partial_expr = f"{left_expr} or {right_expr}" + if len(cond_strs) != 1: + partial_expr = f"({partial_expr})" cond_arg = BinaryCondArg.EITHER else: raise ValueParseError(input_str) From 6c775a73444b10929d78c8799c148d0782c42d83 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Mar 2022 10:16:20 +0000 Subject: [PATCH 37/63] Generate x1 is x2 conds (and visa versa) healthily --- array_api_tests/test_special_cases.py | 62 ++++++++++++++------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 9c9ace88..ed315b63 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -567,32 +567,6 @@ def partial_cond(i1: float, i2: float) -> bool: return partial_cond -def make_eq_other_input_cond( - eq_to: BinaryCondArg, *, eq_neg: bool = False -) -> BinaryCond: - if eq_neg: - input_wrapper = lambda i: -i - else: - input_wrapper = noop - - if eq_to == BinaryCondArg.FIRST: - - def cond(i1: float, i2: float) -> bool: - eq = make_strict_eq(input_wrapper(i1)) - return eq(i2) - - elif eq_to == BinaryCondArg.SECOND: - - def cond(i1: float, i2: float) -> bool: - eq = make_strict_eq(input_wrapper(i2)) - return eq(i1) - - else: - raise ValueError(f"{eq_to=} must be FIRST or SECOND") - - return cond - - def make_eq_input_check_result( eq_to: BinaryCondArg, *, eq_neg: bool = False ) -> BinaryResultCheck: @@ -616,8 +590,6 @@ def check_result(i1: float, i2: float, result: float) -> bool: else: raise ValueError(f"{eq_to=} must be FIRST or SECOND") - return check_result - def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: for k in kw.keys(): @@ -649,9 +621,39 @@ def parse_binary_case(case_str: str) -> BinaryCase: in_sign, in_no, other_sign, other_no = m.groups() assert in_sign == "" and other_no != in_no # sanity check partial_expr = f"{in_sign}x{in_no}_i == {other_sign}x{other_no}_i" - partial_cond = make_eq_other_input_cond( # type: ignore - BinaryCondArg.from_x_no(other_no), eq_neg=other_sign == "-" + input_wrapper = lambda i: -i if other_sign == "-" else noop + shared_from_dtype = lambda d, **kw: st.shared( + xps.from_dtype(d, **kw), key=cond_str ) + + if other_no == "1": + + def partial_cond(i1: float, i2: float) -> bool: + eq = make_strict_eq(input_wrapper(i1)) + return eq(i2) + + _x2_cond_from_dtype = shared_from_dtype # type: ignore + + def _x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: + return shared_from_dtype(dtype, **kw).map(input_wrapper) + + elif other_no == "2": + + def partial_cond(i1: float, i2: float) -> bool: + eq = make_strict_eq(input_wrapper(i2)) + return eq(i1) + + _x1_cond_from_dtype = shared_from_dtype # type: ignore + + def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: + return shared_from_dtype(dtype, **kw).map(input_wrapper) + + else: + raise ValueParseError(cond_str) + + x1_cond_from_dtypes.append(BoundFromDtype(base_func=_x1_cond_from_dtype)) + x2_cond_from_dtypes.append(BoundFromDtype(base_func=_x2_cond_from_dtype)) + elif m := r_both_inputs_are_value.match(cond_str): unary_cond, expr_template, cond_from_dtype = parse_cond(m.group(1)) left_expr = expr_template.replace("{}", "x1_i") From 6d8c3ca1bffc5315651b22257d86259dfc7e6d69 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Mar 2022 10:26:37 +0000 Subject: [PATCH 38/63] Favour `ParseError` to assertions --- array_api_tests/meta/test_special_cases.py | 6 +- array_api_tests/test_special_cases.py | 113 ++++++++++++++------- 2 files changed, 83 insertions(+), 36 deletions(-) diff --git a/array_api_tests/meta/test_special_cases.py b/array_api_tests/meta/test_special_cases.py index 1b8c8358..826e5969 100644 --- a/array_api_tests/meta/test_special_cases.py +++ b/array_api_tests/meta/test_special_cases.py @@ -4,5 +4,7 @@ def test_parse_result(): - s_result = "an implementation-dependent approximation to ``+3π/4``" - assert parse_result(s_result).value == 3 * math.pi / 4 + check_result, _ = parse_result( + "an implementation-dependent approximation to ``+3π/4``" + ) + assert check_result(3 * math.pi / 4) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index ed315b63..253abda3 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -27,6 +27,13 @@ pytestmark = pytest.mark.ci +# The special case test casess are built on runtime via the parametrized +# test_unary and test_binary functions. Most of this file consists of utility +# classes and functions, all bought together to create the test cases (pytest +# params), to finally be run through the general test logic of either test_unary +# or test_binary. + + UnaryCheck = Callable[[float], bool] BinaryCheck = Callable[[float, float], bool] @@ -46,13 +53,13 @@ def strict_eq(i: float) -> bool: return strict_eq -def make_neq(v: float) -> UnaryCheck: - eq = make_strict_eq(v) +def make_strict_neq(v: float) -> UnaryCheck: + strict_eq = make_strict_eq(v) - def neq(i: float) -> bool: - return not eq(i) + def strict_neq(i: float) -> bool: + return not strict_eq(i) - return neq + return strict_neq def make_rough_eq(v: float) -> UnaryCheck: @@ -121,14 +128,25 @@ def abs_cond(i: float) -> bool: @dataclass -class ValueParseError(ValueError): +class ParseError(ValueError): value: str def parse_value(value_str: str) -> float: + """ + Parse a value string to return a float, e.g. + + >>> parse_value('1') + 1. + >>> parse_value('-infinity') + -float('inf') + >>> parse_value('3π/4') + 2.356194490192345 + + """ m = r_value.match(value_str) if m is None: - raise ValueParseError(value_str) + raise ParseError(value_str) if pi_m := r_pi.match(m.group(2)): value = math.pi if numerator := pi_m.group(1): @@ -150,10 +168,19 @@ def parse_value(value_str: str) -> float: def parse_inline_code(inline_code: str) -> float: + """ + Parse a Sphinx code string to return a float, e.g. + + >>> parse_value('``0``') + 0. + >>> parse_value('``NaN``') + float('nan') + + """ if m := r_code.match(inline_code): return parse_value(m.group(1)) else: - raise ValueParseError(inline_code) + raise ParseError(inline_code) r_not = re.compile("not (.+)") @@ -165,16 +192,37 @@ def parse_inline_code(inline_code: str) -> float: class FromDtypeFunc(Protocol): + """ + Type hint for functions that return an elements strategy for arrays of the + given dtype, e.g. xps.from_dtype(). + """ + def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]: ... @dataclass class BoundFromDtype(FromDtypeFunc): + """ + A callable which bounds kwargs and strategy filters to xps.from_dtype() or + equivalent function. + + + + """ + kwargs: Dict[str, Any] = field(default_factory=dict) filter_: Optional[Callable[[Array], bool]] = None base_func: Optional[FromDtypeFunc] = None + def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]: + assert len(kw) == 0 # sanity check + from_dtype = self.base_func or xps.from_dtype + strat = from_dtype(dtype, **self.kwargs) + if self.filter_ is not None: + strat = strat.filter(self.filter_) + return strat + def __add__(self, other: BoundFromDtype) -> BoundFromDtype: for k in self.kwargs.keys(): if k in other.kwargs.keys(): @@ -202,14 +250,6 @@ def __add__(self, other: BoundFromDtype) -> BoundFromDtype: return BoundFromDtype(kwargs, filter_, base_func) - def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]: - assert len(kw) == 0 # sanity check - from_dtype = self.base_func or xps.from_dtype - strat = from_dtype(dtype, **self.kwargs) - if self.filter_ is not None: - strat = strat.filter(self.filter_) - return strat - def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc: def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: @@ -238,7 +278,8 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: strat = st.just(value) elif m := r_equal_to.match(cond_str): value = parse_value(m.group(1)) - assert not math.isnan(value) # sanity check + if math.isnan(value): + raise ParseError(cond_str) cond = lambda i: i == value expr_template = "{} == " + m.group(1) elif m := r_gt.match(cond_str): @@ -317,14 +358,16 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1) else: - raise ValueParseError(cond_str) + raise ParseError(cond_str) if strat is not None: - # sanity checks - assert not not_cond - assert kwargs == {} - assert filter_ is None - assert from_dtype is None + if ( + not_cond + or len(kwargs) != 0 + or filter_ is not None + or from_dtype is not None + ): + raise ParseError(cond_str) return cond, expr_template, wrap_strat_as_from_dtype(strat) if not_cond: @@ -365,7 +408,7 @@ def check_result(result: float) -> bool: expr = "-" else: - raise ValueParseError(result_str) + raise ParseError(result_str) return check_result, expr @@ -461,7 +504,7 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: if m := r_unary_case.search(case): try: case = UnaryCase.from_strings(*m.groups()) - except ValueParseError as e: + except ParseError as e: warn(f"not machine-readable: '{e.value}'") continue cases.append(case) @@ -609,7 +652,8 @@ def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: def parse_binary_case(case_str: str) -> BinaryCase: case_m = r_binary_case.match(case_str) - assert case_m is not None # sanity check + if case_m is None: + raise ParseError(case_str) cond_strs = r_cond_sep.split(case_m.group(1)) partial_conds = [] @@ -619,7 +663,8 @@ def parse_binary_case(case_str: str) -> BinaryCase: for cond_str in cond_strs: if m := r_input_is_array_element.match(cond_str): in_sign, in_no, other_sign, other_no = m.groups() - assert in_sign == "" and other_no != in_no # sanity check + if in_sign != "" or other_no == in_no: + raise ParseError(cond_str) partial_expr = f"{in_sign}x{in_no}_i == {other_sign}x{other_no}_i" input_wrapper = lambda i: -i if other_sign == "-" else noop shared_from_dtype = lambda d, **kw: st.shared( @@ -649,7 +694,7 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: return shared_from_dtype(dtype, **kw).map(input_wrapper) else: - raise ValueParseError(cond_str) + raise ParseError(cond_str) x1_cond_from_dtypes.append(BoundFromDtype(base_func=_x1_cond_from_dtype)) x2_cond_from_dtypes.append(BoundFromDtype(base_func=_x2_cond_from_dtype)) @@ -667,7 +712,7 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: else: cond_m = r_cond.match(cond_str) if cond_m is None: - raise ValueParseError(cond_str) + raise ParseError(cond_str) input_str, value_str = cond_m.groups() if value_str == "the same mathematical sign": @@ -712,7 +757,7 @@ def partial_cond(i1: float, i2: float) -> bool: partial_expr = f"({partial_expr})" cond_arg = BinaryCondArg.EITHER else: - raise ValueParseError(input_str) + raise ParseError(input_str) partial_cond = make_binary_cond( # type: ignore cond_arg, unary_cond, input_wrapper=input_wrapper ) @@ -762,7 +807,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]: result_m = r_result.match(case_m.group(2)) if result_m is None: - raise ValueParseError(case_m.group(2)) + raise ParseError(case_m.group(2)) result_str = result_m.group(1) if m := r_array_element.match(result_str): sign, x_no = m.groups() @@ -787,7 +832,7 @@ def cond(i1: float, i2: float) -> bool: x1_cond_from_dtype = x1_cond_from_dtypes[0] else: if not all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes): - raise ValueParseError(case_str) + raise ParseError(case_str) x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype()) if len(x2_cond_from_dtypes) == 0: x2_cond_from_dtype = xps.from_dtype @@ -795,7 +840,7 @@ def cond(i1: float, i2: float) -> bool: x2_cond_from_dtype = x2_cond_from_dtypes[0] else: if not all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes): - raise ValueParseError(case_str) + raise ParseError(case_str) x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype()) return BinaryCase( @@ -829,7 +874,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: try: case = parse_binary_case(case_str) cases.append(case) - except ValueParseError as e: + except ParseError as e: warn(f"not machine-readable: '{e.value}'") else: if not r_remaining_case.match(case_str): From 63d63e1e62b25d27a105cdc2cc10a2272fa4811e Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Mar 2022 11:06:17 +0000 Subject: [PATCH 39/63] Note check_result def problems --- array_api_tests/test_special_cases.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 253abda3..9ce0de0f 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -633,6 +633,15 @@ def check_result(i1: float, i2: float, result: float) -> bool: else: raise ValueError(f"{eq_to=} must be FIRST or SECOND") + return check_result + + +def make_check_result(check_just_result: UnaryCheck) -> BinaryResultCheck: + def check_result(i1: float, i2: float, result: float) -> bool: + return check_just_result(result) + + return check_result + def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: for k in kw.keys(): @@ -733,7 +742,8 @@ def partial_cond(i1: float, i2: float) -> bool: # partial_cond definition can mess up previous definitions # in the partial_conds list. This is a hard-limitation of # using local functions with the same name and that use the same - # outer variables (i.e. unary_cond). + # outer variables (i.e. unary_cond). Use def in a called + # function avoids this problem. input_wrapper = None if m := r_input.match(input_str): x_no = m.group(1) @@ -809,6 +819,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]: if result_m is None: raise ParseError(case_m.group(2)) result_str = result_m.group(1) + # Like with partial_cond, do not define check_result via the def keyword if m := r_array_element.match(result_str): sign, x_no = m.groups() result_expr = f"{sign}x{x_no}_i" @@ -817,9 +828,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]: ) else: _check_result, result_expr = parse_result(result_m.group(1)) - - def check_result(i1: float, i2: float, result: float) -> bool: - return _check_result(result) + check_result = make_check_result(_check_result) cond_expr = " and ".join(partial_exprs) From 171ee5a082bbde88a22a7712e65e907628fbcd89 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Mar 2022 13:30:25 +0000 Subject: [PATCH 40/63] Document `BoundFromDtype` with extensive examples --- array_api_tests/test_special_cases.py | 59 ++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 9ce0de0f..4ddf9105 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -204,10 +204,65 @@ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]: @dataclass class BoundFromDtype(FromDtypeFunc): """ - A callable which bounds kwargs and strategy filters to xps.from_dtype() or - equivalent function. + A xps.from_dtype()-like callable with bounded kwargs, filters and base function. + We can bound: + 1. Keyword arguments that xps.from_dtype() can use, e.g. + + >>> from_dtype = BoundFromDtype(kwargs={'min_value': 0, 'allow_infinity': False}) + >>> strategy = from_dtype(xp.float64) + + is equivalent to + + >>> strategy = xps.from_dtype(xp.float64, min_value=0, allow_infinity=False) + + i.e. a strategy that generates finite floats above 0 + + 2. Functions that filter the elements strategy that xps.from_dtype() returns, e.g. + + >>> from_dtype = BoundFromDtype(filter=lambda i: i != 0) + >>> strategy = from_dtype(xp.float64) + + is equivalent to + + >>> strategy = xps.from_dtype(xp.float64).filter(lambda i: i != 0) + + i.e. a strategy that generates any floats except 0 + + 3. The underlying function that returns an elements strategy from a dtype, e.g. + + >>> from_dtype = BoundFromDtype( + ... from_dtype=lambda d: st.integers( + ... math.ceil(xp.finfo(d).min), math.floor(xp.finfo(d).max) + ... ) + ... ) + >>> strategy = from_dtype(xp.float64) + + is equivalent to + + >>> strategy = lambda d: st.integers( + ... math.ceil(xp.finfo(d).min), math.floor(xp.finfo(d).max) + ... ) + + i.e. a strategy that generates integers (within the dtypes range) + + This is useful to avoid translating special case conditions into either a + dict, filter or "base func", and instead allows us to generalise these three + components into a callable equivalent of xps.from_dtype(). + + Additionally, BoundFromDtype instances can be added together. This allows us + to keep parsing each condition individually - so we don't need to duplicate + complicated parsing code - as ultimately we can represent (and subsequently + test for) special cases which have more than one condition per array, e.g. + + "If x1_i is greater than 0 and x1_i is not 42, ..." + + could be translated as + + >>> gt_0_from_dtype = BoundFromDtype(kwargs={'min_value': 0}) + >>> not_42_from_dtype = BoundFromDtype(filter=lambda i: i != 42) + >>> from_dtype = gt_0_from_dtype + not_42_from_dtype """ From cb810b96fb655642fa9c2a4bc5e8b44aa66ec95a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Mar 2022 15:39:29 +0000 Subject: [PATCH 41/63] Document `parse_cond` --- array_api_tests/test_special_cases.py | 46 +++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 4ddf9105..6b221512 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -134,7 +134,7 @@ class ParseError(ValueError): def parse_value(value_str: str) -> float: """ - Parse a value string to return a float, e.g. + Parses a value string to return a float, e.g. >>> parse_value('1') 1. @@ -169,7 +169,7 @@ def parse_value(value_str: str) -> float: def parse_inline_code(inline_code: str) -> float: """ - Parse a Sphinx code string to return a float, e.g. + Parses a Sphinx code string to return a float, e.g. >>> parse_value('``0``') 0. @@ -208,7 +208,7 @@ class BoundFromDtype(FromDtypeFunc): We can bound: - 1. Keyword arguments that xps.from_dtype() can use, e.g. + 1. Keyword arguments that xps.from_dtype() can use, e.g. >>> from_dtype = BoundFromDtype(kwargs={'min_value': 0, 'allow_infinity': False}) >>> strategy = from_dtype(xp.float64) @@ -219,7 +219,7 @@ class BoundFromDtype(FromDtypeFunc): i.e. a strategy that generates finite floats above 0 - 2. Functions that filter the elements strategy that xps.from_dtype() returns, e.g. + 2. Functions that filter the elements strategy that xps.from_dtype() returns, e.g. >>> from_dtype = BoundFromDtype(filter=lambda i: i != 0) >>> strategy = from_dtype(xp.float64) @@ -230,7 +230,7 @@ class BoundFromDtype(FromDtypeFunc): i.e. a strategy that generates any floats except 0 - 3. The underlying function that returns an elements strategy from a dtype, e.g. + 3. The underlying function that returns an elements strategy from a dtype, e.g. >>> from_dtype = BoundFromDtype( ... from_dtype=lambda d: st.integers( @@ -307,14 +307,48 @@ def __add__(self, other: BoundFromDtype) -> BoundFromDtype: def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc: + """ + Wraps an elements strategy as a xps.from_dtype()-like function + """ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: - assert kw == {} # sanity check + assert len(kw) == 0 # sanity check return strat return from_dtype def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: + """ + Parses a Sphinx-formatted condition string to return: + + 1. A function which takes an input and returns True if it meets the + condition, otherwise False. + 2. A string template for expressing the condition. + 3. A xps.from_dtype()-like function which returns a strategy that generates + elements which meet the condition. + + e.g. + + >>> cond_func, expr_template, cond_from_dtype = parse_cond( + ... 'greater than ``0``' + ... ) + >>> expr_template.replace('{}', 'x_i') + >>> expr_template.replace('{}', 'x_i') + 'x_i > 0' + >>> cond_func(42) + True + >>> cond_func(-128) + False + >>> strategy = cond_from_dtype(xp.float64) + >>> for _ in range(5): + ... print(strategy.example()) + 1. + 0.1 + 1.7976931348623155e+179 + inf + 124.978 + + """ if m := r_not.match(cond_str): cond_str = m.group(1) not_cond = True From ddb287a9367bd5256304e3cfc08d439d17775080 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Mar 2022 16:05:11 +0000 Subject: [PATCH 42/63] Document `parse_result()` --- array_api_tests/test_special_cases.py | 34 ++++++++++++++++++++------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 6b221512..07cf8909 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -262,7 +262,8 @@ class BoundFromDtype(FromDtypeFunc): >>> gt_0_from_dtype = BoundFromDtype(kwargs={'min_value': 0}) >>> not_42_from_dtype = BoundFromDtype(filter=lambda i: i != 42) - >>> from_dtype = gt_0_from_dtype + not_42_from_dtype + >>> gt_0_from_dtype + not_42_from_dtype + BoundFromDtype(kwargs={'min_value': 0}, filter=(i)) """ @@ -329,17 +330,14 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: e.g. - >>> cond_func, expr_template, cond_from_dtype = parse_cond( - ... 'greater than ``0``' - ... ) - >>> expr_template.replace('{}', 'x_i') + >>> cond, expr_template, from_dtype = parse_cond('greater than ``0``') >>> expr_template.replace('{}', 'x_i') 'x_i > 0' - >>> cond_func(42) + >>> cond(42) True - >>> cond_func(-128) + >>> cond(-128) False - >>> strategy = cond_from_dtype(xp.float64) + >>> strategy = from_dtype(xp.float64) >>> for _ in range(5): ... print(strategy.example()) 1. @@ -387,7 +385,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: v1 = parse_value(m.group(1)) v2 = parse_value(m.group(2)) cond = make_or(make_strict_eq(v1), make_strict_eq(v2)) - expr_template = "{} == " + m.group(1) + " or {} == " + m.group(2) + expr_template = "({} == " + m.group(1) + " or {} == " + m.group(2) + ")" if not not_cond: strat = st.sampled_from([v1, v2]) elif cond_str in ["finite", "a finite number"]: @@ -469,6 +467,24 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: + """ + Parses a Sphinx-formatted result string to return: + + 1. A function which takes an input and returns True if it is the expected + result (or meets the condition of the expected result), otherwise False. + 2. A string that expresses the result. + + e.g. + + >>> check_result, expr = parse_result('``42``') + >>> expr_template.replace('{}', 'x_i') + '42' + >>> check_result(7) + False + >>> check_result(42) + True + + """ if m := r_code.match(result_str): value = parse_value(m.group(1)) check_result = make_strict_eq(value) # type: ignore From 67891b87e8f3b8d9be5b0e0fcb5b826ceddc07af Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Mar 2022 16:36:16 +0000 Subject: [PATCH 43/63] Factor out `UnaryCase.from_strings()` --- array_api_tests/test_special_cases.py | 59 ++++++++++++++------------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 07cf8909..4cad73b1 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -311,6 +311,7 @@ def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc: """ Wraps an elements strategy as a xps.from_dtype()-like function """ + def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: assert len(kw) == 0 # sanity check return strat @@ -553,23 +554,6 @@ class UnaryCase(Case): cond: UnaryCheck check_result: UnaryResultCheck - @classmethod - def from_strings(cls, cond_str: str, result_str: str): - cond, cond_expr_template, cond_from_dtype = parse_cond(cond_str) - cond_expr = cond_expr_template.replace("{}", "x_i") - _check_result, result_expr = parse_result(result_str) - - def check_result(i: float, result: float) -> bool: - return _check_result(result) - - return cls( - cond_expr=cond_expr, - cond=cond, - cond_from_dtype=cond_from_dtype, - result_expr=result_expr, - check_result=check_result, - ) - r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") r_even_int_round_case = re.compile( @@ -578,7 +562,7 @@ def check_result(i: float, result: float) -> bool: ) -def trailing_halves_from_dtype(dtype: DataType): +def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: m, M = dh.dtype_ranges[dtype] return st.integers(math.ceil(m) // 2, math.floor(M) // 2).map(lambda n: n * 0.5) @@ -594,6 +578,13 @@ def trailing_halves_from_dtype(dtype: DataType): ) +def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck: + def check_result(i: float, result: float) -> bool: + return check_just_result(result) + + return check_result + + def parse_unary_docstring(docstring: str) -> List[UnaryCase]: match = r_special_cases.search(docstring) if match is None: @@ -608,10 +599,22 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: continue if m := r_unary_case.search(case): try: - case = UnaryCase.from_strings(*m.groups()) + cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1)) + _check_result, result_expr = parse_result(m.group(2)) except ParseError as e: warn(f"not machine-readable: '{e.value}'") continue + cond_expr = cond_expr_template.replace("{}", "x_i") + # Do not define check_result in this function's body - see + # parse_binary_case comment. + check_result = make_unary_check_result(_check_result) + case = UnaryCase( + cond_expr=cond_expr, + cond=cond, + cond_from_dtype=cond_from_dtype, + result_expr=result_expr, + check_result=check_result, + ) cases.append(case) elif m := r_even_int_round_case.search(case): cases.append(even_int_round_case) @@ -741,7 +744,7 @@ def check_result(i1: float, i2: float, result: float) -> bool: return check_result -def make_check_result(check_just_result: UnaryCheck) -> BinaryResultCheck: +def make_binary_check_result(check_just_result: UnaryCheck) -> BinaryResultCheck: def check_result(i1: float, i2: float, result: float) -> bool: return check_just_result(result) @@ -843,12 +846,12 @@ def partial_cond(i1: float, i2: float) -> bool: else: unary_cond, expr_template, cond_from_dtype = parse_cond(value_str) - # Do not define partial_cond via the def keyword, as one - # partial_cond definition can mess up previous definitions - # in the partial_conds list. This is a hard-limitation of - # using local functions with the same name and that use the same - # outer variables (i.e. unary_cond). Use def in a called - # function avoids this problem. + # Do not define partial_cond via the def keyword or lambda + # expressions, as one partial_cond definition can mess up + # previous definitions in the partial_conds list. This is a + # hard-limitation of using local functions with the same name + # and that use the same outer variables (i.e. unary_cond). Use + # def in a called function avoids this problem. input_wrapper = None if m := r_input.match(input_str): x_no = m.group(1) @@ -924,7 +927,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]: if result_m is None: raise ParseError(case_m.group(2)) result_str = result_m.group(1) - # Like with partial_cond, do not define check_result via the def keyword + # Like with partial_cond, do not define check_result in this function's body. if m := r_array_element.match(result_str): sign, x_no = m.groups() result_expr = f"{sign}x{x_no}_i" @@ -933,7 +936,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]: ) else: _check_result, result_expr = parse_result(result_m.group(1)) - check_result = make_check_result(_check_result) + check_result = make_binary_check_result(_check_result) cond_expr = " and ".join(partial_exprs) From b6002306a8e007a181ab8c25a1129e72b31ac134 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Mar 2022 16:57:57 +0000 Subject: [PATCH 44/63] Better exception messages --- array_api_tests/test_special_cases.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 4cad73b1..34e37b9e 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1072,7 +1072,9 @@ def test_unary(func_name, func, case, x, data): f_in = f"{sh.fmt_idx('x', idx)}={in_}" f_out = f"{sh.fmt_idx('out', idx)}={out}" assert case.check_result(in_, out), ( - f"{f_out} not good [{func_name}()]\n" f"{case}\n" f"{f_in}" + f"{f_out}, but should be {case.result_expr} [{func_name}()]\n" + f"condition: {case.cond_expr}\n" + f"{f_in}" ) break assume(good_example) @@ -1115,7 +1117,9 @@ def test_binary(func_name, func, case, x1, x2, data): f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" f_out = f"{sh.fmt_idx('out', o_idx)}={o}" assert case.check_result(l, r, o), ( - f"{f_out} not good [{func_name}()]\n" f"{case}\n" f"{f_left}, {f_right}" + f"{f_out}, but should be {case.result_expr} [{func_name}()]\n" + f"condition: {case}\n" + f"{f_left}, {f_right}" ) break assume(good_example) From a9a523445aec5b93b840e5db77c9b5c9d5967daf Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 14 Mar 2022 17:00:11 +0000 Subject: [PATCH 45/63] Not worry about removing spec package from path --- array_api_tests/stubs.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py index ab1851f3..2b2e216c 100644 --- a/array_api_tests/stubs.py +++ b/array_api_tests/stubs.py @@ -42,6 +42,3 @@ objects = [getattr(mod, name) for name in mod.__all__] assert all(isinstance(o, FunctionType) for o in objects) extension_to_funcs[ext] = objects - - -sys.path.remove(spec_abs_path) From 180b152c36e51057b778b70c5cc1e9bc314c889f Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Mar 2022 08:52:49 +0000 Subject: [PATCH 46/63] Doc fixes --- array_api_tests/test_special_cases.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 34e37b9e..075537e0 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -228,7 +228,7 @@ class BoundFromDtype(FromDtypeFunc): >>> strategy = xps.from_dtype(xp.float64).filter(lambda i: i != 0) - i.e. a strategy that generates any floats except 0 + i.e. a strategy that generates any float except +0 and -0 3. The underlying function that returns an elements strategy from a dtype, e.g. @@ -245,7 +245,7 @@ class BoundFromDtype(FromDtypeFunc): ... math.ceil(xp.finfo(d).min), math.floor(xp.finfo(d).max) ... ) - i.e. a strategy that generates integers (within the dtypes range) + i.e. a strategy that generates integers (within the dtype's range) This is useful to avoid translating special case conditions into either a dict, filter or "base func", and instead allows us to generalise these three @@ -332,12 +332,12 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: e.g. >>> cond, expr_template, from_dtype = parse_cond('greater than ``0``') - >>> expr_template.replace('{}', 'x_i') - 'x_i > 0' >>> cond(42) True >>> cond(-128) False + >>> expr_template.replace('{}', 'x_i') + 'x_i > 0' >>> strategy = from_dtype(xp.float64) >>> for _ in range(5): ... print(strategy.example()) @@ -478,12 +478,12 @@ def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: e.g. >>> check_result, expr = parse_result('``42``') - >>> expr_template.replace('{}', 'x_i') - '42' >>> check_result(7) False >>> check_result(42) True + >>> expr + '42' """ if m := r_code.match(result_str): From 6c802c359fbfeb78cd58943cf288d40fe85166d9 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Mar 2022 10:05:30 +0000 Subject: [PATCH 47/63] Document `parse_unary_docstring()` --- array_api_tests/test_special_cases.py | 40 +++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 075537e0..51d3da30 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -586,6 +586,46 @@ def check_result(i: float, result: float) -> bool: def parse_unary_docstring(docstring: str) -> List[UnaryCase]: + """ + Parses a Sphinx-formatted docstring of a unary function to return a list of + codified unary cases, e.g. + + >>> def sqrt(x: array, /) -> array: + ... ''' + ... Calculates the square root + ... + ... **Special Cases** + ... + ... For floating-point operands, + ... + ... - If ``x_i`` is ``NaN``, the result is ``NaN``. + ... - If ``x_i`` is less than ``0``, the result is ``NaN``. + ... - If ``x_i`` is ``+0``, the result is ``+0``. + ... - If ``x_i`` is ``-0``, the result is ``-0``. + ... - If ``x_i`` is ``+infinity``, the result is ``+infinity``. + ... + ... Parameters + ... ---------- + ... x: array + ... input array. Should have a floating-point data type + ... + ... Returns + ... ------- + ... out: array + ... an array containing the square root of each element in ``x`` + ... ''' + ... ... + >>> unary_cases = parse_unary_docstring(sqrt.__doc__) + >>> for case in unary_cases: + ... print(repr(case)) + UnaryCase(x_i == NaN -> NaN) + UnaryCase(x_i < 0 -> NaN) + UnaryCase(x_i == +0 -> +0) + UnaryCase(x_i == -0 -> -0) + UnaryCase(x_i == +infinity -> +infinity) + + """ + match = r_special_cases.search(docstring) if match is None: return [] From eaa3b05e20e71950e2cdb51a2387f920f596a191 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Mar 2022 10:20:58 +0000 Subject: [PATCH 48/63] Document `parse_binary_docstring()` --- array_api_tests/test_special_cases.py | 46 +++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 51d3da30..0761936f 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -590,7 +590,7 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: Parses a Sphinx-formatted docstring of a unary function to return a list of codified unary cases, e.g. - >>> def sqrt(x: array, /) -> array: + >>> def sqrt(x): ... ''' ... Calculates the square root ... @@ -607,14 +607,14 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: ... Parameters ... ---------- ... x: array - ... input array. Should have a floating-point data type + ... input array ... ... Returns ... ------- ... out: array ... an array containing the square root of each element in ``x`` ... ''' - ... ... + ... >>> unary_cases = parse_unary_docstring(sqrt.__doc__) >>> for case in unary_cases: ... print(repr(case)) @@ -1014,6 +1014,46 @@ def cond(i1: float, i2: float) -> bool: def parse_binary_docstring(docstring: str) -> List[BinaryCase]: + """ + Parses a Sphinx-formatted docstring of a binary function to return a list of + codified binary cases, e.g. + + >>> def logaddexp(x1, x2): + ... ''' + ... Calculates the logarithm of the sum of exponentiations + ... + ... **Special Cases** + ... + ... For floating-point operands, + ... + ... - If either ``x1_i`` or ``x2_i`` is ``NaN``, the result is ``NaN``. + ... - If ``x1_i`` is ``+infinity`` and ``x2_i`` is not ``NaN``, the + ... result is ``+infinity``. + ... - If ``x1_i`` is not ``NaN`` and ``x2_i`` is ``+infinity``, the + ... result is ``+infinity``. + ... + ... Parameters + ... ---------- + ... x1: array + ... first input array + ... x2: array + ... second input array + ... + ... Returns + ... ------- + ... out: array + ... an array containing the results + ... ''' + ... + >>> binary_cases = parse_binary_docstring(logaddexp.__doc__) + >>> for case in binary_cases: + ... print(repr(case)) + BinaryCase(x1_i == NaN or x2_i == NaN -> NaN) + BinaryCase(x1_i == +infinity and not x2_i == NaN -> +infinity) + BinaryCase(not x1_i == NaN and x2_i == +infinity -> +infinity) + + """ + match = r_special_cases.search(docstring) if match is None: return [] From 8b24204bbbe7526143a9cf72e831e440873293ad Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Mar 2022 10:38:59 +0000 Subject: [PATCH 49/63] Extend `good_examples` comment --- array_api_tests/test_special_cases.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 0761936f..0a973d32 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1122,10 +1122,14 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: ) +# test_unary and test_binary naively generate arrays, i.e. arrays that might not +# meet the condition that is being test. We then forcibly make the array meet +# the condition by picking a random index to insert an acceptable element. +# # good_example is a flag that tells us whether Hypothesis generated an array # with at least on element that is special-cased. We reject the example when # its False - Hypothesis will complain if we reject too many examples, thus -# indicating we should modify the array strategy being used. +# indicating we've done something wrong. @pytest.mark.parametrize("func_name, func, case", unary_params) From 5a9ea21975a55d82a2d4ea33434422308d06122b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Mar 2022 10:46:14 +0000 Subject: [PATCH 50/63] Move `array-api` submodule to the top-level --- .gitmodules | 2 +- array_api_tests/array-api => array-api | 0 array_api_tests/stubs.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) rename array_api_tests/array-api => array-api (100%) diff --git a/.gitmodules b/.gitmodules index 4128e9f2..c225c24e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "array_api_tests/array-api"] - path = array_api_tests/array-api + path = array-api url = https://github.com/data-apis/array-api/ diff --git a/array_api_tests/array-api b/array-api similarity index 100% rename from array_api_tests/array-api rename to array-api diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py index 2b2e216c..15fb7646 100644 --- a/array_api_tests/stubs.py +++ b/array_api_tests/stubs.py @@ -8,8 +8,8 @@ __all__ = ["category_to_funcs", "array", "extension_to_funcs"] -spec_dir = Path(__file__).parent / "array-api" / "spec" / "API_specification" -assert spec_dir.exists(), f"{spec_dir} not found - try `git pull --recurse-submodules`" +spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / "API_specification" +assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`" sigs_dir = spec_dir / "signatures" assert sigs_dir.exists() From 795fd04175f771d9813a07553576466b65dc79a3 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 15 Mar 2022 19:01:37 +0000 Subject: [PATCH 51/63] Minor doc fixes --- array_api_tests/test_special_cases.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 0a973d32..68cd77aa 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -241,8 +241,8 @@ class BoundFromDtype(FromDtypeFunc): is equivalent to - >>> strategy = lambda d: st.integers( - ... math.ceil(xp.finfo(d).min), math.floor(xp.finfo(d).max) + >>> strategy = st.integers( + ... math.ceil(xp.finfo(xp.float64).min), math.floor(xp.finfo(xp.float64).max) ... ) i.e. a strategy that generates integers (within the dtype's range) @@ -1027,10 +1027,8 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: ... For floating-point operands, ... ... - If either ``x1_i`` or ``x2_i`` is ``NaN``, the result is ``NaN``. - ... - If ``x1_i`` is ``+infinity`` and ``x2_i`` is not ``NaN``, the - ... result is ``+infinity``. - ... - If ``x1_i`` is not ``NaN`` and ``x2_i`` is ``+infinity``, the - ... result is ``+infinity``. + ... - If ``x1_i`` is ``+infinity`` and ``x2_i`` is not ``NaN``, the result is ``+infinity``. + ... - If ``x1_i`` is not ``NaN`` and ``x2_i`` is ``+infinity``, the result is ``+infinity``. ... ... Parameters ... ---------- From f0c3df44fa5972115abefed4279e437f11cace21 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 16 Mar 2022 11:42:59 +0000 Subject: [PATCH 52/63] Document `make_binary_cond()` --- array_api_tests/test_special_cases.py | 33 ++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 68cd77aa..d9e290f6 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -732,6 +732,35 @@ def make_binary_cond( *, input_wrapper: Optional[Callable[[float], float]] = None, ) -> BinaryCond: + """ + Wraps a unary condition as a binary condition, e.g. + + >>> unary_cond = lambda i: i == 42 + + >>> binary_cond_first = make_binary_cond(BinaryCondArg.FIRST, unary_cond) + >>> binary_cond_first(42, 0) + True + >>> binary_cond_second = make_binary_cond(BinaryCondArg.SECOND, unary_cond) + >>> binary_cond_second(42, 0) + False + >>> binary_cond_second(0, 42) + True + >>> binary_cond_both = make_binary_cond(BinaryCondArg.BOTH, unary_cond) + >>> binary_cond_both(42, 0) + False + >>> binary_cond_both(42, 42) + True + >>> binary_cond_either = make_binary_cond(BinaryCondArg.EITHER, unary_cond) + >>> binary_cond_either(0, 0) + False + >>> binary_cond_either(42, 0) + True + >>> binary_cond_either(0, 42) + True + >>> binary_cond_either(42, 42) + True + + """ if input_wrapper is None: input_wrapper = noop @@ -823,11 +852,13 @@ def parse_binary_case(case_str: str) -> BinaryCase: if in_sign != "" or other_no == in_no: raise ParseError(cond_str) partial_expr = f"{in_sign}x{in_no}_i == {other_sign}x{other_no}_i" + input_wrapper = lambda i: -i if other_sign == "-" else noop + # For these scenarios, we want to make sure both array elements + # generate respective to one another by using a shared strategy. shared_from_dtype = lambda d, **kw: st.shared( xps.from_dtype(d, **kw), key=cond_str ) - if other_no == "1": def partial_cond(i1: float, i2: float) -> bool: From 546827ecabbdf0fba9b3ea9afa9c5ad5be8dec46 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 16 Mar 2022 12:02:43 +0000 Subject: [PATCH 53/63] Document `make_eq_input_check_result()` --- array_api_tests/test_special_cases.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index d9e290f6..a9b28c69 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -736,7 +736,6 @@ def make_binary_cond( Wraps a unary condition as a binary condition, e.g. >>> unary_cond = lambda i: i == 42 - >>> binary_cond_first = make_binary_cond(BinaryCondArg.FIRST, unary_cond) >>> binary_cond_first(42, 0) True @@ -790,6 +789,24 @@ def partial_cond(i1: float, i2: float) -> bool: def make_eq_input_check_result( eq_to: BinaryCondArg, *, eq_neg: bool = False ) -> BinaryResultCheck: + """ + Returns a result checker for cases where the result equals an array element + + >>> check_result_first = make_eq_input_check_result(BinaryCondArg.FIRST) + >>> check_result(42, 0, 42) + True + >>> check_result_second = make_eq_input_check_result(BinaryCondArg.SECOND) + >>> check_result(42, 0, 42) + False + >>> check_result(0, 42, 42) + True + >>> check_result_neg_first = make_eq_input_check_result(BinaryCondArg.FIRST, eq_neg=True) + >>> check_result_neg_first(42, 0, 42) + False + >>> check_result_neg_first(42, 0, -42) + True + + """ if eq_neg: input_wrapper = lambda i: -i else: From f5b0975a8769145e887f39cb675a92eadbd01779 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 16 Mar 2022 14:08:23 +0000 Subject: [PATCH 54/63] Make `parse_cond()` only return `BoundFromDtype` Also add some more granular documentation to it --- array_api_tests/test_special_cases.py | 109 +++++++++++--------------- 1 file changed, 46 insertions(+), 63 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index a9b28c69..9eff25be 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -319,7 +319,7 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: return from_dtype -def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: +def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, BoundFromDtype]: """ Parses a Sphinx-formatted condition string to return: @@ -348,22 +348,30 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: 124.978 """ + # We first identify whether the condition starts with "not". If so, we note + # this but parse the condition as if it was not negated. if m := r_not.match(cond_str): cond_str = m.group(1) not_cond = True else: not_cond = False + # We parse the condition to identify the condition function, expression + # template, and xps.from_dtype()-like condition strategy. kwargs = {} filter_ = None from_dtype = None # type: ignore - strat = None if m := r_code.match(cond_str): value = parse_value(m.group(1)) cond = make_strict_eq(value) expr_template = "{} == " + m.group(1) - if not not_cond: - strat = st.just(value) + from_dtype = wrap_strat_as_from_dtype(st.just(value)) + elif m := r_either_code.match(cond_str): + v1 = parse_value(m.group(1)) + v2 = parse_value(m.group(2)) + cond = make_or(make_strict_eq(v1), make_strict_eq(v2)) + expr_template = "({} == " + m.group(1) + " or {} == " + m.group(2) + ")" + from_dtype = wrap_strat_as_from_dtype(st.sampled_from([v1, v2])) elif m := r_equal_to.match(cond_str): value = parse_value(m.group(1)) if math.isnan(value): @@ -374,97 +382,73 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]: value = parse_value(m.group(1)) cond = make_gt(value) expr_template = "{} > " + m.group(1) - if not not_cond: - kwargs = {"min_value": value, "exclude_min": True} + kwargs = {"min_value": value, "exclude_min": True} elif m := r_lt.match(cond_str): value = parse_value(m.group(1)) cond = make_lt(value) expr_template = "{} < " + m.group(1) - if not not_cond: - kwargs = {"max_value": value, "exclude_max": True} - elif m := r_either_code.match(cond_str): - v1 = parse_value(m.group(1)) - v2 = parse_value(m.group(2)) - cond = make_or(make_strict_eq(v1), make_strict_eq(v2)) - expr_template = "({} == " + m.group(1) + " or {} == " + m.group(2) + ")" - if not not_cond: - strat = st.sampled_from([v1, v2]) + kwargs = {"max_value": value, "exclude_max": True} elif cond_str in ["finite", "a finite number"]: cond = math.isfinite expr_template = "isfinite({})" - if not not_cond: - kwargs = {"allow_nan": False, "allow_infinity": False} + kwargs = {"allow_nan": False, "allow_infinity": False} elif cond_str in "a positive (i.e., greater than ``0``) finite number": cond = lambda i: math.isfinite(i) and i > 0 expr_template = "isfinite({}) and {} > 0" - if not not_cond: - kwargs = { - "allow_nan": False, - "allow_infinity": False, - "min_value": 0, - "exclude_min": True, - } + kwargs = { + "allow_nan": False, + "allow_infinity": False, + "min_value": 0, + "exclude_min": True, + } elif cond_str == "a negative (i.e., less than ``0``) finite number": cond = lambda i: math.isfinite(i) and i < 0 expr_template = "isfinite({}) and {} < 0" - if not not_cond: - kwargs = { - "allow_nan": False, - "allow_infinity": False, - "max_value": 0, - "exclude_max": True, - } + kwargs = { + "allow_nan": False, + "allow_infinity": False, + "max_value": 0, + "exclude_max": True, + } elif cond_str == "positive": cond = lambda i: math.copysign(1, i) == 1 expr_template = "copysign(1, {}) == 1" - if not not_cond: - # We assume (positive) zero is special cased seperately - kwargs = {"min_value": 0, "exclude_min": True} + # We assume (positive) zero is special cased seperately + kwargs = {"min_value": 0, "exclude_min": True} elif cond_str == "negative": cond = lambda i: math.copysign(1, i) == -1 expr_template = "copysign(1, {}) == -1" - if not not_cond: - # We assume (negative) zero is special cased seperately - kwargs = {"max_value": 0, "exclude_max": True} + # We assume (negative) zero is special cased seperately + kwargs = {"max_value": 0, "exclude_max": True} elif "nonzero finite" in cond_str: cond = lambda i: math.isfinite(i) and i != 0 expr_template = "isfinite({}) and {} != 0" - if not not_cond: - kwargs = {"allow_nan": False, "allow_infinity": False} - filter_ = lambda n: n != 0 + kwargs = {"allow_nan": False, "allow_infinity": False} + filter_ = lambda n: n != 0 elif cond_str == "an integer value": cond = lambda i: i.is_integer() expr_template = "{}.is_integer()" - if not not_cond: - from_dtype = integers_from_dtype # type: ignore + from_dtype = integers_from_dtype # type: ignore elif cond_str == "an odd integer value": cond = lambda i: i.is_integer() and i % 2 == 1 expr_template = "{}.is_integer() and {} % 2 == 1" - if not not_cond: + from_dtype = integers_from_dtype # type: ignore - def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: - return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1) + def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: + return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1) else: raise ParseError(cond_str) - if strat is not None: - if ( - not_cond - or len(kwargs) != 0 - or filter_ is not None - or from_dtype is not None - ): - raise ParseError(cond_str) - return cond, expr_template, wrap_strat_as_from_dtype(strat) - if not_cond: - expr_template = f"not {expr_template}" + # We handle negated conitions by simply negating the condition function + # and using it as a filter for xps.from_dtype() (or an equivalent). cond = make_not_cond(cond) - kwargs = {} + expr_template = f"not {expr_template}" filter_ = cond - assert kwargs is not None - return cond, expr_template, BoundFromDtype(kwargs, filter_, from_dtype) + return cond, expr_template, BoundFromDtype(filter_=filter_) + else: + return cond, expr_template, BoundFromDtype(kwargs, filter_, from_dtype) def parse_result(result_str: str) -> Tuple[UnaryCheck, str]: @@ -838,6 +822,9 @@ def check_result(i1: float, i2: float, result: float) -> bool: def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: + """ + Returns a strategy that generates float-casted integers within the bounds of dtype. + """ for k in kw.keys(): # sanity check assert k in ["min_value", "max_value", "exclude_min", "exclude_max"] @@ -1036,16 +1023,12 @@ def cond(i1: float, i2: float) -> bool: elif len(x1_cond_from_dtypes) == 1: x1_cond_from_dtype = x1_cond_from_dtypes[0] else: - if not all(isinstance(fd, BoundFromDtype) for fd in x1_cond_from_dtypes): - raise ParseError(case_str) x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype()) if len(x2_cond_from_dtypes) == 0: x2_cond_from_dtype = xps.from_dtype elif len(x2_cond_from_dtypes) == 1: x2_cond_from_dtype = x2_cond_from_dtypes[0] else: - if not all(isinstance(fd, BoundFromDtype) for fd in x2_cond_from_dtypes): - raise ParseError(case_str) x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype()) return BinaryCase( From 3773a4df3a500cf3b6e5e30f84869224ae87a4c0 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 16 Mar 2022 16:48:27 +0000 Subject: [PATCH 55/63] Document `parse_binary_case()` --- array_api_tests/test_special_cases.py | 45 ++++++++++++++++++++------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 9eff25be..1c67dab0 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -327,14 +327,14 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, BoundFromDtype]: condition, otherwise False. 2. A string template for expressing the condition. 3. A xps.from_dtype()-like function which returns a strategy that generates - elements which meet the condition. + elements that meet the condition. e.g. >>> cond, expr_template, from_dtype = parse_cond('greater than ``0``') >>> cond(42) True - >>> cond(-128) + >>> cond(-123) False >>> expr_template.replace('{}', 'x_i') 'x_i > 0' @@ -582,8 +582,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: ... ... For floating-point operands, ... - ... - If ``x_i`` is ``NaN``, the result is ``NaN``. ... - If ``x_i`` is less than ``0``, the result is ``NaN``. + ... - If ``x_i`` is ``NaN``, the result is ``NaN``. ... - If ``x_i`` is ``+0``, the result is ``+0``. ... - If ``x_i`` is ``-0``, the result is ``-0``. ... - If ``x_i`` is ``+infinity``, the result is ``+infinity``. @@ -602,11 +602,16 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: >>> unary_cases = parse_unary_docstring(sqrt.__doc__) >>> for case in unary_cases: ... print(repr(case)) - UnaryCase(x_i == NaN -> NaN) - UnaryCase(x_i < 0 -> NaN) - UnaryCase(x_i == +0 -> +0) - UnaryCase(x_i == -0 -> -0) - UnaryCase(x_i == +infinity -> +infinity) + UnaryCase( NaN>) + UnaryCase( NaN>) + UnaryCase( +0>) + UnaryCase( -0>) + UnaryCase( +infinity>) + >>> lt_0_case = unary_cases[0] + >>> lt_0_case.cond(-123) + True + >>> lt_0_case.check_result(-123, float('nan')) + True """ @@ -841,6 +846,22 @@ def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: def parse_binary_case(case_str: str) -> BinaryCase: + """ + Parses a Sphinx-formatted binary case string to return codified binary cases, e.g. + + >>> case_str = ( + ... "If ``x1_i`` is greater than ``0``, ``x1_i`` is a finite number, " + ... "and ``x2_i`` is ``+infinity``, the result is ``NaN``." + ... ) + >>> case = parse_binary_case(case_str) + >>> case + BinaryCase( 0 and isfinite(x1_i) and x2_i == +infinity -> NaN>) + >>> case.cond(42, float('inf')) + True + >>> case.check_result(42, float('inf'), float('nan')) + True + + """ case_m = r_binary_case.match(case_str) if case_m is None: raise ParseError(case_str) @@ -857,12 +878,12 @@ def parse_binary_case(case_str: str) -> BinaryCase: raise ParseError(cond_str) partial_expr = f"{in_sign}x{in_no}_i == {other_sign}x{other_no}_i" - input_wrapper = lambda i: -i if other_sign == "-" else noop # For these scenarios, we want to make sure both array elements # generate respective to one another by using a shared strategy. shared_from_dtype = lambda d, **kw: st.shared( xps.from_dtype(d, **kw), key=cond_str ) + input_wrapper = lambda i: -i if other_sign == "-" else noop if other_no == "1": def partial_cond(i1: float, i2: float) -> bool: @@ -1077,9 +1098,9 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: >>> binary_cases = parse_binary_docstring(logaddexp.__doc__) >>> for case in binary_cases: ... print(repr(case)) - BinaryCase(x1_i == NaN or x2_i == NaN -> NaN) - BinaryCase(x1_i == +infinity and not x2_i == NaN -> +infinity) - BinaryCase(not x1_i == NaN and x2_i == +infinity -> +infinity) + BinaryCase( NaN>) + BinaryCase( +infinity>) + BinaryCase( +infinity>) """ From f82fee507adc3ab7816b8949498294c69a42ac20 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 16 Mar 2022 17:26:23 +0000 Subject: [PATCH 56/63] Generate and test the even rounding halves case correctly --- array_api_tests/test_special_cases.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 1c67dab0..1a19958f 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -540,20 +540,31 @@ class UnaryCase(Case): r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") -r_even_int_round_case = re.compile( +r_even_round_halves_case = re.compile( "If two integers are equally close to ``x_i``, " "the result is the even integer closest to ``x_i``" ) def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: - m, M = dh.dtype_ranges[dtype] - return st.integers(math.ceil(m) // 2, math.floor(M) // 2).map(lambda n: n * 0.5) + """ + Returns a strategy that generates floats that end with .5 and are within the + bounds of dtype. + """ + # We bound our base integers strategy to a range of values which should be + # able to represent a decimal 5 when .5 is added or subtracted. + if dtype == xp.float32: + abs_max = 10**4 + else: + abs_max = 10**16 + return st.sampled_from([0.5, -0.5]).flatmap( + lambda half: st.integers(-abs_max, abs_max).map(lambda n: n + half) + ) -even_int_round_case = UnaryCase( - cond_expr="i % 0.5 == 0", - cond=lambda i: i % 0.5 == 0, +even_round_halves_case = UnaryCase( + cond_expr="modf(i)[0] == 0.5", + cond=lambda i: math.modf(i)[0] == 0.5, cond_from_dtype=trailing_halves_from_dtype, result_expr="Decimal(i).to_integral_exact(ROUND_HALF_EVEN)", check_result=lambda i, result: ( @@ -645,8 +656,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]: check_result=check_result, ) cases.append(case) - elif m := r_even_int_round_case.search(case): - cases.append(even_int_round_case) + elif m := r_even_round_halves_case.search(case): + cases.append(even_round_halves_case) else: if not r_remaining_case.search(case): warn(f"case not machine-readable: '{case}'") From 558ffdce7e96b444365eb358e1d45228f091a5b6 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 16 Mar 2022 17:56:15 +0000 Subject: [PATCH 57/63] Case expression fixes --- array_api_tests/test_special_cases.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 1a19958f..5667de8b 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -364,13 +364,13 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, BoundFromDtype]: if m := r_code.match(cond_str): value = parse_value(m.group(1)) cond = make_strict_eq(value) - expr_template = "{} == " + m.group(1) + expr_template = "{} is " + m.group(1) from_dtype = wrap_strat_as_from_dtype(st.just(value)) elif m := r_either_code.match(cond_str): v1 = parse_value(m.group(1)) v2 = parse_value(m.group(2)) cond = make_or(make_strict_eq(v1), make_strict_eq(v2)) - expr_template = "({} == " + m.group(1) + " or {} == " + m.group(2) + ")" + expr_template = "({} is " + m.group(1) + " or {} == " + m.group(2) + ")" from_dtype = wrap_strat_as_from_dtype(st.sampled_from([v1, v2])) elif m := r_equal_to.match(cond_str): value = parse_value(m.group(1)) @@ -487,7 +487,7 @@ def check_result(result: float) -> bool: return True return math.copysign(1, result) == 1 - expr = "+" + expr = "positive sign" elif "negative" in result_str: def check_result(result: float) -> bool: @@ -496,7 +496,7 @@ def check_result(result: float) -> bool: return True return math.copysign(1, result) == -1 - expr = "-" + expr = "negative sign" else: raise ParseError(result_str) @@ -927,7 +927,7 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: unary_cond, expr_template, cond_from_dtype = parse_cond(m.group(1)) left_expr = expr_template.replace("{}", "x1_i") right_expr = expr_template.replace("{}", "x2_i") - partial_expr = f"({left_expr}) and ({right_expr})" + partial_expr = f"{left_expr} and {right_expr}" partial_cond = make_binary_cond( # type: ignore BinaryCondArg.BOTH, unary_cond ) @@ -972,7 +972,7 @@ def partial_cond(i1: float, i2: float) -> bool: elif r_and_input.match(input_str): left_expr = expr_template.replace("{}", "x1_i") right_expr = expr_template.replace("{}", "x2_i") - partial_expr = f"({left_expr}) and ({right_expr})" + partial_expr = f"{left_expr} and {right_expr}" cond_arg = BinaryCondArg.BOTH elif r_or_input.match(input_str): left_expr = expr_template.replace("{}", "x1_i") From 456fc6ccf16b31a7a68021dab4e2888c28b6cd5b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 17 Mar 2022 10:11:26 +0000 Subject: [PATCH 58/63] Drop outdated `x<1/2>_cond_from_dtype` logic --- array_api_tests/test_special_cases.py | 28 ++++++++++++--------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 5667de8b..67895f44 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -1010,22 +1010,28 @@ def partial_cond(i1: float, i2: float) -> bool: st.sampled_from([(True, False), (False, True), (True, True)]) ) - def _x1_cond_from_dtype(dtype) -> st.SearchStrategy[float]: + def _x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: + assert len(kw) == 0 # sanity check return use_x1_or_x2_strat.flatmap( lambda t: cond_from_dtype(dtype) if t[0] else xps.from_dtype(dtype) ) - def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]: + def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: + assert len(kw) == 0 # sanity check return use_x1_or_x2_strat.flatmap( lambda t: cond_from_dtype(dtype) if t[1] else xps.from_dtype(dtype) ) - x1_cond_from_dtypes.append(_x1_cond_from_dtype) - x2_cond_from_dtypes.append(_x2_cond_from_dtype) + x1_cond_from_dtypes.append( + BoundFromDtype(base_func=_x1_cond_from_dtype) + ) + x2_cond_from_dtypes.append( + BoundFromDtype(base_func=_x2_cond_from_dtype) + ) partial_conds.append(partial_cond) partial_exprs.append(partial_expr) @@ -1050,18 +1056,8 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]: def cond(i1: float, i2: float) -> bool: return all(pc(i1, i2) for pc in partial_conds) - if len(x1_cond_from_dtypes) == 0: - x1_cond_from_dtype = xps.from_dtype - elif len(x1_cond_from_dtypes) == 1: - x1_cond_from_dtype = x1_cond_from_dtypes[0] - else: - x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype()) - if len(x2_cond_from_dtypes) == 0: - x2_cond_from_dtype = xps.from_dtype - elif len(x2_cond_from_dtypes) == 1: - x2_cond_from_dtype = x2_cond_from_dtypes[0] - else: - x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype()) + x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype()) + x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype()) return BinaryCase( cond_expr=cond_expr, From d95c2ab1c8851ab5389385f67ad22a014d461ab8 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 17 Mar 2022 11:36:51 +0000 Subject: [PATCH 59/63] Test special cases for operators --- array_api_tests/test_special_cases.py | 94 ++++++++++++++++++++++++--- 1 file changed, 86 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 67895f44..a68171b7 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -4,6 +4,7 @@ import inspect import math +import operator import re from dataclasses import dataclass, field from decimal import ROUND_HALF_EVEN, Decimal @@ -24,6 +25,10 @@ from . import xps from ._array_module import mod as xp from .stubs import category_to_funcs +from .test_operators_and_elementwise_functions import ( + oneway_broadcastable_shapes, + oneway_promotable_dtypes, +) pytestmark = pytest.mark.ci @@ -1138,6 +1143,8 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: unary_params = [] binary_params = [] +iop_params = [] +func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()} for stub in category_to_funcs["elementwise"]: if stub.__doc__ is None: warn(f"{stub.__name__}() stub has no docstring") @@ -1157,20 +1164,39 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]: continue if param_names[0] == "x": if cases := parse_unary_docstring(stub.__doc__): - for case in cases: - id_ = f"{stub.__name__}({case.cond_expr}) -> {case.result_expr}" - p = pytest.param(stub.__name__, func, case, id=id_) - unary_params.append(p) + func_name_to_func = {stub.__name__: func} + if stub.__name__ in func_to_op.keys(): + op_name = func_to_op[stub.__name__] + op = getattr(operator, op_name) + func_name_to_func[op_name] = op + for func_name, func in func_name_to_func.items(): + for case in cases: + id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(func_name, func, case, id=id_) + unary_params.append(p) continue if len(sig.parameters) == 1: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": if cases := parse_binary_docstring(stub.__doc__): - for case in cases: - id_ = f"{stub.__name__}({case.cond_expr}) -> {case.result_expr}" - p = pytest.param(stub.__name__, func, case, id=id_) - binary_params.append(p) + func_name_to_func = {stub.__name__: func} + if stub.__name__ in func_to_op.keys(): + op_name = func_to_op[stub.__name__] + op = getattr(operator, op_name) + func_name_to_func[op_name] = op + # We collect inplaceoperator test cases seperately + iop_name = "__i" + op_name[2:] + iop = getattr(operator, iop_name) + for case in cases: + id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(iop_name, iop, case, id=id_) + iop_params.append(p) + for func_name, func in func_name_to_func.items(): + for case in cases: + id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}" + p = pytest.param(func_name, func, case, id=id_) + binary_params.append(p) continue else: warn( @@ -1264,3 +1290,55 @@ def test_binary(func_name, func, case, x1, x2, data): ) break assume(good_example) + + +@pytest.mark.parametrize("iop_name, iop, case", iop_params) +@given( + oneway_dtypes=oneway_promotable_dtypes(dh.float_dtypes), + oneway_shapes=oneway_broadcastable_shapes(), + data=st.data(), +) +def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): + x1 = data.draw( + xps.arrays(dtype=oneway_dtypes.result_dtype, shape=oneway_shapes.result_shape), + label="x1", + ) + x2 = data.draw( + xps.arrays(dtype=oneway_dtypes.input_dtype, shape=oneway_shapes.input_shape), + label="x2", + ) + + all_indices = list(sh.iter_indices(x1.shape, x2.shape, x1.shape)) + + indices_strat = st.shared(st.sampled_from(all_indices)) + set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx") + set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value") + x1[set_x1_idx] = set_x1_value + note(f"{x1=}") + set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") + set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") + x2[set_x2_idx] = set_x2_value + note(f"{x2=}") + + res = xp.asarray(x1, copy=True) + iop(res, x2) + # sanity check + ph.assert_result_shape(iop_name, [x1.shape, x2.shape], res.shape) + + good_example = False + for l_idx, r_idx, o_idx in all_indices: + l = float(x1[l_idx]) + r = float(x2[r_idx]) + if case.cond(l, r): + good_example = True + o = float(res[o_idx]) + f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" + f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" + f_out = f"{sh.fmt_idx('out', o_idx)}={o}" + assert case.check_result(l, r, o), ( + f"{f_out}, but should be {case.result_expr} [{iop_name}()]\n" + f"condition: {case}\n" + f"{f_left}, {f_right}" + ) + break + assume(good_example) From 9eac45b736f9207075e72f28e535091f806aeb5c Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 17 Mar 2022 13:44:10 +0000 Subject: [PATCH 60/63] Update skipped NumPy workflow tests --- .github/workflows/numpy.yml | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml index 04090923..6f45e636 100644 --- a/.github/workflows/numpy.yml +++ b/.github/workflows/numpy.yml @@ -43,13 +43,27 @@ jobs: # waiting on NumPy to allow/revert distinct NaNs for np.unique # https://github.com/numpy/numpy/issues/20326#issuecomment-1012380448 array_api_tests/test_set_functions.py + # https://github.com/numpy/numpy/issues/21211 + array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] # noted diversions from spec - array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i == +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] - array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i == +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] - array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i == -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] - array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i == -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] - array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i == -infinity) -> -0] - array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i == +infinity) -> -0] + array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] + array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] + array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] + array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] + array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] + array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] + array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] + array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] + array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] + array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] + array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] + array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] + array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] + array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] + array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] + array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] EOF From 2068ee48eaa7f20e66c575386a0ca91fafa6cac6 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 17 Mar 2022 18:58:43 +0000 Subject: [PATCH 61/63] Better repr for failing `ipow` special case test --- .github/workflows/numpy.yml | 2 ++ array_api_tests/test_special_cases.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml index 6f45e636..d88e6342 100644 --- a/.github/workflows/numpy.yml +++ b/.github/workflows/numpy.yml @@ -45,6 +45,8 @@ jobs: array_api_tests/test_set_functions.py # https://github.com/numpy/numpy/issues/21211 array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] + # https://github.com/numpy/numpy/issues/21213 + array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] # noted diversions from spec array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index a68171b7..b8979298 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -437,7 +437,8 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, BoundFromDtype]: elif cond_str == "an odd integer value": cond = lambda i: i.is_integer() and i % 2 == 1 expr_template = "{}.is_integer() and {} % 2 == 1" - from_dtype = integers_from_dtype # type: ignore + if not_cond: + expr_template = f"({expr_template})" def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1) From 27a00d71e245cb1aa5ea6d8ef596f88ed1e8475b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 18 Mar 2022 10:07:02 +0000 Subject: [PATCH 62/63] Update workflow with another failing `__ipow__()` case --- .github/workflows/numpy.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml index d88e6342..e3819ca4 100644 --- a/.github/workflows/numpy.yml +++ b/.github/workflows/numpy.yml @@ -46,6 +46,7 @@ jobs: # https://github.com/numpy/numpy/issues/21211 array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] # https://github.com/numpy/numpy/issues/21213 + array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] # noted diversions from spec array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] From 6642522d97ff5f23020bfb2308db7d30963965c5 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 22 Mar 2022 11:28:42 +0000 Subject: [PATCH 63/63] Note submodule in readme --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 1d4ad770..9eebc397 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,15 @@ welcome! ### Setup +Currently we pin the Array API specification repo [`array-api`](https://github.com/data-apis/array-api/) +as a git submodule. This might change in the future to better support vendoring +use cases (see [#107](https://github.com/data-apis/array-api-tests/issues/107)), +but for now be sure submodules are pulled too, e.g. + +```bash +$ git submodule update --init +``` + To run the tests, install the testing dependencies. ```bash