Skip to content

Commit 8e0086f

Browse files
committed
Early PathCollection support
Fix scatter transforms Remove unused imports more unused imports
1 parent 339bc85 commit 8e0086f

File tree

3 files changed

+170
-3
lines changed

3 files changed

+170
-3
lines changed

data_prototype/wrappers.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from matplotlib.image import AxesImage as _AxesImage
1212
from matplotlib.patches import StepPatch as _StepPatch
1313
from matplotlib.text import Text as _Text
14-
from matplotlib.collections import LineCollection as _LineCollection
14+
import matplotlib.transforms as mtransforms
15+
from matplotlib.collections import LineCollection as _LineCollection, PathCollection as _PathCollection
1516
from matplotlib.artist import Artist as _Artist
1617

1718
from data_prototype.containers import DataContainer, _MatplotlibTransform
@@ -198,15 +199,14 @@ def __setattr__(self, key, value):
198199
else:
199200
super().__setattr__(key, value)
200201

201-
202202
class LineWrapper(ProxyWrapper):
203203
_wrapped_class = _Line2D
204204
_privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data")
205205
required_keys = {"x", "y"}
206206

207207
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
208208
super().__init__(data, nus)
209-
self._wrapped_instance = self._wrapped_class([], [], **kwargs)
209+
self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs)
210210

211211
@_stale_wrapper
212212
def draw(self, renderer):
@@ -220,6 +220,34 @@ def _update_wrapped(self, data):
220220
k = {"x": "xdata", "y": "ydata"}.get(k, k)
221221
getattr(self._wrapped_instance, f"set_{k}")(v)
222222

223+
class PathCollectionWrapper(ProxyWrapper):
224+
_wrapped_class = _PathCollection
225+
_privtized_methods = (
226+
"set_facecolors", "set_edgecolors", "set_offsets", "set_sizes", "set_paths",
227+
"get_facecolors", "get_edgecolors", "get_offsets", "get_sizes", "get_paths",
228+
)
229+
230+
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
231+
super().__init__(data, nus)
232+
self._wrapped_instance = self._wrapped_class([], **kwargs)
233+
self._wrapped_instance.set_transform(mtransforms.IdentityTransform())
234+
235+
@_stale_wrapper
236+
def draw(self, renderer):
237+
self._update_wrapped(
238+
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
239+
)
240+
return self._wrapped_instance.draw(renderer)
241+
242+
def _update_wrapped(self, data):
243+
print(data)
244+
self._wrapped_instance.set_offsets(np.array([data["x"], data["y"]]).T)
245+
self._wrapped_instance.set_paths(data["paths"])
246+
self._wrapped_instance.set_facecolors(data["facecolors"])
247+
self._wrapped_instance.set_edgecolors(data["edgecolors"])
248+
self._wrapped_instance.set_sizes(data["sizes"])
249+
250+
223251

224252
class ImageWrapper(ProxyWrapper):
225253
_wrapped_class = _AxesImage

examples/lissajous.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
==========================
3+
An animated lissajous ball
4+
==========================
5+
6+
Inspired by https://twitter.com/_brohrer_/status/1584681864648065027
7+
8+
9+
"""
10+
import time
11+
from typing import Dict, Tuple, Any, Union
12+
from functools import partial
13+
14+
import numpy as np
15+
16+
import matplotlib.pyplot as plt
17+
import matplotlib.markers as mmarkers
18+
from matplotlib.animation import FuncAnimation
19+
20+
from data_prototype.containers import _Transform, Desc
21+
22+
from data_prototype.wrappers import PathCollectionWrapper, FormatedText
23+
24+
25+
class Lissajous:
26+
N = 1024
27+
# cycles per minutes
28+
scale = 2
29+
30+
def describe(self):
31+
return {
32+
"x": Desc([self.N], float),
33+
"y": Desc([self.N], float),
34+
"phase": Desc([], float),
35+
"time": Desc([], float),
36+
"sizes": Desc([], float),
37+
"paths": Desc([], float),
38+
"edgecolors": Desc([], str),
39+
"facecolors": Desc([self.N], str),
40+
}
41+
42+
def query(
43+
self,
44+
transform: _Transform,
45+
size: Tuple[int, int],
46+
) -> Tuple[Dict[str, Any], Union[str, int]]:
47+
def next_time():
48+
cur_time = time.time()
49+
cur_time = np.array([cur_time, cur_time-.1, cur_time-.2, cur_time-0.3])
50+
51+
phase = 15*np.pi * (self.scale * cur_time % 60) / 150
52+
marker_obj = mmarkers.MarkerStyle("o")
53+
return {
54+
"x": np.cos(5*phase),
55+
"y": np.sin(3*phase),
56+
"phase": phase[0],
57+
"sizes": np.array([256]),
58+
"paths": [marker_obj.get_path().transformed(marker_obj.get_transform())],
59+
"edgecolors": "k",
60+
"facecolors": ["#4682b4ff", "#82b446aa", "#46b48288", "#8246b433"],
61+
"time": cur_time[0],
62+
}, hash(cur_time[0])
63+
64+
return next_time()
65+
66+
67+
def update(frame, art):
68+
return art
69+
70+
71+
sot_c = Lissajous()
72+
73+
fc = FormatedText(
74+
sot_c,
75+
"ϕ={phase:.2f} ".format,
76+
x=1,
77+
y=1,
78+
ha="right",
79+
)
80+
fig, ax = plt.subplots()
81+
ax.set_xlim(-1.1, 1.1)
82+
ax.set_ylim(-1.1, 1.1)
83+
lw = PathCollectionWrapper(sot_c, offset_transform=ax.transData)
84+
ax.add_artist(lw)
85+
ax.add_artist(fc)
86+
#ax.set_xticks([])
87+
#ax.set_yticks([])
88+
ax.set_aspect(1)
89+
ani = FuncAnimation(
90+
fig,
91+
partial(update, art=(lw, fc)),
92+
frames=60*15,
93+
interval=1000 / 100,
94+
# TODO: blitting does not work because wrappers do not inherent from Artist
95+
# blit=True,
96+
)
97+
plt.show()

examples/simple_scatter.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
==========================
3+
An animated lissajous ball
4+
==========================
5+
6+
Inspired by https://twitter.com/_brohrer_/status/1584681864648065027
7+
8+
9+
"""
10+
import numpy as np
11+
12+
import matplotlib.pyplot as plt
13+
import matplotlib.markers as mmarkers
14+
15+
from data_prototype.containers import ArrayContainer
16+
17+
from data_prototype.wrappers import PathCollectionWrapper
18+
19+
20+
def update(frame, art):
21+
return art
22+
23+
marker_obj = mmarkers.MarkerStyle("o")
24+
25+
cont = ArrayContainer(
26+
x = np.array([0,1,2]),
27+
y = np.array([1,4,2]),
28+
paths = np.array([marker_obj.get_path()]),
29+
sizes = np.array([12]),
30+
edgecolors = np.array(["k"]),
31+
facecolors = np.array(["C3"]),
32+
)
33+
34+
fig, ax = plt.subplots()
35+
ax.set_xlim(-.5, 2.5)
36+
ax.set_ylim(0, 5)
37+
lw = PathCollectionWrapper(cont, offset_transform=ax.transData)
38+
ax.add_artist(lw)
39+
#ax.set_xticks([])
40+
#ax.set_yticks([])
41+
#ax.set_aspect(1)
42+
plt.show()

0 commit comments

Comments
 (0)