Skip to content

Commit c588266

Browse files
committed
Play with units
1 parent 2e70f58 commit c588266

File tree

6 files changed

+232
-2
lines changed

6 files changed

+232
-2
lines changed

data_prototype/axes.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import numpy as np
2+
3+
4+
import matplotlib as mpl
5+
from matplotlib.axes._axes import Axes as MPLAxes, _preprocess_data
6+
import matplotlib.collections as mcoll
7+
import matplotlib.cbook as cbook
8+
import matplotlib.markers as mmarkers
9+
import matplotlib.projections as mprojections
10+
11+
from .containers import ArrayContainer, DataUnion
12+
from .conversion_node import (
13+
MatplotlibUnitConversion,
14+
FunctionConversionNode,
15+
RenameConversionNode,
16+
)
17+
from .wrappers import PathCollectionWrapper
18+
19+
20+
class Axes(MPLAxes):
21+
# Name for registering as a projection so we can experiment with it
22+
name = "data-prototype"
23+
24+
@_preprocess_data(
25+
replace_names=[
26+
"x",
27+
"y",
28+
"s",
29+
"linewidths",
30+
"edgecolors",
31+
"c",
32+
"facecolor",
33+
"facecolors",
34+
"color",
35+
],
36+
label_namer="y",
37+
)
38+
def scatter(
39+
self,
40+
x,
41+
y,
42+
s=None,
43+
c=None,
44+
marker=None,
45+
cmap=None,
46+
norm=None,
47+
vmin=None,
48+
vmax=None,
49+
alpha=None,
50+
linewidths=None,
51+
*,
52+
edgecolors=None,
53+
plotnonfinite=False,
54+
**kwargs
55+
):
56+
# TODO implement normalize kwargs as a pipeline stage
57+
# add edgecolors and linewidths to kwargs so they can be processed by
58+
# normalize_kwargs
59+
if edgecolors is not None:
60+
kwargs.update({"edgecolors": edgecolors})
61+
if linewidths is not None:
62+
kwargs.update({"linewidths": linewidths})
63+
64+
kwargs = cbook.normalize_kwargs(kwargs, mcoll.Collection)
65+
c, colors, edgecolors = self._parse_scatter_color_args(
66+
c,
67+
edgecolors,
68+
kwargs,
69+
np.ma.ravel(x).size,
70+
get_next_color_func=self._get_patches_for_fill.get_next_color,
71+
)
72+
73+
inputs = ArrayContainer(
74+
x=x,
75+
y=y,
76+
s=s,
77+
c=c,
78+
marker=marker,
79+
cmap=cmap,
80+
norm=norm,
81+
vmin=vmin,
82+
vmax=vmax,
83+
alpha=alpha,
84+
plotnonfinite=plotnonfinite,
85+
facecolors=colors,
86+
edgecolors=edgecolors,
87+
**kwargs
88+
)
89+
# TODO should more go in here?
90+
# marker/s are always in Container, but require overriding if None
91+
# Color handling is odd too
92+
defaults = ArrayContainer(
93+
linewidths=mpl.rcParams["lines.linewidth"],
94+
)
95+
96+
cont = DataUnion(defaults, inputs)
97+
98+
pipeline = []
99+
xconvert = MatplotlibUnitConversion.from_keys(("x",), axis=self.xaxis)
100+
yconvert = MatplotlibUnitConversion.from_keys(("y",), axis=self.yaxis)
101+
pipeline.extend([xconvert, yconvert])
102+
pipeline.append(lambda x: np.ma.ravel(x))
103+
pipeline.append(lambda y: np.ma.ravel(y))
104+
pipeline.append(
105+
lambda s: np.ma.ravel(s)
106+
if s is not None
107+
else [20]
108+
if mpl.rcParams["_internal.classic_mode"]
109+
else [mpl.rcParams["lines.markersize"] ** 2.0]
110+
)
111+
# TODO plotnonfinite/mask combining
112+
pipeline.append(
113+
lambda marker: marker
114+
if marker is not None
115+
else mpl.rcParams["scatter.marker"]
116+
)
117+
pipeline.append(
118+
lambda marker: marker
119+
if isinstance(marker, mmarkers.MarkerStyle)
120+
else mmarkers.MarkerStyle(marker)
121+
)
122+
pipeline.append(
123+
FunctionConversionNode.from_funcs(
124+
{
125+
"paths": lambda marker: [
126+
marker.get_path().transformed(marker.get_transform())
127+
]
128+
}
129+
)
130+
)
131+
pipeline.append(RenameConversionNode.from_mapping({"s": "sizes"}))
132+
133+
# TODO classic mode margin override?
134+
pcw = PathCollectionWrapper(cont, pipeline, offset_transform=self.transData)
135+
self.add_artist(pcw)
136+
self._request_autoscale_view()
137+
return pcw
138+
139+
140+
# This is a handy trick to allow e.g. plt.subplots(subplot_kw={'projection': 'data-prototype'})
141+
mprojections.register_projection(Axes)

data_prototype/containers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ class ArrayContainer:
8585
def __init__(self, **data):
8686
self._data = data
8787
self._cache_key = str(uuid.uuid4())
88-
self._desc = {k: Desc(v.shape, v.dtype) for k, v in data.items()}
88+
self._desc = {
89+
k: Desc(v.shape, v.dtype)
90+
if isinstance(v, np.ndarray)
91+
else Desc((), type(v))
92+
for k, v in data.items()
93+
}
8994

9095
def query(
9196
self,

data_prototype/conversion_node.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
import inspect
77
from functools import cached_property
88

9+
from matplotlib.axis import Axis
10+
911
from typing import Any
1012

1113

1214
def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]):
1315
for node in nodes:
16+
if isinstance(node, Callable):
17+
k = list(inspect.signature(node).parameters.keys())[0]
18+
node = FunctionConversionNode.from_funcs({k: node})
19+
1420
input = node.evaluate(input)
1521
return input
1622

@@ -113,3 +119,20 @@ def from_keys(cls, keys: Sequence[str]):
113119

114120
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
115121
return {k: v for k, v in input.items() if k in self.keys}
122+
123+
124+
@dataclass
125+
class MatplotlibUnitConversion(ConversionNode):
126+
axis: Axis
127+
128+
@classmethod
129+
def from_keys(cls, keys: Sequence[str], axis: Axis):
130+
return cls(tuple(keys), tuple(keys), trim_keys=False, axis=axis)
131+
132+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
133+
return super().evaluate(
134+
{
135+
**input,
136+
**{k: self.axis.convert_units(input[k]) for k in self.required_keys},
137+
}
138+
)

data_prototype/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
265265
@_stale_wrapper
266266
def draw(self, renderer):
267267
self._update_wrapped(
268-
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
268+
self._query_and_transform(renderer, xunits=[], yunits=[]),
269269
)
270270
return self._wrapped_instance.draw(renderer)
271271

examples/scatter_with_custom_axes.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import data_prototype.axes # side-effect registers projection # noqa
2+
3+
import matplotlib.pyplot as plt
4+
5+
fig = plt.figure()
6+
newstyle = fig.add_subplot(2, 1, 1, projection="data-prototype")
7+
oldstyle = fig.add_subplot(2, 1, 2)
8+
9+
newstyle.scatter([0, 1, 2], [2, 5, 1])
10+
oldstyle.scatter([0, 1, 2], [2, 5, 1])
11+
newstyle.scatter([0, 1, 2], [3, 1, 2])
12+
oldstyle.scatter([0, 1, 2], [3, 1, 2])
13+
14+
15+
# Autoscaling not working
16+
newstyle.set_xlim(oldstyle.get_xlim())
17+
newstyle.set_ylim(oldstyle.get_ylim())
18+
19+
plt.show()

examples/units.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
==================================================
3+
An simple scatter plot using PathCollectionWrapper
4+
==================================================
5+
6+
A quick scatter plot using :class:`.containers.ArrayContainer` and
7+
:class:`.wrappers.PathCollectionWrapper`.
8+
"""
9+
import numpy as np
10+
11+
import matplotlib.pyplot as plt
12+
import matplotlib.markers as mmarkers
13+
14+
from data_prototype.containers import ArrayContainer
15+
from data_prototype.conversion_node import MatplotlibUnitConversion
16+
17+
from data_prototype.wrappers import PathCollectionWrapper
18+
19+
import pint
20+
21+
ureg = pint.UnitRegistry()
22+
ureg.setup_matplotlib()
23+
24+
marker_obj = mmarkers.MarkerStyle("o")
25+
26+
cont = ArrayContainer(
27+
x=np.array([0, 1, 2]) * ureg.m,
28+
y=np.array([1, 4, 2]) * ureg.m,
29+
paths=np.array([marker_obj.get_path()]),
30+
sizes=np.array([12]),
31+
edgecolors=np.array(["k"]),
32+
facecolors=np.array(["C3"]),
33+
)
34+
35+
fig, ax = plt.subplots()
36+
ax.set_xlim(-0.5, 7)
37+
ax.set_ylim(0, 5)
38+
conv = MatplotlibUnitConversion.from_keys(("x",), axis=ax.xaxis)
39+
lw = PathCollectionWrapper(cont, [conv], offset_transform=ax.transData)
40+
ax.add_artist(lw)
41+
ax.xaxis.set_units(ureg.feet)
42+
plt.show()

0 commit comments

Comments
 (0)