|
2 | 2 | from itertools import count
|
3 | 3 | from typing import Iterator, NamedTuple, Union
|
4 | 4 |
|
5 |
| -from hypothesis import assume, given |
| 5 | +from hypothesis import assume, given, note |
6 | 6 | from hypothesis import strategies as st
|
7 | 7 |
|
8 | 8 | from . import _array_module as xp
|
9 | 9 | from . import array_helpers as ah
|
10 | 10 | from . import dtype_helpers as dh
|
11 | 11 | from . import hypothesis_helpers as hh
|
12 | 12 | from . import pytest_helpers as ph
|
| 13 | +from . import shape_helpers as sh |
13 | 14 | from . import xps
|
14 | 15 | from .typing import DataType, Scalar
|
15 | 16 |
|
@@ -186,6 +187,104 @@ def test_arange(dtype, data):
|
186 | 187 | ), f"out[0]={out[0]}, but should be {_start} {f_func}"
|
187 | 188 |
|
188 | 189 |
|
| 190 | +@given( |
| 191 | + shape=hh.shapes(min_side=1), |
| 192 | + data=st.data(), |
| 193 | +) |
| 194 | +def test_asarray_scalars(shape, data): |
| 195 | + kw = data.draw( |
| 196 | + hh.kwargs(dtype=st.none() | xps.scalar_dtypes(), copy=st.none()), label="kw" |
| 197 | + ) |
| 198 | + dtype = kw.get("dtype", None) |
| 199 | + if dtype is None: |
| 200 | + dtype_family = data.draw( |
| 201 | + st.sampled_from( |
| 202 | + [(xp.bool,), (xp.int32, xp.int64), (xp.float32, xp.float64)] |
| 203 | + ), |
| 204 | + label="expected out dtypes", |
| 205 | + ) |
| 206 | + _dtype = dtype_family[0] |
| 207 | + else: |
| 208 | + _dtype = dtype |
| 209 | + if dh.is_float_dtype(_dtype): |
| 210 | + elements_strat = xps.from_dtype(_dtype) | xps.from_dtype(xp.int32) |
| 211 | + elif dh.is_int_dtype(_dtype): |
| 212 | + elements_strat = xps.from_dtype(_dtype) | st.booleans() |
| 213 | + else: |
| 214 | + elements_strat = xps.from_dtype(_dtype) |
| 215 | + size = math.prod(shape) |
| 216 | + obj_strat = st.lists(elements_strat, min_size=size, max_size=size) |
| 217 | + scalar_type = dh.get_scalar_type(_dtype) |
| 218 | + if dtype is None: |
| 219 | + # For asarray to infer the dtype we're testing, obj requires at least |
| 220 | + # one element to be the scalar equivalent of the inferred dtype, and so |
| 221 | + # we filter out invalid examples. Note we use type() as Python booleans |
| 222 | + # instance check with ints e.g. isinstance(False, int) == True. |
| 223 | + obj_strat = obj_strat.filter(lambda l: any(type(e) == scalar_type for e in l)) |
| 224 | + _obj = data.draw(obj_strat, label="_obj") |
| 225 | + obj = sh.reshape(_obj, shape) |
| 226 | + note(f"{obj=}") |
| 227 | + |
| 228 | + out = xp.asarray(obj, **kw) |
| 229 | + |
| 230 | + if dtype is None: |
| 231 | + msg = f"out.dtype={dh.dtype_to_name[out.dtype]}, should be " |
| 232 | + if dtype_family == (xp.float32, xp.float64): |
| 233 | + msg += "default floating-point dtype (float32 or float64)" |
| 234 | + elif dtype_family == (xp.int32, xp.int64): |
| 235 | + msg += "default integer dtype (int32 or int64)" |
| 236 | + else: |
| 237 | + msg += "boolean dtype" |
| 238 | + msg += " [asarray()]" |
| 239 | + assert out.dtype in dtype_family, msg |
| 240 | + else: |
| 241 | + assert kw["dtype"] == _dtype # sanity check |
| 242 | + ph.assert_kw_dtype("asarray", _dtype, out.dtype) |
| 243 | + ph.assert_shape("asarray", out.shape, shape) |
| 244 | + for idx, v_expect in zip(sh.ndindex(out.shape), _obj): |
| 245 | + v = scalar_type(out[idx]) |
| 246 | + ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw) |
| 247 | + |
| 248 | + |
| 249 | +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), st.data()) |
| 250 | +def test_asarray_arrays(x, data): |
| 251 | + # TODO: test other valid dtypes |
| 252 | + kw = data.draw( |
| 253 | + hh.kwargs(dtype=st.none() | st.just(x.dtype), copy=st.none() | st.booleans()), |
| 254 | + label="kw", |
| 255 | + ) |
| 256 | + |
| 257 | + out = xp.asarray(x, **kw) |
| 258 | + |
| 259 | + dtype = kw.get("dtype", None) |
| 260 | + if dtype is None: |
| 261 | + ph.assert_dtype("asarray", x.dtype, out.dtype) |
| 262 | + else: |
| 263 | + ph.assert_kw_dtype("asarray", dtype, out.dtype) |
| 264 | + ph.assert_shape("asarray", out.shape, x.shape) |
| 265 | + if dtype is None or dtype == x.dtype: |
| 266 | + ph.assert_array("asarray", out, x, **kw) |
| 267 | + else: |
| 268 | + pass # TODO |
| 269 | + copy = kw.get("copy", None) |
| 270 | + if copy is not None: |
| 271 | + idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx") |
| 272 | + _dtype = x.dtype if dtype is None else dtype |
| 273 | + old_value = x[idx] |
| 274 | + value = data.draw( |
| 275 | + xps.arrays(dtype=_dtype, shape=()).filter(lambda y: y != old_value), |
| 276 | + label="mutating value", |
| 277 | + ) |
| 278 | + x[idx] = value |
| 279 | + note(f"mutated {x=}") |
| 280 | + if copy: |
| 281 | + assert not xp.all( |
| 282 | + out == x |
| 283 | + ), "xp.all(out == x)=True, but should be False after x was mutated\n{out=}" |
| 284 | + elif copy is False: |
| 285 | + pass # TODO |
| 286 | + |
| 287 | + |
189 | 288 | @given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))
|
190 | 289 | def test_empty(shape, kw):
|
191 | 290 | out = xp.empty(shape, **kw)
|
|
0 commit comments