Skip to content

Commit d18a891

Browse files
authored
Merge pull request #22 from tacaswell/widgets
ENH: add demo of binding container to sliders
2 parents cab4a67 + 82f6c8d commit d18a891

File tree

2 files changed

+139
-4
lines changed

2 files changed

+139
-4
lines changed

data_prototype/containers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,19 @@ def _split(input_dict):
165165
self._xyfuncs = _split(xyfuncs) if xyfuncs is not None else {}
166166
self._cache: MutableMapping[Union[str, int], Any] = LFUCache(64)
167167

168+
def _query_hash(self, coord_transform, size):
169+
# TODO find a better way to compute the hash key, this is not sentative to
170+
# scale changes, only limit changes
171+
data_bounds = tuple(coord_transform.transform([[0, 0], [1, 1]]).flatten())
172+
hash_key = hash((data_bounds, size))
173+
return hash_key
174+
168175
def query(
169176
self,
170177
coord_transform: _MatplotlibTransform,
171178
size: Tuple[int, int],
172179
) -> Tuple[Dict[str, Any], Union[str, int]]:
173-
# TODO find a better way to compute the hash key, this is not sentative to
174-
# scale changes, only limit changes
175-
data_bounds = tuple(coord_transform.transform([[0, 0], [1, 1]]).flatten())
176-
hash_key = hash((data_bounds, size))
180+
hash_key = self._query_hash(coord_transform, size)
177181
if hash_key in self._cache:
178182
return self._cache[hash_key], hash_key
179183

examples/widgets.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""
2+
======
3+
Slider
4+
======
5+
6+
In this example, sliders are used to control the frequency and amplitude of
7+
a sine wave.
8+
9+
"""
10+
import inspect
11+
12+
import numpy as np
13+
import matplotlib.pyplot as plt
14+
from matplotlib.widgets import Slider, Button
15+
16+
from data_prototype.wrappers import LineWrapper
17+
from data_prototype.containers import FuncContainer
18+
19+
20+
class SliderContainer(FuncContainer):
21+
def __init__(self, xfuncs, /, **sliders):
22+
self._sliders = sliders
23+
for slider in sliders.values():
24+
slider.on_changed(
25+
lambda x, sld=slider: sld.ax.figure.canvas.draw_idle(),
26+
)
27+
28+
def get_needed_keys(f, offset=1):
29+
return tuple(inspect.signature(f).parameters)[offset:]
30+
31+
super().__init__(
32+
{
33+
k: (
34+
s,
35+
# this line binds the correct sliders to the functions
36+
# and makes lambdas that match the API FuncContainer needs
37+
lambda x, keys=get_needed_keys(f), f=f: f(x, *(sliders[k].val for k in keys)),
38+
)
39+
for k, (s, f) in xfuncs.items()
40+
},
41+
)
42+
43+
def _query_hash(self, coord_transform, size):
44+
key = super()._query_hash(coord_transform, size)
45+
# inject the slider values into the hashing logic
46+
return hash((key, tuple(s.val for s in self._sliders.values())))
47+
48+
49+
# Define initial parameters
50+
init_amplitude = 5
51+
init_frequency = 3
52+
53+
# Create the figure and the line that we will manipulate
54+
fig, ax = plt.subplots()
55+
ax.set_xlim(0, 1)
56+
ax.set_ylim(-7, 7)
57+
58+
ax.set_xlabel("Time [s]")
59+
60+
# adjust the main plot to make room for the sliders
61+
fig.subplots_adjust(left=0.25, bottom=0.25, right=0.75)
62+
63+
# Make a horizontal slider to control the frequency.
64+
axfreq = fig.add_axes([0.25, 0.1, 0.65, 0.03])
65+
freq_slider = Slider(
66+
ax=axfreq,
67+
label="Frequency [Hz]",
68+
valmin=0.1,
69+
valmax=30,
70+
valinit=init_frequency,
71+
)
72+
73+
# Make a vertically oriented slider to control the amplitude
74+
axamp = fig.add_axes([0.1, 0.25, 0.0225, 0.63])
75+
amp_slider = Slider(
76+
ax=axamp,
77+
label="Amplitude",
78+
valmin=0,
79+
valmax=10,
80+
valinit=init_amplitude,
81+
orientation="vertical",
82+
)
83+
84+
# Make a vertically oriented slider to control the phase
85+
axphase = fig.add_axes([0.85, 0.25, 0.0225, 0.63])
86+
phase_slider = Slider(
87+
ax=axphase,
88+
label="Phase [rad]",
89+
valmin=-2 * np.pi,
90+
valmax=2 * np.pi,
91+
valinit=0,
92+
orientation="vertical",
93+
)
94+
95+
# pick a cyclic color map
96+
cmap = plt.get_cmap("twilight")
97+
98+
# set up the data container
99+
fc = SliderContainer(
100+
{
101+
# the x data does not need the sliders values
102+
"x": (("N",), lambda t: t),
103+
"y": (
104+
("N",),
105+
# the y data needs all three sliders
106+
lambda t, amplitude, frequency, phase: amplitude * np.sin(2 * np.pi * frequency * t + phase),
107+
),
108+
# the color data has to take the x (because reasons), but just
109+
# needs the phase
110+
"color": ((1,), lambda t, phase: phase),
111+
},
112+
# bind the sliders to the data container
113+
amplitude=amp_slider,
114+
frequency=freq_slider,
115+
phase=phase_slider,
116+
)
117+
lw = LineWrapper(
118+
fc,
119+
# color map phase (scaled to 2pi and wrapped to [0, 1])
120+
{"color": lambda color: cmap((color / (2 * np.pi)) % 1)},
121+
lw=5,
122+
)
123+
ax.add_artist(lw)
124+
125+
126+
# Create a `matplotlib.widgets.Button` to reset the sliders to initial values.
127+
resetax = fig.add_axes([0.8, 0.025, 0.1, 0.04])
128+
button = Button(resetax, "Reset", hovercolor="0.975")
129+
button.on_clicked(lambda event: [sld.reset() for sld in (freq_slider, amp_slider, phase_slider)])
130+
131+
plt.show()

0 commit comments

Comments
 (0)