Skip to content

Initial rework as conversion edges #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ default_language_version:
python: python3
repos:
- repo: https://github.com/ambv/black
rev: 23.3.0
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.0.0
hooks:
- id: flake8
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
rev: 0.7.1
hooks:
- id: nbstripout
28 changes: 17 additions & 11 deletions data_prototype/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,28 @@ def scatter(
pipeline.append(lambda x: np.ma.ravel(x))
pipeline.append(lambda y: np.ma.ravel(y))
pipeline.append(
lambda s: np.ma.ravel(s)
if s is not None
else [20]
if mpl.rcParams["_internal.classic_mode"]
else [mpl.rcParams["lines.markersize"] ** 2.0]
lambda s: (
np.ma.ravel(s)
if s is not None
else (
[20]
if mpl.rcParams["_internal.classic_mode"]
else [mpl.rcParams["lines.markersize"] ** 2.0]
)
)
)
# TODO plotnonfinite/mask combining
pipeline.append(
lambda marker: marker
if marker is not None
else mpl.rcParams["scatter.marker"]
lambda marker: (
marker if marker is not None else mpl.rcParams["scatter.marker"]
)
)
pipeline.append(
lambda marker: marker
if isinstance(marker, mmarkers.MarkerStyle)
else mmarkers.MarkerStyle(marker)
lambda marker: (
marker
if isinstance(marker, mmarkers.MarkerStyle)
else mmarkers.MarkerStyle(marker)
)
)
pipeline.append(
FunctionConversionNode.from_funcs(
Expand Down
211 changes: 85 additions & 126 deletions data_prototype/containers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from __future__ import annotations

from typing import (
Protocol,
Dict,
Expand All @@ -8,7 +9,6 @@
Union,
Callable,
MutableMapping,
TypeAlias,
)
import uuid

Expand All @@ -17,92 +17,25 @@
import numpy as np
import pandas as pd

from .description import Desc, desc_like

class _MatplotlibTransform(Protocol):
def transform(self, verts):
...

def __sub__(self, other) -> "_MatplotlibTransform":
...
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .conversion_edge import Graph

ShapeSpec: TypeAlias = Tuple[Union[str, int], ...]

class _MatplotlibTransform(Protocol):
def transform(self, verts): ...

@dataclass(frozen=True)
class Desc:
# TODO: sort out how to actually spell this. We need to know:
# - what the number of dimensions is (1d vs 2d vs ...)
# - is this a fixed size dimension (e.g. 2 for xextent)
# - is this a variable size depending on the query (e.g. N)
# - what is the relative size to the other variable values (N vs N+1)
# We are probably going to have to implement a DSL for this (😞)
shape: ShapeSpec
# TODO: is using a string better?
dtype: np.dtype
# TODO: do we want to include this at this level? "naive" means unit-unaware.
units: str = "naive"

@staticmethod
def validate_shapes(
specification: dict[str, ShapeSpec | "Desc"],
actual: dict[str, ShapeSpec | "Desc"],
*,
broadcast=False,
) -> bool:
specvars: dict[str, int | tuple[str, int]] = {}
for fieldname in specification:
spec = specification[fieldname]
if fieldname not in actual:
raise KeyError(
f"Actual is missing {fieldname!r}, required by specification."
)
desc = actual[fieldname]
if isinstance(spec, Desc):
spec = spec.shape
if isinstance(desc, Desc):
desc = desc.shape
if not broadcast:
if len(spec) != len(desc):
raise ValueError(
f"{fieldname!r} shape {desc} incompatible with specification "
f"{spec}."
)
elif len(desc) > len(spec):
raise ValueError(
f"{fieldname!r} shape {desc} incompatible with specification "
f"{spec}."
)
for speccomp, desccomp in zip(spec[::-1], desc[::-1]):
if broadcast and desccomp == 1:
continue
if isinstance(speccomp, str):
specv, specoff = speccomp[0], int(speccomp[1:] or 0)

if isinstance(desccomp, str):
descv, descoff = desccomp[0], int(desccomp[1:] or 0)
entry = (descv, descoff - specoff)
else:
entry = desccomp - specoff

if specv in specvars and entry != specvars[specv]:
raise ValueError(f"Found two incompatible values for {specv!r}")

specvars[specv] = entry
elif speccomp != desccomp:
raise ValueError(
f"{fieldname!r} shape {desc} incompatible with specification "
f"{spec}"
)
return None
def __sub__(self, other) -> "_MatplotlibTransform": ...


class DataContainer(Protocol):
def query(
self,
# TODO 3D?!!
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
/,
) -> Tuple[Dict[str, Any], Union[str, int]]:
"""
Expand Down Expand Up @@ -132,6 +65,7 @@ def query(
This is a key that clients can use to cache down-stream
computations on this data.
"""
...

def describe(self) -> Dict[str, Desc]:
"""
Expand All @@ -141,27 +75,29 @@ def describe(self) -> Dict[str, Desc]:
-------
Dict[str, Desc]
"""
...


class NoNewKeys(ValueError):
...
class NoNewKeys(ValueError): ...


class ArrayContainer:
def __init__(self, **data):
self._data = data
self._cache_key = str(uuid.uuid4())
self._desc = {
k: Desc(v.shape, v.dtype)
if isinstance(v, np.ndarray)
else Desc((), type(v))
k: (
Desc(v.shape, v.dtype)
if isinstance(v, np.ndarray)
else Desc((), type(v))
)
for k, v in data.items()
}

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
return dict(self._data), self._cache_key

Expand All @@ -185,8 +121,8 @@ def __init__(self, **shapes):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
return {k: np.random.randn(*d.shape) for k, d in self._desc.items()}, str(
uuid.uuid4()
Expand Down Expand Up @@ -253,31 +189,44 @@ def _query_hash(self, coord_transform, size):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
hash_key = self._query_hash(coord_transform, size)
if hash_key in self._cache:
return self._cache[hash_key], hash_key
# hash_key = self._query_hash(coord_transform, size)
# if hash_key in self._cache:
# return self._cache[hash_key], hash_key

desc = Desc(("N",), np.dtype("f8"))
xy = {"x": desc, "y": desc}
data_lim = graph.evaluator(
desc_like(xy, coordinates="data"),
desc_like(xy, coordinates=parent_coordinates),
).inverse

screen_size = graph.evaluator(
desc_like(xy, coordinates=parent_coordinates),
desc_like(xy, coordinates="display"),
)

xpix, ypix = size
x_data, _ = coord_transform.transform(
np.vstack(
[
np.linspace(0, 1, int(xpix) * 2),
np.zeros(int(xpix) * 2),
]
).T
).T
_, y_data = coord_transform.transform(
np.vstack(
[
np.zeros(int(ypix) * 2),
np.linspace(0, 1, int(ypix) * 2),
]
).T
).T
screen_dims = screen_size.evaluate({"x": [0, 1], "y": [0, 1]})
xpix, ypix = np.ceil(np.abs(np.diff(screen_dims["x"]))), np.ceil(
np.abs(np.diff(screen_dims["y"]))
)

x_data = data_lim.evaluate(
{
"x": np.linspace(0, 1, int(xpix) * 2),
"y": np.zeros(int(xpix) * 2),
}
)["x"]
y_data = data_lim.evaluate(
{
"x": np.zeros(int(ypix) * 2),
"y": np.linspace(0, 1, int(ypix) * 2),
}
)["y"]

hash_key = str(uuid.uuid4())
ret = self._cache[hash_key] = dict(
**{k: f(x_data) for k, f in self._xfuncs.items()},
**{k: f(y_data) for k, f in self._yfuncs.items()},
Expand All @@ -302,11 +251,21 @@ def __init__(self, raw_data, num_bins: int):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
dmin, dmax = self._full_range
xmin, ymin, xmax, ymax = coord_transform.transform([[0, 0], [1, 1]]).flatten()

desc = Desc(("N",), np.dtype("f8"))
xy = {"x": desc, "y": desc}
data_lim = graph.evaluator(
desc_like(xy, coordinates="data"),
desc_like(xy, coordinates=parent_coordinates),
).inverse

pts = data_lim.evaluate({"x": (0, 1), "y": (0, 1)})
xmin, xmax = pts["x"]
ymin, ymax = pts["y"]

xmin, xmax = np.clip([xmin, xmax], dmin, dmax)
hash_key = hash((xmin, xmax))
Expand All @@ -333,7 +292,7 @@ def describe(self) -> Dict[str, Desc]:


class SeriesContainer:
_data: pd.DataFrame
_data: pd.Series
_index_name: str
_hash_key: str

Expand All @@ -350,8 +309,8 @@ def __init__(self, series: pd.Series, *, index_name: str, col_name: str):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
return {
self._index_name: self._data.index.values,
Expand Down Expand Up @@ -392,8 +351,8 @@ def __init__(

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
ret = {}
if self._index_name is not None:
Expand All @@ -415,10 +374,10 @@ def __init__(self, data: DataContainer, mapping: Dict[str, str]):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
base, cache_key = self._data.query(coord_transform, size)
base, cache_key = self._data.query(graph, parent_coordinates)
return {v: base[k] for k, v in self._mapping.items()}, cache_key

def describe(self):
Expand All @@ -433,13 +392,13 @@ def __init__(self, *data: DataContainer):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
cache_keys = []
ret = {}
for data in self._datas:
base, cache_key = data.query(coord_transform, size)
base, cache_key = data.query(graph, parent_coordinates)
ret.update(base)
cache_keys.append(cache_key)
return ret, hash(tuple(cache_keys))
Expand All @@ -451,11 +410,11 @@ def describe(self):
class WebServiceContainer:
def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
def hit_some_database():
{}, "1"
return {}, "1"

data, etag = hit_some_database()
return data, etag
Loading