|
16 | 16 | import pytest
|
17 | 17 | from hypothesis import assume, given
|
18 | 18 | from hypothesis.strategies import (booleans, composite, none, tuples, integers,
|
19 |
| - shared, sampled_from, data, just) |
| 19 | + shared, sampled_from, one_of, data, just) |
| 20 | +from ndindex import iter_indices |
20 | 21 |
|
21 | 22 | from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity
|
22 | 23 | from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
|
|
43 | 44 | # Standin strategy for not yet implemented tests
|
44 | 45 | todo = none()
|
45 | 46 |
|
46 |
| -def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw): |
| 47 | +def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), **kw): |
47 | 48 | """
|
48 | 49 | Test that f(*args, **kw) maps across stacks of matrices
|
49 | 50 |
|
50 |
| - dims is the number of dimensions f should have for a single n x m matrix |
51 |
| - stack. |
| 51 | + dims is the number of dimensions f(*args) should have for a single n x m |
| 52 | + matrix stack. |
| 53 | +
|
| 54 | + matrix_axes are the axes along which matrices (or vectors) are stacked in |
| 55 | + the input. |
| 56 | +
|
| 57 | + true_val may be a function such that true_val(*x_stacks, **kw) gives the |
| 58 | + true value for f on a stack. |
| 59 | +
|
| 60 | + res should be the result of f(*args, **kw). It is computed if not passed |
| 61 | + in. |
52 | 62 |
|
53 |
| - true_val may be a function such that true_val(*x_stacks) gives the true |
54 |
| - value for f on a stack |
55 | 63 | """
|
56 | 64 | if res is None:
|
57 | 65 | res = f(*args, **kw)
|
58 | 66 |
|
59 |
| - shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape |
60 |
| - for x in args]) |
61 |
| - for _idx in sh.ndindex(shape[:-2]): |
62 |
| - idx = _idx + (slice(None),)*dims |
63 |
| - res_stack = res[idx] |
64 |
| - x_stacks = [x[_idx + (...,)] for x in args] |
| 67 | + shapes = [x.shape for x in args] |
| 68 | + |
| 69 | + for (x_idxes, (res_idx,)) in zip( |
| 70 | + iter_indices(*shapes, skip_axes=matrix_axes), |
| 71 | + iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))): |
| 72 | + x_idxes = [x_idx.raw for x_idx in x_idxes] |
| 73 | + res_idx = res_idx.raw |
| 74 | + # res should have `dims` slices in it. Cases where there are more than |
| 75 | + # `dims` slices are ambiguous, but that should only occur in cases |
| 76 | + # where axes = (-2, -1). |
| 77 | + # res_idx2 = [] |
| 78 | + # d = dims |
| 79 | + # for i in res_idx: |
| 80 | + # if isinstance(i, slice): |
| 81 | + # if d: |
| 82 | + # res_idx2.append(i) |
| 83 | + # d -= 1 |
| 84 | + # else: |
| 85 | + # res_idx2.append(i) |
| 86 | + # res_idx2 = tuple(res_idx2) |
| 87 | + |
| 88 | + res_stack = res[res_idx] |
| 89 | + x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)] |
65 | 90 | decomp_res_stack = f(*x_stacks, **kw)
|
66 | 91 | assert_exactly_equal(res_stack, decomp_res_stack)
|
67 | 92 | if true_val:
|
|
0 commit comments