Skip to content

Commit dc8a96b

Browse files
authored
Merge pull request #35 from ksunden/validate_shapes
Initial implementation of shape checking method
2 parents 2e70f58 + 64cc09d commit dc8a96b

File tree

4 files changed

+203
-4
lines changed

4 files changed

+203
-4
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
runs-on: ubuntu-latest
99
strategy:
1010
matrix:
11-
python-version: ["3.9", "3.10"]
11+
python-version: ["3.10", "3.11", "3.12"]
1212
fail-fast: false
1313

1414
steps:

data_prototype/containers.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
from dataclasses import dataclass
2-
from typing import Protocol, Dict, Tuple, Optional, Any, Union, Callable, MutableMapping
2+
from typing import (
3+
Protocol,
4+
Dict,
5+
Tuple,
6+
Optional,
7+
Any,
8+
Union,
9+
Callable,
10+
MutableMapping,
11+
TypeAlias,
12+
)
313
import uuid
414

515
from cachetools import LFUCache
@@ -16,6 +26,9 @@ def __sub__(self, other) -> "_MatplotlibTransform":
1626
...
1727

1828

29+
ShapeSpec: TypeAlias = Tuple[Union[str, int], ...]
30+
31+
1932
@dataclass(frozen=True)
2033
class Desc:
2134
# TODO: sort out how to actually spell this. We need to know:
@@ -24,12 +37,65 @@ class Desc:
2437
# - is this a variable size depending on the query (e.g. N)
2538
# - what is the relative size to the other variable values (N vs N+1)
2639
# We are probably going to have to implement a DSL for this (😞)
27-
shape: Tuple[Union[str, int], ...]
40+
shape: ShapeSpec
2841
# TODO: is using a string better?
2942
dtype: np.dtype
3043
# TODO: do we want to include this at this level? "naive" means unit-unaware.
3144
units: str = "naive"
3245

46+
@staticmethod
47+
def validate_shapes(
48+
specification: dict[str, ShapeSpec | "Desc"],
49+
actual: dict[str, ShapeSpec | "Desc"],
50+
*,
51+
broadcast=False,
52+
) -> bool:
53+
specvars: dict[str, int | tuple[str, int]] = {}
54+
for fieldname in specification:
55+
spec = specification[fieldname]
56+
if fieldname not in actual:
57+
raise KeyError(
58+
f"Actual is missing {fieldname!r}, required by specification."
59+
)
60+
desc = actual[fieldname]
61+
if isinstance(spec, Desc):
62+
spec = spec.shape
63+
if isinstance(desc, Desc):
64+
desc = desc.shape
65+
if not broadcast:
66+
if len(spec) != len(desc):
67+
raise ValueError(
68+
f"{fieldname!r} shape {desc} incompatible with specification "
69+
f"{spec}."
70+
)
71+
elif len(desc) > len(spec):
72+
raise ValueError(
73+
f"{fieldname!r} shape {desc} incompatible with specification "
74+
f"{spec}."
75+
)
76+
for speccomp, desccomp in zip(spec[::-1], desc[::-1]):
77+
if broadcast and desccomp == 1:
78+
continue
79+
if isinstance(speccomp, str):
80+
specv, specoff = speccomp[0], int(speccomp[1:] or 0)
81+
82+
if isinstance(desccomp, str):
83+
descv, descoff = desccomp[0], int(desccomp[1:] or 0)
84+
entry = (descv, descoff - specoff)
85+
else:
86+
entry = desccomp - specoff
87+
88+
if specv in specvars and entry != specvars[specv]:
89+
raise ValueError(f"Found two incompatible values for {specv!r}")
90+
91+
specvars[specv] = entry
92+
elif speccomp != desccomp:
93+
raise ValueError(
94+
f"{fieldname!r} shape {desc} incompatible with specification "
95+
f"{spec}"
96+
)
97+
return None
98+
3399

