Skip to content

Commit 64cc09d

Browse files
committed
Convert check_shape to raising instead of returning bool/rename to validate_shapes
1 parent 18ccaff commit 64cc09d

File tree

2 files changed

+61
-22
lines changed

2 files changed

+61
-22
lines changed

data_prototype/containers.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,36 @@ class Desc:
4444
units: str = "naive"
4545

4646
@staticmethod
47-
def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool:
47+
def validate_shapes(
48+
specification: dict[str, ShapeSpec | "Desc"],
49+
actual: dict[str, ShapeSpec | "Desc"],
50+
*,
51+
broadcast=False,
52+
) -> bool:
4853
specvars: dict[str, int | tuple[str, int]] = {}
49-
for spec, desc in args:
54+
for fieldname in specification:
55+
spec = specification[fieldname]
56+
if fieldname not in actual:
57+
raise KeyError(
58+
f"Actual is missing {fieldname!r}, required by specification."
59+
)
60+
desc = actual[fieldname]
61+
if isinstance(spec, Desc):
62+
spec = spec.shape
63+
if isinstance(desc, Desc):
64+
desc = desc.shape
5065
if not broadcast:
51-
if len(spec) != len(desc.shape):
52-
return False
53-
elif len(desc.shape) > len(spec):
54-
return False
55-
for speccomp, desccomp in zip(spec[::-1], desc.shape[::-1]):
66+
if len(spec) != len(desc):
67+
raise ValueError(
68+
f"{fieldname!r} shape {desc} incompatible with specification "
69+
f"{spec}."
70+
)
71+
elif len(desc) > len(spec):
72+
raise ValueError(
73+
f"{fieldname!r} shape {desc} incompatible with specification "
74+
f"{spec}."
75+
)
76+
for speccomp, desccomp in zip(spec[::-1], desc[::-1]):
5677
if broadcast and desccomp == 1:
5778
continue
5879
if isinstance(speccomp, str):
@@ -65,12 +86,15 @@ def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool:
6586
entry = desccomp - specoff
6687

6788
if specv in specvars and entry != specvars[specv]:
68-
return False
89+
raise ValueError(f"Found two incompatible values for {specv!r}")
6990

7091
specvars[specv] = entry
7192
elif speccomp != desccomp:
72-
return False
73-
return True
93+
raise ValueError(
94+
f"{fieldname!r} shape {desc} incompatible with specification "
95+
f"{spec}"
96+
)
97+
return None
7498

7599

76100
class DataContainer(Protocol):

data_prototype/tests/test_check_shape.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
def test_passing_no_broadcast(
3131
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
3232
):
33-
assert Desc.check_shapes(
34-
*[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)]
35-
)
33+
spec = {var: shape for var, shape in zip("abcdefg", spec)}
34+
actual = {var: shape for var, shape in zip("abcdefg", actual)}
35+
Desc.validate_shapes(spec, actual)
3636

3737

3838
@pytest.mark.parametrize(
@@ -55,9 +55,10 @@ def test_passing_no_broadcast(
5555
def test_failing_no_broadcast(
5656
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
5757
):
58-
assert not Desc.check_shapes(
59-
*[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)]
60-
)
58+
spec = {var: shape for var, shape in zip("abcdefg", spec)}
59+
actual = {var: shape for var, shape in zip("abcdefg", actual)}
60+
with pytest.raises(ValueError):
61+
Desc.validate_shapes(spec, actual)
6162

6263

6364
@pytest.mark.parametrize(
@@ -90,9 +91,9 @@ def test_failing_no_broadcast(
9091
def test_passing_broadcast(
9192
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
9293
):
93-
assert Desc.check_shapes(
94-
*[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)], broadcast=True
95-
)
94+
spec = {var: shape for var, shape in zip("abcdefg", spec)}
95+
actual = {var: shape for var, shape in zip("abcdefg", actual)}
96+
Desc.validate_shapes(spec, actual, broadcast=True)
9697

9798

9899
@pytest.mark.parametrize(
@@ -113,6 +114,20 @@ def test_passing_broadcast(
113114
def test_failing_broadcast(
114115
spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]]
115116
):
116-
assert not Desc.check_shapes(
117-
*[(s, Desc(dtype=float, shape=a)) for s, a in zip(spec, actual)], broadcast=True
118-
)
117+
spec = {var: shape for var, shape in zip("abcdefg", spec)}
118+
actual = {var: shape for var, shape in zip("abcdefg", actual)}
119+
with pytest.raises(ValueError):
120+
Desc.validate_shapes(spec, actual, broadcast=True)
121+
122+
123+
def test_desc_object():
124+
spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)}
125+
actual = {"a": Desc((3,), float), "b": Desc((4,), float)}
126+
Desc.validate_shapes(spec, actual)
127+
128+
129+
def test_missing_key():
130+
spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)}
131+
actual = {"a": Desc((3,), float)}
132+
with pytest.raises(KeyError):
133+
Desc.validate_shapes(spec, actual)

0 commit comments

Comments
 (0)