Skip to content

Commit a8813d5

Browse files
authored
Merge pull request #36 from ksunden/late_conversion
Delayed conversion Node for units
2 parents dc8a96b + d2d364f commit a8813d5

File tree

8 files changed

+278
-30
lines changed

8 files changed

+278
-30
lines changed

data_prototype/axes.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import numpy as np
2+
3+
4+
import matplotlib as mpl
5+
from matplotlib.axes._axes import Axes as MPLAxes, _preprocess_data
6+
import matplotlib.collections as mcoll
7+
import matplotlib.cbook as cbook
8+
import matplotlib.markers as mmarkers
9+
import matplotlib.projections as mprojections
10+
11+
from .containers import ArrayContainer, DataUnion
12+
from .conversion_node import (
13+
DelayedConversionNode,
14+
FunctionConversionNode,
15+
RenameConversionNode,
16+
)
17+
from .wrappers import PathCollectionWrapper
18+
19+
20+
class Axes(MPLAxes):
21+
# Name for registering as a projection so we can experiment with it
22+
name = "data-prototype"
23+
24+
@_preprocess_data(
25+
replace_names=[
26+
"x",
27+
"y",
28+
"s",
29+
"linewidths",
30+
"edgecolors",
31+
"c",
32+
"facecolor",
33+
"facecolors",
34+
"color",
35+
],
36+
label_namer="y",
37+
)
38+
def scatter(
39+
self,
40+
x,
41+
y,
42+
s=None,
43+
c=None,
44+
marker=None,
45+
cmap=None,
46+
norm=None,
47+
vmin=None,
48+
vmax=None,
49+
alpha=None,
50+
linewidths=None,
51+
*,
52+
edgecolors=None,
53+
plotnonfinite=False,
54+
**kwargs
55+
):
56+
# TODO implement normalize kwargs as a pipeline stage
57+
# add edgecolors and linewidths to kwargs so they can be processed by
58+
# normalize_kwargs
59+
if edgecolors is not None:
60+
kwargs.update({"edgecolors": edgecolors})
61+
if linewidths is not None:
62+
kwargs.update({"linewidths": linewidths})
63+
64+
kwargs = cbook.normalize_kwargs(kwargs, mcoll.Collection)
65+
c, colors, edgecolors = self._parse_scatter_color_args(
66+
c,
67+
edgecolors,
68+
kwargs,
69+
np.ma.ravel(x).size,
70+
get_next_color_func=self._get_patches_for_fill.get_next_color,
71+
)
72+
73+
inputs = ArrayContainer(
74+
x=x,
75+
y=y,
76+
s=s,
77+
c=c,
78+
marker=marker,
79+
cmap=cmap,
80+
norm=norm,
81+
vmin=vmin,
82+
vmax=vmax,
83+
alpha=alpha,
84+
plotnonfinite=plotnonfinite,
85+
facecolors=colors,
86+
edgecolors=edgecolors,
87+
**kwargs
88+
)
89+
# TODO should more go in here?
90+
# marker/s are always in Container, but require overriding if None
91+
# Color handling is odd too
92+
defaults = ArrayContainer(
93+
linewidths=mpl.rcParams["lines.linewidth"],
94+
)
95+
96+
cont = DataUnion(defaults, inputs)
97+
98+
pipeline = []
99+
xconvert = DelayedConversionNode.from_keys(("x",), converter_key="xunits")
100+
yconvert = DelayedConversionNode.from_keys(("y",), converter_key="yunits")
101+
pipeline.extend([xconvert, yconvert])
102+
pipeline.append(lambda x: np.ma.ravel(x))
103+
pipeline.append(lambda y: np.ma.ravel(y))
104+
pipeline.append(
105+
lambda s: np.ma.ravel(s)
106+
if s is not None
107+
else [20]
108+
if mpl.rcParams["_internal.classic_mode"]
109+
else [mpl.rcParams["lines.markersize"] ** 2.0]
110+
)
111+
# TODO plotnonfinite/mask combining
112+
pipeline.append(
113+
lambda marker: marker
114+
if marker is not None
115+
else mpl.rcParams["scatter.marker"]
116+
)
117+
pipeline.append(
118+
lambda marker: marker
119+
if isinstance(marker, mmarkers.MarkerStyle)
120+
else mmarkers.MarkerStyle(marker)
121+
)
122+
pipeline.append(
123+
FunctionConversionNode.from_funcs(
124+
{
125+
"paths": lambda marker: [
126+
marker.get_path().transformed(marker.get_transform())
127+
]
128+
}
129+
)
130+
)
131+
pipeline.append(RenameConversionNode.from_mapping({"s": "sizes"}))
132+
133+
# TODO classic mode margin override?
134+
pcw = PathCollectionWrapper(cont, pipeline, offset_transform=self.transData)
135+
self.add_artist(pcw)
136+
self._request_autoscale_view()
137+
return pcw
138+
139+
140+
# This is a handy trick to allow e.g. plt.subplots(subplot_kw={'projection': 'data-prototype'})
141+
mprojections.register_projection(Axes)

