Skip to content

Commit a38c182

Browse files
authored
Merge pull request #39 from ksunden/edges
Initial rework as conversion edges
2 parents 830ce25 + 3c0c6ff commit a38c182

15 files changed

+520
-174
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ default_language_version:
22
python: python3
33
repos:
44
- repo: https://github.com/ambv/black
5-
rev: 23.3.0
5+
rev: 24.2.0
66
hooks:
77
- id: black
88
- repo: https://github.com/pre-commit/pre-commit-hooks
99
rev: v2.0.0
1010
hooks:
1111
- id: flake8
1212
- repo: https://github.com/kynan/nbstripout
13-
rev: 0.6.1
13+
rev: 0.7.1
1414
hooks:
1515
- id: nbstripout

data_prototype/axes.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,28 @@ def scatter(
102102
pipeline.append(lambda x: np.ma.ravel(x))
103103
pipeline.append(lambda y: np.ma.ravel(y))
104104
pipeline.append(
105-
lambda s: np.ma.ravel(s)
106-
if s is not None
107-
else [20]
108-
if mpl.rcParams["_internal.classic_mode"]
109-
else [mpl.rcParams["lines.markersize"] ** 2.0]
105+
lambda s: (
106+
np.ma.ravel(s)
107+
if s is not None
108+
else (
109+
[20]
110+
if mpl.rcParams["_internal.classic_mode"]
111+
else [mpl.rcParams["lines.markersize"] ** 2.0]
112+
)
113+
)
110114
)
111115
# TODO plotnonfinite/mask combining
112116
pipeline.append(
113-
lambda marker: marker
114-
if marker is not None
115-
else mpl.rcParams["scatter.marker"]
117+
lambda marker: (
118+
marker if marker is not None else mpl.rcParams["scatter.marker"]
119+
)
116120
)
117121
pipeline.append(
118-
lambda marker: marker
119-
if isinstance(marker, mmarkers.MarkerStyle)
120-
else mmarkers.MarkerStyle(marker)
122+
lambda marker: (
123+
marker
124+
if isinstance(marker, mmarkers.MarkerStyle)
125+
else mmarkers.MarkerStyle(marker)
126+
)
121127
)
122128
pipeline.append(
123129
FunctionConversionNode.from_funcs(

data_prototype/containers.py

Lines changed: 85 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from dataclasses import dataclass
1+
from __future__ import annotations
2+
23
from typing import (
34
Protocol,
45
Dict,
@@ -8,7 +9,6 @@
89
Union,
910
Callable,
1011
MutableMapping,
11-
TypeAlias,
1212
)
1313
import uuid
1414

@@ -17,92 +17,25 @@
1717
import numpy as np
1818
import pandas as pd
1919

20+
from .description import Desc, desc_like
2021

21-
class _MatplotlibTransform(Protocol):
22-
def transform(self, verts):
23-
...
24-
25-
def __sub__(self, other) -> "_MatplotlibTransform":
26-
...
22+
from typing import TYPE_CHECKING
2723

24+
if TYPE_CHECKING:
25+
from .conversion_edge import Graph
2826

29-
ShapeSpec: TypeAlias = Tuple[Union[str, int], ...]
3027

28+
class _MatplotlibTransform(Protocol):
29+
def transform(self, verts): ...
3130

32-
@dataclass(frozen=True)
33-
class Desc:
34-
# TODO: sort out how to actually spell this. We need to know:
35-
# - what the number of dimensions is (1d vs 2d vs ...)
36-
# - is this a fixed size dimension (e.g. 2 for xextent)
37-
# - is this a variable size depending on the query (e.g. N)
38-
# - what is the relative size to the other variable values (N vs N+1)
39-
# We are probably going to have to implement a DSL for this (😞)
40-
shape: ShapeSpec
41-
# TODO: is using a string better?
42-
dtype: np.dtype
43-
# TODO: do we want to include this at this level? "naive" means unit-unaware.
44-
units: str = "naive"
45-
46-
@staticmethod
47-
def validate_shapes(
48-
specification: dict[str, ShapeSpec | "Desc"],
49-
actual: dict[str, ShapeSpec | "Desc"],
50-
*,
51-
broadcast=False,
52-
) -> bool:
53-
specvars: dict[str, int | tuple[str, int]] = {}
54-
for fieldname in specification:
55-
spec = specification[fieldname]
56-
if fieldname not in actual:
57-
raise KeyError(
58-
f"Actual is missing {fieldname!r}, required by specification."
59-
)
60-
desc = actual[fieldname]
61-
if isinstance(spec, Desc):
62-
spec = spec.shape
63-
if isinstance(desc, Desc):
64-
desc = desc.shape
65-
if not broadcast:
66-
if len(spec) != len(desc):
67-
raise ValueError(
68-
f"{fieldname!r} shape {desc} incompatible with specification "
69-
f"{spec}."
70-
)
71-
elif len(desc) > len(spec):
72-
raise ValueError(
73-
f"{fieldname!r} shape {desc} incompatible with specification "
74-
f"{spec}."
75-
)
76-
for speccomp, desccomp in zip(spec[::-1], desc[::-1]):
77-
if broadcast and desccomp == 1:
78-
continue
79-
if isinstance(speccomp, str):
80-
specv, specoff = speccomp[0], int(speccomp[1:] or 0)
81-
82-
if isinstance(desccomp, str):
83-
descv, descoff = desccomp[0], int(desccomp[1:] or 0)
84-
entry = (descv, descoff - specoff)
85-
else:
86-
entry = desccomp - specoff
87-
88-
if specv in specvars and entry != specvars[specv]:
89-
raise ValueError(f"Found two incompatible values for {specv!r}")
90-
91-
specvars[specv] = entry
92-
elif speccomp != desccomp:
93-
raise ValueError(
94-
f"{fieldname!r} shape {desc} incompatible with specification "
95-
f"{spec}"
96-
)
97-
return None
31+
def __sub__(self, other) -> "_MatplotlibTransform": ...
9832

9933

10034
class DataContainer(Protocol):
10135
def query(
10236
self,
103-
# TODO 3D?!!
104-
coord_transform: _MatplotlibTransform,
105-
size: Tuple[int, int],
37+
graph: Graph,
38+
parent_coordinates: str = "axes",
10639
/,
10740
) -> Tuple[Dict[str, Any], Union[str, int]]:
10841
"""
@@ -132,6 +65,7 @@ def query(
13265
This is a key that clients can use to cache down-stream
13366
computations on this data.
13467
"""
68+
...
13569

13670
def describe(self) -> Dict[str, Desc]:
13771
"""
@@ -141,27 +75,29 @@ def describe(self) -> Dict[str, Desc]:
14175
-------
14276
Dict[str, Desc]
14377
"""
78+
...
14479

14580

146-
class NoNewKeys(ValueError):
147-
...
81+
class NoNewKeys(ValueError): ...
14882

14983

15084
class ArrayContainer:
15185
def __init__(self, **data):
15286
self._data = data
15387
self._cache_key = str(uuid.uuid4())
15488
self._desc = {
155-
k: Desc(v.shape, v.dtype)
156-
if isinstance(v, np.ndarray)
157-
else Desc((), type(v))
89+
k: (
90+
Desc(v.shape, v.dtype)
91+
if isinstance(v, np.ndarray)
92+
else Desc((), type(v))
93+
)
15894
for k, v in data.items()
15995
}
16096

16197
def query(
16298
self,
163-
coord_transform: _MatplotlibTransform,
164-
size: Tuple[int, int],
99+
graph: Graph,
100+
parent_coordinates: str = "axes",
165101
) -> Tuple[Dict[str, Any], Union[str, int]]:
166102
return dict(self._data), self._cache_key
167103

@@ -185,8 +121,8 @@ def __init__(self, **shapes):
185121

186122
def query(
187123
self,
188-
coord_transform: _MatplotlibTransform,
189-
size: Tuple[int, int],
124+
graph: Graph,
125+
parent_coordinates: str = "axes",
190126
) -> Tuple[Dict[str, Any], Union[str, int]]:
191127
return {k: np.random.randn(*d.shape) for k, d in self._desc.items()}, str(
192128
uuid.uuid4()
@@ -253,31 +189,44 @@ def _query_hash(self, coord_transform, size):
253189

254190
def query(
255191
self,
256-
coord_transform: _MatplotlibTransform,
257-
size: Tuple[int, int],
192+
graph: Graph,
193+
parent_coordinates: str = "axes",
258194
) -> Tuple[Dict[str, Any], Union[str, int]]:
259-
hash_key = self._query_hash(coord_transform, size)
260-
if hash_key in self._cache:
261-
return self._cache[hash_key], hash_key
195+
# hash_key = self._query_hash(coord_transform, size)
196+
# if hash_key in self._cache:
197+
# return self._cache[hash_key], hash_key
198+
199+
desc = Desc(("N",), np.dtype("f8"))
200+
xy = {"x": desc, "y": desc}
201+
data_lim = graph.evaluator(
202+
desc_like(xy, coordinates="data"),
203+
desc_like(xy, coordinates=parent_coordinates),
204+
).inverse
205+
206+
screen_size = graph.evaluator(
207+
desc_like(xy, coordinates=parent_coordinates),
208+
desc_like(xy, coordinates="display"),
209+
)
262210

263-
xpix, ypix = size
264-
x_data, _ = coord_transform.transform(
265-
np.vstack(
266-
[
267-
np.linspace(0, 1, int(xpix) * 2),
268-
np.zeros(int(xpix) * 2),
269-
]
270-
).T
271-
).T
272-
_, y_data = coord_transform.transform(
273-
np.vstack(
274-
[
275-
np.zeros(int(ypix) * 2),
276-
np.linspace(0, 1, int(ypix) * 2),
277-
]
278-
).T
279-
).T
211+
screen_dims = screen_size.evaluate({"x": [0, 1], "y": [0, 1]})
212+
xpix, ypix = np.ceil(np.abs(np.diff(screen_dims["x"]))), np.ceil(
213+
np.abs(np.diff(screen_dims["y"]))
214+
)
280215

216+
x_data = data_lim.evaluate(
217+
{
218+
"x": np.linspace(0, 1, int(xpix) * 2),
219+
"y": np.zeros(int(xpix) * 2),
220+
}
221+
)["x"]
222+
y_data = data_lim.evaluate(
223+
{
224+
"x": np.zeros(int(ypix) * 2),
225+
"y": np.linspace(0, 1, int(ypix) * 2),
226+
}
227+
)["y"]
228+
229+
hash_key = str(uuid.uuid4())
281230
ret = self._cache[hash_key] = dict(
282231
**{k: f(x_data) for k, f in self._xfuncs.items()},
283232
**{k: f(y_data) for k, f in self._yfuncs.items()},
@@ -302,11 +251,21 @@ def __init__(self, raw_data, num_bins: int):
302251

303252
def query(
304253
self,
305-
coord_transform: _MatplotlibTransform,
306-
size: Tuple[int, int],
254+
graph: Graph,
255+
parent_coordinates: str = "axes",
307256
) -> Tuple[Dict[str, Any], Union[str, int]]:
308257
dmin, dmax = self._full_range
309-
xmin, ymin, xmax, ymax = coord_transform.transform([[0, 0], [1, 1]]).flatten()
258+
259+
desc = Desc(("N",), np.dtype("f8"))
260+
xy = {"x": desc, "y": desc}
261+
data_lim = graph.evaluator(
262+
desc_like(xy, coordinates="data"),
263+
desc_like(xy, coordinates=parent_coordinates),
264+
).inverse
265+
266+
pts = data_lim.evaluate({"x": (0, 1), "y": (0, 1)})
267+
xmin, xmax = pts["x"]
268+
ymin, ymax = pts["y"]
310269

311270
xmin, xmax = np.clip([xmin, xmax], dmin, dmax)
312271
hash_key = hash((xmin, xmax))
@@ -333,7 +292,7 @@ def describe(self) -> Dict[str, Desc]:
333292

334293

335294
class SeriesContainer:
336-
_data: pd.DataFrame
295+
_data: pd.Series
337296
_index_name: str
338297
_hash_key: str
339298

@@ -350,8 +309,8 @@ def __init__(self, series: pd.Series, *, index_name: str, col_name: str):
350309

351310
def query(
352311
self,
353-
coord_transform: _MatplotlibTransform,
354-
size: Tuple[int, int],
312+
graph: Graph,
313+
parent_coordinates: str = "axes",
355314
) -> Tuple[Dict[str, Any], Union[str, int]]:
356315
return {
357316
self._index_name: self._data.index.values,
@@ -392,8 +351,8 @@ def __init__(
392351

393352
def query(
394353
self,
395-
coord_transform: _MatplotlibTransform,
396-
size: Tuple[int, int],
354+
graph: Graph,
355+
parent_coordinates: str = "axes",
397356
) -> Tuple[Dict[str, Any], Union[str, int]]:
398357
ret = {}
399358
if self._index_name is not None:
@@ -415,10 +374,10 @@ def __init__(self, data: DataContainer, mapping: Dict[str, str]):
415374

416375
def query(
417376
self,
418-
coord_transform: _MatplotlibTransform,
419-
size: Tuple[int, int],
377+
graph: Graph,
378+
parent_coordinates: str = "axes",
420379
) -> Tuple[Dict[str, Any], Union[str, int]]:
421-
base, cache_key = self._data.query(coord_transform, size)
380+
base, cache_key = self._data.query(graph, parent_coordinates)
422381
return {v: base[k] for k, v in self._mapping.items()}, cache_key
423382

424383
def describe(self):
@@ -433,13 +392,13 @@ def __init__(self, *data: DataContainer):
433392

434393
def query(
435394
self,
436-
coord_transform: _MatplotlibTransform,
437-
size: Tuple[int, int],
395+
graph: Graph,
396+
parent_coordinates: str = "axes",
438397
) -> Tuple[Dict[str, Any], Union[str, int]]:
439398
cache_keys = []
440399
ret = {}
441400
for data in self._datas:
442-
base, cache_key = data.query(coord_transform, size)
401+
base, cache_key = data.query(graph, parent_coordinates)
443402
ret.update(base)
444403
cache_keys.append(cache_key)
445404
return ret, hash(tuple(cache_keys))
@@ -451,11 +410,11 @@ def describe(self):
451410
class WebServiceContainer:
452411
def query(
453412
self,
454-
coord_transform: _MatplotlibTransform,
455-
size: Tuple[int, int],
413+
graph: Graph,
414+
parent_coordinates: str = "axes",
456415
) -> Tuple[Dict[str, Any], Union[str, int]]:
457416
def hit_some_database():
458-
{}, "1"
417+
return {}, "1"
459418

460419
data, etag = hit_some_database()
461420
return data, etag

0 commit comments

Comments
 (0)