Skip to content

Use clearer names in query signature #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions data_prototype/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import pandas as pd


class _Transform(Protocol):
class _MatplotlibTransform(Protocol):
def transform(self, verts):
...

def __sub__(self, other) -> "_Transform":
def __sub__(self, other) -> "_MatplotlibTransform":
...


Expand All @@ -35,8 +35,9 @@ class DataContainer(Protocol):
def query(
self,
# TODO 3D?!!
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
/,
) -> Tuple[Dict[str, Any], Union[str, int]]:
"""
Query the data container for data.
Expand All @@ -46,7 +47,7 @@ def query(

Parameters
----------
transform : matplotlib.transform.Transform
coord_transform : matplotlib.transform.Transform
Must go from axes fraction space -> data space

size : 2 integers
Expand Down Expand Up @@ -88,7 +89,7 @@ def __init__(self, **data):

def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
return dict(self._data), self._cache_key
Expand All @@ -113,7 +114,7 @@ def __init__(self, **shapes):

def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
return {k: np.random.randn(*d.shape) for k, d in self._desc.items()}, str(uuid.uuid4())
Expand Down Expand Up @@ -166,26 +167,26 @@ def _split(input_dict):

def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
# TODO find a better way to compute the hash key, this is not sentative to
# scale changes, only limit changes
data_bounds = tuple(transform.transform([[0, 0], [1, 1]]).flatten())
data_bounds = tuple(coord_transform.transform([[0, 0], [1, 1]]).flatten())
hash_key = hash((data_bounds, size))
if hash_key in self._cache:
return self._cache[hash_key], hash_key

xpix, ypix = size
x_data, _ = transform.transform(
x_data, _ = coord_transform.transform(
np.vstack(
[
np.linspace(0, 1, int(xpix) * 2),
np.zeros(int(xpix) * 2),
]
).T
).T
_, y_data = transform.transform(
_, y_data = coord_transform.transform(
np.vstack(
[
np.zeros(int(ypix) * 2),
Expand Down Expand Up @@ -218,11 +219,11 @@ def __init__(self, raw_data, num_bins: int):

def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
dmin, dmax = self._full_range
xmin, ymin, xmax, ymax = transform.transform([[0, 0], [1, 1]]).flatten()
xmin, ymin, xmax, ymax = coord_transform.transform([[0, 0], [1, 1]]).flatten()

xmin, xmax = np.clip([xmin, xmax], dmin, dmax)
hash_key = hash((xmin, xmax))
Expand Down Expand Up @@ -266,7 +267,7 @@ def __init__(self, series: pd.Series, *, index_name: str, col_name: str):

def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
return {self._index_name: self._data.index.values, self._col_name: self._data.values}, self._hash_key
Expand Down Expand Up @@ -305,7 +306,7 @@ def __init__(

def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
ret = {}
Expand All @@ -328,10 +329,10 @@ def __init__(self, data: DataContainer, mapping: Dict[str, str]):

def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
base, cache_key = self._data.query(transform, size)
base, cache_key = self._data.query(coord_transform, size)
return {v: base[k] for k, v in self._mapping.items()}, cache_key

def describe(self):
Expand All @@ -346,13 +347,13 @@ def __init__(self, *data: DataContainer):

def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
cache_keys = []
ret = {}
for data in self._datas:
base, cache_key = data.query(transform, size)
base, cache_key = data.query(coord_transform, size)
ret.update(base)
cache_keys.append(cache_key)
return ret, hash(tuple(cache_keys))
Expand All @@ -364,7 +365,7 @@ def describe(self):
class WebServiceContainer:
def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
def hit_some_database():
Expand Down
6 changes: 3 additions & 3 deletions data_prototype/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from matplotlib.collections import LineCollection as _LineCollection
from matplotlib.artist import Artist as _Artist

from data_prototype.containers import DataContainer, _Transform
from data_prototype.containers import DataContainer, _MatplotlibTransform


class _BBox(Protocol):
Expand All @@ -30,8 +30,8 @@ class _Axes(Protocol):
xaxis: _Axis
yaxis: _Axis

transData: _Transform
transAxes: _Transform
transData: _MatplotlibTransform
transAxes: _MatplotlibTransform

def get_xlim(self) -> Tuple[float, float]:
...
Expand Down
4 changes: 2 additions & 2 deletions examples/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from data_prototype.containers import _Transform, Desc
from data_prototype.containers import _MatplotlibTransform, Desc

from data_prototype.wrappers import LineWrapper, FormatedText

Expand All @@ -34,7 +34,7 @@ def describe(self):

def query(
self,
transform: _Transform,
coord_transformtransform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
th = np.linspace(0, 2 * np.pi, self.N)
Expand Down
6 changes: 3 additions & 3 deletions examples/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np

from data_prototype.wrappers import ImageWrapper
from data_prototype.containers import _Transform
from data_prototype.containers import _MatplotlibTransform, Desc

from skimage.transform import downscale_local_mean

Expand All @@ -45,10 +45,10 @@ def describe(self):

def query(
self,
transform: _Transform,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
) -> Tuple[Dict[str, Any], Union[str, int]]:
(x1, y1), (x2, y2) = transform.transform([[0, 0], [1, 1]])
(x1, y1), (x2, y2) = coord_transform.transform([[0, 0], [1, 1]])

xi1 = np.argmin(np.abs(x - x1))
yi1 = np.argmin(np.abs(y - y1))
Expand Down