data_prototype/containers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,12 @@ class ArrayContainer:
151151
def __init__(self, **data):
152152
self._data = data
153153
self._cache_key = str(uuid.uuid4())
154-
self._desc = {k: Desc(v.shape, v.dtype) for k, v in data.items()}
154+
self._desc = {
155+
k: Desc(v.shape, v.dtype)
156+
if isinstance(v, np.ndarray)
157+
else Desc((), type(v))
158+
for k, v in data.items()
159+
}
155160

156161
def query(
157162
self,

data_prototype/conversion_node.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,19 @@
99
from typing import Any
1010

1111

12-
def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]):
12+
def evaluate_pipeline(
13+
nodes: Sequence[ConversionNode],
14+
input: dict[str, Any],
15+
delayed_converters: dict[str, Callable] | None = None,
16+
):
1317
for node in nodes:
14-
input = node.evaluate(input)
18+
if isinstance(node, Callable):
19+
k = list(inspect.signature(node).parameters.keys())[0]
20+
node = FunctionConversionNode.from_funcs({k: node})
21+
if isinstance(node, DelayedConversionNode):
22+
input = node.evaluate(input, delayed_converters)
23+
else:
24+
input = node.evaluate(input)
1525
return input
1626

1727

@@ -113,3 +123,27 @@ def from_keys(cls, keys: Sequence[str]):
113123

114124
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
115125
return {k: v for k, v in input.items() if k in self.keys}
126+
127+
128+
@dataclass
129+
class DelayedConversionNode(ConversionNode):
130+
converter_key: str
131+
132+
@classmethod
133+
def from_keys(cls, keys: Sequence[str], converter_key: str):
134+
return cls(
135+
tuple(keys), tuple(keys), trim_keys=False, converter_key=converter_key
136+
)
137+
138+
def evaluate(
139+
self, input: dict[str, Any], converters: dict[str, Callable] | None = None
140+
) -> dict[str, Any]:
141+
return super().evaluate(
142+
{
143+
**input,
144+
**{
145+
k: converters[self.converter_key](input[k])
146+
for k in self.required_keys
147+
},
148+
}
149+
)

