diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index ed39127..9bb728a 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.10", "3.11", "3.12"] fail-fast: false steps: diff --git a/data_prototype/containers.py b/data_prototype/containers.py index 4d87446..e6fb72e 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -1,5 +1,15 @@ from dataclasses import dataclass -from typing import Protocol, Dict, Tuple, Optional, Any, Union, Callable, MutableMapping +from typing import ( + Protocol, + Dict, + Tuple, + Optional, + Any, + Union, + Callable, + MutableMapping, + TypeAlias, +) import uuid from cachetools import LFUCache @@ -16,6 +26,9 @@ def __sub__(self, other) -> "_MatplotlibTransform": ... +ShapeSpec: TypeAlias = Tuple[Union[str, int], ...] + + @dataclass(frozen=True) class Desc: # TODO: sort out how to actually spell this. We need to know: @@ -24,12 +37,65 @@ class Desc: # - is this a variable size depending on the query (e.g. N) # - what is the relative size to the other variable values (N vs N+1) # We are probably going to have to implement a DSL for this (😞) - shape: Tuple[Union[str, int], ...] + shape: ShapeSpec # TODO: is using a string better? dtype: np.dtype # TODO: do we want to include this at this level? "naive" means unit-unaware. units: str = "naive" + @staticmethod + def validate_shapes( + specification: dict[str, ShapeSpec | "Desc"], + actual: dict[str, ShapeSpec | "Desc"], + *, + broadcast=False, + ) -> bool: + specvars: dict[str, int | tuple[str, int]] = {} + for fieldname in specification: + spec = specification[fieldname] + if fieldname not in actual: + raise KeyError( + f"Actual is missing {fieldname!r}, required by specification." + ) + desc = actual[fieldname] + if isinstance(spec, Desc): + spec = spec.shape + if isinstance(desc, Desc): + desc = desc.shape + if not broadcast: + if len(spec) != len(desc): + raise ValueError( + f"{fieldname!r} shape {desc} incompatible with specification " + f"{spec}." + ) + elif len(desc) > len(spec): + raise ValueError( + f"{fieldname!r} shape {desc} incompatible with specification " + f"{spec}." + ) + for speccomp, desccomp in zip(spec[::-1], desc[::-1]): + if broadcast and desccomp == 1: + continue + if isinstance(speccomp, str): + specv, specoff = speccomp[0], int(speccomp[1:] or 0) + + if isinstance(desccomp, str): + descv, descoff = desccomp[0], int(desccomp[1:] or 0) + entry = (descv, descoff - specoff) + else: + entry = desccomp - specoff + + if specv in specvars and entry != specvars[specv]: + raise ValueError(f"Found two incompatible values for {specv!r}") + + specvars[specv] = entry + elif speccomp != desccomp: + raise ValueError( + f"{fieldname!r} shape {desc} incompatible with specification " + f"{spec}" + ) + return None + class DataContainer(Protocol): def query( diff --git a/data_prototype/tests/test_check_shape.py b/data_prototype/tests/test_check_shape.py new file mode 100644 index 0000000..0f8d6bb --- /dev/null +++ b/data_prototype/tests/test_check_shape.py @@ -0,0 +1,133 @@ +import pytest + +from data_prototype.containers import Desc + + +@pytest.mark.parametrize( + "spec,actual", + [ + ([()], [()]), + ([(3,)], [(3,)]), + ([("N",)], [(3,)]), + ([("N",)], [("X",)]), + ([("N+1",)], [(3,)]), + ([("N", "N+1")], [(3, 4)]), + ([("N", "N-1")], [(3, 2)]), + ([("N", "N+10")], [(3, 13)]), + ([("N", "N+1")], [("X", "X+1")]), + ([("N", "N+9")], [("X", "X+9")]), + ([("N",), ("N",)], [("X",), ("X",)]), + ([("N",), ("N",)], [(3,), (3,)]), + ([("N",), ("N+1",)], [(3,), (4,)]), + ([("N", "M")], [(3, 4)]), + ([("N", "M")], [("X", "Y")]), + ([("N", "M")], [("X", "X")]), + ([("N", "M", 3)], [(3, 4, 3)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]), + ([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]), + ], +) +def test_passing_no_broadcast( + spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] +): + spec = {var: shape for var, shape in zip("abcdefg", spec)} + actual = {var: shape for var, shape in zip("abcdefg", actual)} + Desc.validate_shapes(spec, actual) + + +@pytest.mark.parametrize( + "spec,actual", + [ + ([(2,)], [()]), + ([(3,)], [(4,)]), + ([(3,)], [(1,)]), + ([("N",)], [(3, 4)]), + ([("N", "N+1")], [(4, 4)]), + ([("N", "N-1")], [(4, 4)]), + ([("N", "N+1")], [("X", "Y")]), + ([("N", "N+1")], [("X", 3)]), + ([("N",), ("N",)], [(3,), (4,)]), + ([("N", "N")], [("X", "Y")]), + ([("N", "M", 3)], [(3, 4, 4)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]), + ], +) +def test_failing_no_broadcast( + spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] +): + spec = {var: shape for var, shape in zip("abcdefg", spec)} + actual = {var: shape for var, shape in zip("abcdefg", actual)} + with pytest.raises(ValueError): + Desc.validate_shapes(spec, actual) + + +@pytest.mark.parametrize( + "spec,actual", + [ + ([()], [()]), + ([(2,)], [()]), + ([(3,)], [(3,)]), + ([(3,)], [(1,)]), + ([("N",)], [(3,)]), + ([("N",)], [("X",)]), + ([("N", 4)], [(3, 1)]), + ([("N+1",)], [(3,)]), + ([("N", "N+1")], [(3, 4)]), + ([("N", "N+1")], [("X", "X+1")]), + ([("N", "N+1")], [("X", 1)]), + ([("N",), ("N",)], [("X",), ("X",)]), + ([("N",), ("N+1",)], [("X",), (1,)]), + ([("N",), ("N+1",)], [(3,), (4,)]), + ([("N",), ("N+1",)], [(1,), (4,)]), + ([("N", "M")], [(3, 4)]), + ([("N", "M")], [("X", "Y")]), + ([("N", "M")], [("X", "X")]), + ([("N", "M", 3)], [(3, 4, 3)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 4)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 1)]), + ([("N",), ("M",), ("N", "M")], [("X",), ("Y",), ("X", "Y")]), + ], +) +def test_passing_broadcast( + spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] +): + spec = {var: shape for var, shape in zip("abcdefg", spec)} + actual = {var: shape for var, shape in zip("abcdefg", actual)} + Desc.validate_shapes(spec, actual, broadcast=True) + + +@pytest.mark.parametrize( + "spec,actual", + [ + ([(1,)], [(3,)]), + ([(3,)], [(4,)]), + ([("N",)], [(3, 4)]), + ([("N", "N+1")], [(4, 4)]), + ([("N", "N+1")], [("X", "Y")]), + ([("N", "N+1")], [("X", 3)]), + ([("N",), ("N",)], [(3,), (4,)]), + ([("N", "N")], [("X", "Y")]), + ([("N", "M", 3)], [(3, 4, 4)]), + ([("N",), ("M",), ("N", "M")], [(3,), (4,), (3, 5)]), + ], +) +def test_failing_broadcast( + spec: list[tuple[int | str, ...]], actual: list[tuple[int | str, ...]] +): + spec = {var: shape for var, shape in zip("abcdefg", spec)} + actual = {var: shape for var, shape in zip("abcdefg", actual)} + with pytest.raises(ValueError): + Desc.validate_shapes(spec, actual, broadcast=True) + + +def test_desc_object(): + spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)} + actual = {"a": Desc((3,), float), "b": Desc((4,), float)} + Desc.validate_shapes(spec, actual) + + +def test_missing_key(): + spec = {"a": Desc(("N",), float), "b": Desc(("N+1",), float)} + actual = {"a": Desc((3,), float)} + with pytest.raises(KeyError): + Desc.validate_shapes(spec, actual) diff --git a/setup.py b/setup.py index 3e0bc45..eb88ad7 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # NOTE: This file must remain Python 2 compatible for the foreseeable future, # to ensure that we error out properly for people with outdated setuptools # and/or pip. -min_version = (3, 9) +min_version = (3, 10) if sys.version_info < min_version: error = """ data_prototype does not support Python {0}.{1}.