Skip to content

Commit fca3448

Browse files
committed
Initial implementeation of shape checking method
1 parent 2e70f58 commit fca3448

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

data_prototype/containers.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
from dataclasses import dataclass
2-
from typing import Protocol, Dict, Tuple, Optional, Any, Union, Callable, MutableMapping
2+
from typing import (
3+
Protocol,
4+
Dict,
5+
Tuple,
6+
Optional,
7+
Any,
8+
Union,
9+
Callable,
10+
MutableMapping,
11+
TypeAlias,
12+
)
313
import uuid
414

515
from cachetools import LFUCache
@@ -16,6 +26,9 @@ def __sub__(self, other) -> "_MatplotlibTransform":
1626
...
1727

1828

29+
ShapeSpec: TypeAlias = Tuple[Union[str, int], ...]
30+
31+
1932
@dataclass(frozen=True)
2033
class Desc:
2134
# TODO: sort out how to actually spell this. We need to know:
@@ -24,12 +37,41 @@ class Desc:
2437
# - is this a variable size depending on the query (e.g. N)
2538
# - what is the relative size to the other variable values (N vs N+1)
2639
# We are probably going to have to implement a DSL for this (😞)
27-
shape: Tuple[Union[str, int], ...]
40+
shape: ShapeSpec
2841
# TODO: is using a string better?
2942
dtype: np.dtype
3043
# TODO: do we want to include this at this level? "naive" means unit-unaware.
3144
units: str = "naive"
3245

46+
@staticmethod
47+
def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool:
48+
specvars: dict[str, int | tuple[str, int]] = {}
49+
for spec, desc in args:
50+
if not broadcast:
51+
if len(spec) != len(desc.shape):
52+
return False
53+
elif len(desc.shape) > len(spec):
54+
return False
55+
for speccomp, desccomp in zip(spec[::-1], desc.shape[::-1]):
56+
if broadcast and desccomp == 1:
57+
continue
58+
if isinstance(speccomp, str):
59+
specv, specoff = speccomp[0], int(speccomp[1:] or 0)
60+
61+
if isinstance(desccomp, str):
62+
descv, descoff = speccomp[0], int(speccomp[1:] or 0)
63+
entry = (descv, descoff - specoff)
64+
else:
65+
entry = desccomp - specoff
66+
67+
if specv in specvars and entry != specvars[specv]:
68+
return False
69+
70+
specvars[specv] = entry
71+
elif speccomp != desccomp:
72+
return False
73+
return True
74+
3375

3476
class DataContainer(Protocol):
3577
def query(

0 commit comments

Comments
 (0)