data_prototype/patches.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ class PatchWrapper(ProxyWrapper):
3030
"set_joinstyle",
3131
"set_path",
3232
)
33-
_xunits = ()
34-
_yunits = ()
3533
required_keys = {
3634
"edgecolor",
3735
"facecolor",
@@ -50,11 +48,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
5048

5149
@_stale_wrapper
5250
def draw(self, renderer):
53-
self._update_wrapped(
54-
self._query_and_transform(
55-
renderer, xunits=self._xunits, yunits=self._yunits
56-
)
57-
)
51+
self._update_wrapped(self._query_and_transform(renderer))
5852
return self._wrapped_instance.draw(renderer)
5953

6054
def _update_wrapped(self, data):
@@ -77,8 +71,6 @@ class RectangleWrapper(PatchWrapper):
7771
"set_angle",
7872
"set_rotation_point",
7973
)
80-
_xunits = ("x", "width")
81-
_yunits = ("y", "height")
8274
required_keys = PatchWrapper.required_keys | {
8375
"x",
8476
"y",

data_prototype/wrappers.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict, Any, Protocol, Tuple, get_type_hints
1+
from typing import Dict, Any, Protocol, Tuple, get_type_hints
22
import inspect
33

44
import numpy as np
@@ -121,17 +121,13 @@ def draw(self, renderer):
121121
def _update_wrapped(self, data):
122122
raise NotImplementedError
123123

124-
def _query_and_transform(
125-
self, renderer, *, xunits: List[str], yunits: List[str]
126-
) -> Dict[str, Any]:
124+
def _query_and_transform(self, renderer) -> Dict[str, Any]:
127125
"""
128126
Helper to centralize the data querying and python-side transforms
129127
130128
Parameters
131129
----------
132130
renderer : RendererBase
133-
xunits, yunits : List[str]
134-
The list of keys that need to be run through the x and y unit machinery.
135131
"""
136132
# extract what we need to about the axes to query the data
137133
ax = self.axes
@@ -153,8 +149,11 @@ def _query_and_transform(
153149
return self._cache[cache_key]
154150
except KeyError:
155151
...
156-
# TODO units
157-
transformed_data = evaluate_pipeline(self._converters, data)
152+
delayed_conversion = {
153+
"xunits": ax.xaxis.convert_units,
154+
"yunits": ax.yaxis.convert_units,
155+
}
156+
transformed_data = evaluate_pipeline(self._converters, data, delayed_conversion)
158157

159158
self._cache[cache_key] = transformed_data
160159
return transformed_data
@@ -232,7 +231,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
232231
@_stale_wrapper
233232
def draw(self, renderer):
234233
self._update_wrapped(
235-
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
234+
self._query_and_transform(renderer),
236235
)
237236
return self._wrapped_instance.draw(renderer)
238237

@@ -265,7 +264,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
265264
@_stale_wrapper
266265
def draw(self, renderer):
267266
self._update_wrapped(
268-
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
267+
self._query_and_transform(renderer),
269268
)
270269
return self._wrapped_instance.draw(renderer)
271270

@@ -304,7 +303,7 @@ def __init__(
304303
@_stale_wrapper
305304
def draw(self, renderer):
306305
self._update_wrapped(
307-
self._query_and_transform(renderer, xunits=["xextent"], yunits=["yextent"]),
306+
self._query_and_transform(renderer),
308307
)
309308
return self._wrapped_instance.draw(renderer)
310309

@@ -325,7 +324,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
325324
@_stale_wrapper
326325
def draw(self, renderer):
327326
self._update_wrapped(
328-
self._query_and_transform(renderer, xunits=["edges"], yunits=["density"]),
327+
self._query_and_transform(renderer),
329328
)
330329
return self._wrapped_instance.draw(renderer)
331330

@@ -344,7 +343,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
344343
@_stale_wrapper
345344
def draw(self, renderer):
346345
self._update_wrapped(
347-
self._query_and_transform(renderer, xunits=[], yunits=[]),
346+
self._query_and_transform(renderer),
348347
)
349348
return self._wrapped_instance.draw(renderer)
350349

@@ -425,11 +424,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
425424
@_stale_wrapper
426425
def draw(self, renderer):
427426
self._update_wrapped(
428-
self._query_and_transform(
429-
renderer,
430-
xunits=["x", "xupper", "xlower"],
431-
yunits=["y", "yupper", "ylower"],
432-
),
427+
self._query_and_transform(renderer),
433428
)
434429
for k, v in self._wrapped_instances.items():
435430
v.draw(renderer)

examples/scatter_with_custom_axes.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
=========================================
3+
An simple scatter plot using `ax.scatter`
4+
=========================================
5+
6+
This is a quick comparison between the current Matplotlib `scatter` and
7+
the version in :file:`data_prototype/axes.py`, which uses data containers
8+
and a conversion pipeline.
9+
10+
This is here to show what does work and what does not work with the current
11+
implementation of container-based artist drawing.
12+
"""
13+
14+
15+
import data_prototype.axes # side-effect registers projection # noqa
16+
17+
import matplotlib.pyplot as plt
18+
19+
fig = plt.figure()
20+
newstyle = fig.add_subplot(2, 1, 1, projection="data-prototype")
21+
oldstyle = fig.add_subplot(2, 1, 2)
22+
23+
newstyle.scatter([0, 1, 2], [2, 5, 1])
24+
oldstyle.scatter([0, 1, 2], [2, 5, 1])
25+
newstyle.scatter([0, 1, 2], [3, 1, 2])
26+
oldstyle.scatter([0, 1, 2], [3, 1, 2])
27+
28+
29+
# Autoscaling not working
30+
newstyle.set_xlim(oldstyle.get_xlim())
31+
newstyle.set_ylim(oldstyle.get_ylim())
32+
33+
plt.show()

0 commit comments

Comments
 (0)