Skip to content

Commit 6534c83

Browse files
committed
Add tests of check_shape functionality
1 parent fca3448 commit 6534c83

File tree

2 files changed

+121
-1
lines changed

2 files changed

+121
-1
lines changed

data_prototype/containers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,18 @@ def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool:
5353
elif len(desc.shape) > len(spec):
5454
return False
5555
for speccomp, desccomp in zip(spec[::-1], desc.shape[::-1]):
56+
print(specvars)
5657
if broadcast and desccomp == 1:
5758
continue
5859
if isinstance(speccomp, str):
5960
specv, specoff = speccomp[0], int(speccomp[1:] or 0)
6061

6162
if isinstance(desccomp, str):
62-
descv, descoff = speccomp[0], int(speccomp[1:] or 0)
63+
descv, descoff = desccomp[0], int(desccomp[1:] or 0)
6364
entry = (descv, descoff - specoff)
6465
else:
6566
entry = desccomp - specoff
67+
print(entry)
6668

6769
if specv in specvars and entry != specvars[specv]:
6870
return False
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import pytest
2+
3+
from data_prototype.containers import Desc
4+
5+
6+
@pytest.mark.parametrize(
7+
"spec,actual",
8+
[
9+
([()], [()]),
10+
([(3,)], [(3,)]),
11+
([("N",)], [(3,)]),
12+
([("N",)], [("X",)]),
13+
([("N+1",)], [(3,)]),
14+
([("N", "N+1")], [(3, 4)]),
15+
([("N", "N-1")], [(3, 2)]),
16+
([("N", "N+10")], [(3, 13)]),
17+
([("N", "N+1")], [("X", "X+1")]),
18+
([("N", "N+9")], [("X", "X+9")]),
19+
([("N",), ("N",)], [("X",), ("X",)]),
20+
([("N",), ("N",)], [(3,), (3,)]),
21+
([("N",), ("N+1",)], [(3,), (4,)]),
22+
([("N", "M")], [(3, 4)]),
23+
([("N", "M")], [("X", "Y")]),
24+
([("N", "M")], [("X", "X")]),
25+
([("N", "M", 3)], [(3, 4, 3)]),
26+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]),
27+
([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]),
28+
],
29+
)
30+
def test_passing_no_broadcast(
31+
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
32+
):
33+
assert Desc.check_shapes(
34+
*[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)]
35+
)
36+
37+
38+
@pytest.mark.parametrize(
39+
"spec,actual",
40+
[
41+
([(2,)], [()]),
42+
([(3,)], [(4,)]),
43+
([(3,)], [(1,)]),
44+
([("N",)], [(3, 4)]),
45+
([("N", "N+1")], [(4, 4)]),
46+
([("N", "N-1")], [(4, 4)]),
47+
([("N", "N+1")], [("X", "Y")]),
48+
([("N", "N+1")], [("X", 3)]),
49+
([("N",), ("N",)], [(3,), (4,)]),
50+
([("N", "N")], [("X", "Y")]),
51+
([("N", "M", 3)], [(3, 4, 4)]),
52+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]),
53+
],
54+
)
55+
def test_failing_no_broadcast(
56+
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
57+
):
58+
assert not Desc.check_shapes(
59+
*[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)]
60+
)
61+
62+
63+
@pytest.mark.parametrize(
64+
"spec,actual",
65+
[
66+
([()], [()]),
67+
([(2,)], [()]),
68+
([(3,)], [(3,)]),
69+
([(3,)], [(1,)]),
70+
([("N",)], [(3,)]),
71+
([("N",)], [("X",)]),
72+
([("N", 4)], [(3, 1)]),
73+
([("N+1",)], [(3,)]),
74+
([("N", "N+1")], [(3, 4)]),
75+
([("N", "N+1")], [("X", "X+1")]),
76+
([("N", "N+1")], [("X", 1)]),
77+
([("N",), ("N",)], [("X",), ("X",)]),
78+
([("N",), ("N+1",)], [("X",), (1,)]),
79+
([("N",), ("N+1",)], [(3,), (4,)]),
80+
([("N",), ("N+1",)], [(1,), (4,)]),
81+
([("N", "M")], [(3, 4)]),
82+
([("N", "M")], [("X", "Y")]),
83+
([("N", "M")], [("X", "X")]),
84+
([("N", "M", 3)], [(3, 4, 3)]),
85+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]),
86+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 1)]),
87+
([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]),
88+
],
89+
)
90+
def test_passing_broadcast(
91+
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
92+
):
93+
assert Desc.check_shapes(
94+
*[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)], broadcast=True
95+
)
96+
97+
98+
@pytest.mark.parametrize(
99+
"spec,actual",
100+
[
101+
([(1,)], [(3,)]),
102+
([(3,)], [(4,)]),
103+
([("N",)], [(3, 4)]),
104+
([("N", "N+1")], [(4, 4)]),
105+
([("N", "N+1")], [("X", "Y")]),
106+
([("N", "N+1")], [("X", 3)]),
107+
([("N",), ("N",)], [(3,), (4,)]),
108+
([("N", "N")], [("X", "Y")]),
109+
([("N", "M", 3)], [(3, 4, 4)]),
110+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]),
111+
],
112+
)
113+
def test_failing_broadcast(
114+
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
115+
):
116+
assert not Desc.check_shapes(
117+
*[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)], broadcast=True
118+
)

0 commit comments

Comments
 (0)