|
10 | 10 | from . import dtype_helpers as dh
|
11 | 11 | from . import hypothesis_helpers as hh
|
12 | 12 | from . import pytest_helpers as ph
|
| 13 | +from . import shape_helpers as sh |
13 | 14 | from . import xps
|
14 | 15 | from .typing import DataType, Scalar
|
15 | 16 |
|
@@ -186,14 +187,59 @@ def test_arange(dtype, data):
|
186 | 187 | ), f"out[0]={out[0]}, but should be {_start} {f_func}"
|
187 | 188 |
|
188 | 189 |
|
189 |
| -@given(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1), data=st.data()) |
190 |
| -def test_asarray_scalars(dtype, shape, data): |
191 |
| - obj = data.draw(hh.scalar_objects(dtype, shape), label="obj") |
| 190 | +@given( |
| 191 | + shape=hh.shapes(min_side=1), |
| 192 | + data=st.data(), |
| 193 | +) |
| 194 | +def test_asarray_scalars(shape, data): |
192 | 195 | kw = data.draw(
|
193 |
| - hh.kwargs(dtype=st.sampled_from([None, dtype]), copy=st.none()), label="kw" |
| 196 | + hh.kwargs(dtype=st.none() | xps.scalar_dtypes(), copy=st.none()), label="kw" |
194 | 197 | )
|
| 198 | + dtype = kw.get("dtype", None) |
| 199 | + if dtype is None: |
| 200 | + dtype_family = data.draw( |
| 201 | + st.sampled_from( |
| 202 | + [(xp.bool,), (xp.int32, xp.int64), (xp.float32, xp.float64)] |
| 203 | + ), |
| 204 | + label="expected out dtypes", |
| 205 | + ) |
| 206 | + _dtype = dtype_family[0] |
| 207 | + else: |
| 208 | + _dtype = dtype |
| 209 | + if dh.is_float_dtype(_dtype): |
| 210 | + elements_strat = xps.from_dtype(_dtype) | xps.from_dtype(xp.int32) |
| 211 | + elif dh.is_int_dtype(_dtype): |
| 212 | + elements_strat = xps.from_dtype(_dtype) | st.booleans() |
| 213 | + else: |
| 214 | + elements_strat = xps.from_dtype(_dtype) |
| 215 | + size = math.prod(shape) |
| 216 | + obj_strat = st.lists(elements_strat, min_size=size, max_size=size) |
| 217 | + if dtype is None: |
| 218 | + # For asarray to infer the dtype we're testing, obj requires at least |
| 219 | + # one element to be the scalar equivalent of the inferred dtype, and so |
| 220 | + # we filter out invalid examples. Note we use type() as Python booleans |
| 221 | + # instance check with ints e.g. isinstance(False, int) == True. |
| 222 | + scalar_type = dh.get_scalar_type(_dtype) |
| 223 | + obj_strat = obj_strat.filter(lambda l: any(type(e) == scalar_type for e in l)) |
| 224 | + obj_strat = obj_strat.map(lambda l: sh.reshape(l, shape)) |
| 225 | + obj = data.draw(obj_strat, label="obj") |
| 226 | + |
| 227 | + out = xp.asarray(obj, **kw) |
195 | 228 |
|
196 |
| - xp.asarray(obj, **kw) |
| 229 | + if dtype is None: |
| 230 | + msg = f"out.dtype={dh.dtype_to_name[out.dtype]}, should be " |
| 231 | + if dtype_family == (xp.float32, xp.float64): |
| 232 | + msg += "default floating-point dtype (float32 or float64)" |
| 233 | + elif dtype_family == (xp.int32, xp.int64): |
| 234 | + msg += "default integer dtype (int32 or int64)" |
| 235 | + else: |
| 236 | + msg += "boolean dtype" |
| 237 | + msg += " [asarray()]" |
| 238 | + assert out.dtype in dtype_family, msg |
| 239 | + else: |
| 240 | + assert kw["dtype"] == _dtype # sanity check |
| 241 | + ph.assert_kw_dtype("asarray", _dtype, out.dtype) |
| 242 | + ph.assert_shape("asarray", out.shape, shape) |
197 | 243 |
|
198 | 244 |
|
199 | 245 | # TODO: test asarray with arrays and copy (in a seperate method)
|
|
0 commit comments