Skip to content

Jax float64 precision issues do not play ball with hypothesis #368

Open
@ev-br

Description

@ev-br

A typical example is (test_diff):

self = <hypothesis.extra.array_api.ArrayStrategy object at 0x7e6a6cf6c990>, val = 2.112233982580733, val_0d = Array(2.1122339, dtype=float32)
strategy = FloatStrategy(min_value=2.0, max_value=64.0, allow_nan=False, smallest_nonzero_magnitude=2.2250738585072014e-308)

    def check_set_value(self, val, val_0d, strategy):
        if val == val and self.builtin(val_0d) != val:
            if self.builtin is float:
                assert self.finfo is not None  # for mypy
                try:
                    is_subnormal = 0 < abs(val) < self.finfo.smallest_normal
                except Exception:
                    # val may be a non-float that does not support the
                    # operations __lt__ and __abs__
                    is_subnormal = False
                if is_subnormal:
                    raise InvalidArgument(
                        f"Generated subnormal float {val} from strategy "
                        f"{strategy} resulted in {val_0d!r}, probably "
                        f"as a result of array module {self.xp.__name__} "
                        "being built with flush-to-zero compiler options. "
                        "Consider passing allow_subnormal=False."
                    )
>           raise InvalidArgument(
                f"Generated array element {val!r} from strategy {strategy} "
                f"cannot be represented with dtype {self.dtype}. "
                f"Array module {self.xp.__name__} instead "
                f"represents the element as {val_0d}. "
                "Consider using a more precise elements strategy, "
                "for example passing the width argument to floats()."
            )
E           hypothesis.errors.InvalidArgument: Generated array element 2.112233982580733 from strategy FloatStrategy(min_value=2.0, max_value=64.0, allow_nan=False, smallest_nonzero_magnitude=2.2250738585072014e-308) cannot be represented with dtype <class 'jax.numpy.float64'>. Array module jax.numpy instead represents the element as 2.112233877182007. Consider using a more precise elements strategy, for example passing the width argument to floats().
E           while generating 'x' from sampled_from((<class 'jax.numpy.uint8'>, <class 'jax.numpy.int8'>, <class 'jax.numpy.int16'>, <class 'jax.numpy.int32'>, <class 'jax.numpy.float32'>, <class 'jax.numpy.float64'>, <class 'jax.numpy.complex64'>, <class 'jax.numpy.complex128'>)).flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
E           Explanation:
E               These lines were always and only run by failing examples:
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/array.py:328
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/array.py:651
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/numpy/_core/getlimits.py:609

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions