1
1
from dataclasses import dataclass
2
- from typing import Protocol , Dict , Tuple , Optional , Any , Union , Callable , MutableMapping
2
+ from typing import (
3
+ Protocol ,
4
+ Dict ,
5
+ Tuple ,
6
+ Optional ,
7
+ Any ,
8
+ Union ,
9
+ Callable ,
10
+ MutableMapping ,
11
+ TypeAlias ,
12
+ )
3
13
import uuid
4
14
5
15
from cachetools import LFUCache
@@ -16,6 +26,9 @@ def __sub__(self, other) -> "_MatplotlibTransform":
16
26
...
17
27
18
28
29
+ ShapeSpec : TypeAlias = Tuple [Union [str , int ], ...]
30
+
31
+
19
32
@dataclass (frozen = True )
20
33
class Desc :
21
34
# TODO: sort out how to actually spell this. We need to know:
@@ -24,12 +37,41 @@ class Desc:
24
37
# - is this a variable size depending on the query (e.g. N)
25
38
# - what is the relative size to the other variable values (N vs N+1)
26
39
# We are probably going to have to implement a DSL for this (😞)
27
- shape : Tuple [ Union [ str , int ], ...]
40
+ shape : ShapeSpec
28
41
# TODO: is using a string better?
29
42
dtype : np .dtype
30
43
# TODO: do we want to include this at this level? "naive" means unit-unaware.
31
44
units : str = "naive"
32
45
46
+ @staticmethod
47
+ def check_shapes (* args : tuple [ShapeSpec , "Desc" ], broadcast = False ) -> bool :
48
+ specvars : dict [str , int | tuple [str , int ]] = {}
49
+ for spec , desc in args :
50
+ if not broadcast :
51
+ if len (spec ) != len (desc .shape ):
52
+ return False
53
+ elif len (desc .shape ) > len (spec ):
54
+ return False
55
+ for speccomp , desccomp in zip (spec [::- 1 ], desc .shape [::- 1 ]):
56
+ if broadcast and desccomp == 1 :
57
+ continue
58
+ if isinstance (speccomp , str ):
59
+ specv , specoff = speccomp [0 ], int (speccomp [1 :] or 0 )
60
+
61
+ if isinstance (desccomp , str ):
62
+ descv , descoff = speccomp [0 ], int (speccomp [1 :] or 0 )
63
+ entry = (descv , descoff - specoff )
64
+ else :
65
+ entry = desccomp - specoff
66
+
67
+ if specv in specvars and entry != specvars [specv ]:
68
+ return False
69
+
70
+ specvars [specv ] = entry
71
+ elif speccomp != desccomp :
72
+ return False
73
+ return True
74
+
33
75
34
76
class DataContainer (Protocol ):
35
77
def query (
0 commit comments