34100
class DataContainer(Protocol):
35101
def query(
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import pytest
2+
3+
from data_prototype.containers import Desc
4+
5+
6+
@pytest.mark.parametrize(
7+
"spec,actual",
8+
[
9+
([()], [()]),
10+
([(3,)], [(3,)]),
11+
([("N",)], [(3,)]),
12+
([("N",)], [("X",)]),
13+
([("N+1",)], [(3,)]),
14+
([("N", "N+1")], [(3, 4)]),
15+
([("N", "N-1")], [(3, 2)]),
16+
([("N", "N+10")], [(3, 13)]),
17+
([("N", "N+1")], [("X", "X+1")]),
18+
([("N", "N+9")], [("X", "X+9")]),
19+
([("N",), ("N",)], [("X",), ("X",)]),
20+
([("N",), ("N",)], [(3,), (3,)]),
21+
([("N",), ("N+1",)], [(3,), (4,)]),
22+
([("N", "M")], [(3, 4)]),
23+
([("N", "M")], [("X", "Y")]),
24+
([("N", "M")], [("X", "X")]),
25+
([("N", "M", 3)], [(3, 4, 3)]),
26+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]),
27+
([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]),
28+
],
29+
)
30+
def test_passing_no_broadcast(
31+
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
32+
):
33+
spec = {var: shape for var, shape in zip("abcdefg", spec)}
34+
actual = {var: shape for var, shape in zip("abcdefg", actual)}
35+
Desc.validate_shapes(spec, actual)
36+
37+
38+
@pytest.mark.parametrize(
39+
"spec,actual",
40+
[
41+
([(2,)], [()]),
42+
([(3,)], [(4,)]),
43+
([(3,)], [(1,)]),
44+
([("N",)], [(3, 4)]),
45+
([("N", "N+1")], [(4, 4)]),
46+
([("N", "N-1")], [(4, 4)]),
47+
([("N", "N+1")], [("X", "Y")]),
48+
([("N", "N+1")], [("X", 3)]),
49+
([("N",), ("N",)], [(3,), (4,)]),
50+
([("N", "N")], [("X", "Y")]),
51+
([("N", "M", 3)], [(3, 4, 4)]),
52+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]),
53+
],
54+
)
55+
def test_failing_no_broadcast(
56+
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
57+
):
58+
spec = {var: shape for var, shape in zip("abcdefg", spec)}
59+
actual = {var: shape for var, shape in zip("abcdefg", actual)}
60+
with pytest.raises(ValueError):
61+
Desc.validate_shapes(spec, actual)
62+
63+
64+
@pytest.mark.parametrize(
65+
"spec,actual",
66+
[
67+
([()], [()]),
68+
([(2,)], [()]),
69+
([(3,)], [(3,)]),
70+
([(3,)], [(1,)]),
71+
([("N",)], [(3,)]),
72+
([("N",)], [("X",)]),
73+
([("N", 4)], [(3, 1)]),
74+
([("N+1",)], [(3,)]),
75+
([("N", "N+1")], [(3, 4)]),
76+
([("N", "N+1")], [("X", "X+1")]),
77+
([("N", "N+1")], [("X", 1)]),
78+
([("N",), ("N",)], [("X",), ("X",)]),
79+
([("N",), ("N+1",)], [("X",), (1,)]),
80+
([("N",), ("N+1",)], [(3,), (4,)]),
81+
([("N",), ("N+1",)], [(1,), (4,)]),
82+
([("N", "M")], [(3, 4)]),
83+
([("N", "M")], [("X", "Y")]),
84+
([("N", "M")], [("X", "X")]),
85+
([("N", "M", 3)], [(3, 4, 3)]),
86+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]),
87+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 1)]),
88+
([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]),
89+
],
90+
)
91+
def test_passing_broadcast(
92+
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
93+
):
94+
spec = {var: shape for var, shape in zip("abcdefg", spec)}
95+
actual = {var: shape for var, shape in zip("abcdefg", actual)}
96+
Desc.validate_shapes(spec, actual, broadcast=True)
97+
98+
99+
@pytest.mark.parametrize(
100+
"spec,actual",
101+
[
102+
([(1,)], [(3,)]),
103+
([(3,)], [(4,)]),
104+
([("N",)], [(3, 4)]),
105+
([("N", "N+1")], [(4, 4)]),
106+
([("N", "N+1")], [("X", "Y")]),
107+
([("N", "N+1")], [("X", 3)]),
108+
([("N",), ("N",)], [(3,), (4,)]),
109+
([("N", "N")], [("X", "Y")]),
110+
([("N", "M", 3)], [(3, 4, 4)]),
111+
([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]),
112+
],
113+
)
114+
def test_failing_broadcast(
115+
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
116+
):
117+
spec = {var: shape for var, shape in zip("abcdefg", spec)}
118+
actual = {var: shape for var, shape in zip("abcdefg", actual)}
119+
with pytest.raises(ValueError):
120+
Desc.validate_shapes(spec, actual, broadcast=True)
121+
122+
123+
def test_desc_object():
124+
spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)}
125+
actual = {"a": Desc((3,), float), "b": Desc((4,), float)}
126+
Desc.validate_shapes(spec, actual)
127+
128+
129+
def test_missing_key():
130+
spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)}
131+
actual = {"a": Desc((3,), float)}
132+
with pytest.raises(KeyError):
133+
Desc.validate_shapes(spec, actual)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# NOTE: This file must remain Python 2 compatible for the foreseeable future,
77
# to ensure that we error out properly for people with outdated setuptools
88
# and/or pip.
9-
min_version = (3, 9)
9+
min_version = (3, 10)
1010
if sys.version_info < min_version:
1111
error = """
1212
data_prototype does not support Python {0}.{1}.

0 commit comments

Comments
 (0)