diff --git a/data_prototype/containers.py b/data_prototype/containers.py index b3db13e..8491914 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -8,11 +8,11 @@ import pandas as pd -class _Transform(Protocol): +class _MatplotlibTransform(Protocol): def transform(self, verts): ... - def __sub__(self, other) -> "_Transform": + def __sub__(self, other) -> "_MatplotlibTransform": ... @@ -35,8 +35,9 @@ class DataContainer(Protocol): def query( self, # TODO 3D?!! - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], + /, ) -> Tuple[Dict[str, Any], Union[str, int]]: """ Query the data container for data. @@ -46,7 +47,7 @@ def query( Parameters ---------- - transform : matplotlib.transform.Transform + coord_transform : matplotlib.transform.Transform Must go from axes fraction space -> data space size : 2 integers @@ -88,7 +89,7 @@ def __init__(self, **data): def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: return dict(self._data), self._cache_key @@ -113,7 +114,7 @@ def __init__(self, **shapes): def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: return {k: np.random.randn(*d.shape) for k, d in self._desc.items()}, str(uuid.uuid4()) @@ -166,18 +167,18 @@ def _split(input_dict): def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: # TODO find a better way to compute the hash key, this is not sentative to # scale changes, only limit changes - data_bounds = tuple(transform.transform([[0, 0], [1, 1]]).flatten()) + data_bounds = tuple(coord_transform.transform([[0, 0], [1, 1]]).flatten()) hash_key = hash((data_bounds, size)) if hash_key in self._cache: return self._cache[hash_key], hash_key xpix, ypix = size - x_data, _ = transform.transform( + x_data, _ = coord_transform.transform( np.vstack( [ np.linspace(0, 1, int(xpix) * 2), @@ -185,7 +186,7 @@ def query( ] ).T ).T - _, y_data = transform.transform( + _, y_data = coord_transform.transform( np.vstack( [ np.zeros(int(ypix) * 2), @@ -218,11 +219,11 @@ def __init__(self, raw_data, num_bins: int): def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: dmin, dmax = self._full_range - xmin, ymin, xmax, ymax = transform.transform([[0, 0], [1, 1]]).flatten() + xmin, ymin, xmax, ymax = coord_transform.transform([[0, 0], [1, 1]]).flatten() xmin, xmax = np.clip([xmin, xmax], dmin, dmax) hash_key = hash((xmin, xmax)) @@ -266,7 +267,7 @@ def __init__(self, series: pd.Series, *, index_name: str, col_name: str): def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: return {self._index_name: self._data.index.values, self._col_name: self._data.values}, self._hash_key @@ -305,7 +306,7 @@ def __init__( def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: ret = {} @@ -328,10 +329,10 @@ def __init__(self, data: DataContainer, mapping: Dict[str, str]): def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: - base, cache_key = self._data.query(transform, size) + base, cache_key = self._data.query(coord_transform, size) return {v: base[k] for k, v in self._mapping.items()}, cache_key def describe(self): @@ -346,13 +347,13 @@ def __init__(self, *data: DataContainer): def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: cache_keys = [] ret = {} for data in self._datas: - base, cache_key = data.query(transform, size) + base, cache_key = data.query(coord_transform, size) ret.update(base) cache_keys.append(cache_key) return ret, hash(tuple(cache_keys)) @@ -364,7 +365,7 @@ def describe(self): class WebServiceContainer: def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: def hit_some_database(): diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index 66bbc33..059e507 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -14,7 +14,7 @@ from matplotlib.collections import LineCollection as _LineCollection from matplotlib.artist import Artist as _Artist -from data_prototype.containers import DataContainer, _Transform +from data_prototype.containers import DataContainer, _MatplotlibTransform class _BBox(Protocol): @@ -30,8 +30,8 @@ class _Axes(Protocol): xaxis: _Axis yaxis: _Axis - transData: _Transform - transAxes: _Transform + transData: _MatplotlibTransform + transAxes: _MatplotlibTransform def get_xlim(self) -> Tuple[float, float]: ... diff --git a/examples/animation.py b/examples/animation.py index e92c176..77161e3 100644 --- a/examples/animation.py +++ b/examples/animation.py @@ -14,7 +14,7 @@ import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation -from data_prototype.containers import _Transform, Desc +from data_prototype.containers import _MatplotlibTransform, Desc from data_prototype.wrappers import LineWrapper, FormatedText @@ -34,7 +34,7 @@ def describe(self): def query( self, - transform: _Transform, + coord_transformtransform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: th = np.linspace(0, 2 * np.pi, self.N) diff --git a/examples/subsample.py b/examples/subsample.py index 07fb8ad..2e0e4ab 100644 --- a/examples/subsample.py +++ b/examples/subsample.py @@ -21,7 +21,7 @@ import numpy as np from data_prototype.wrappers import ImageWrapper -from data_prototype.containers import _Transform +from data_prototype.containers import _MatplotlibTransform, Desc from skimage.transform import downscale_local_mean @@ -45,10 +45,10 @@ def describe(self): def query( self, - transform: _Transform, + coord_transform: _MatplotlibTransform, size: Tuple[int, int], ) -> Tuple[Dict[str, Any], Union[str, int]]: - (x1, y1), (x2, y2) = transform.transform([[0, 0], [1, 1]]) + (x1, y1), (x2, y2) = coord_transform.transform([[0, 0], [1, 1]]) xi1 = np.argmin(np.abs(x - x1)) yi1 = np.argmin(np.abs(y - y1))