Skip to content

Commit 56b173e

Browse files
authored
First poc (#2)
* first code comitted * flakeify * some cleaning and tweaking * Implement ids, to support multiple slicers. Add examples * add 3d view to example * comments / clean up * fix * add setup.py * add example that actually used the sub-components * add a note about plotlies _array_to_b64str * update readme some more
1 parent 363f7d2 commit 56b173e

12 files changed

+516
-1
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
__pycache__
2+
*.pyc
3+
*.pyo
4+
*.egg-info
5+
dist/
6+
build/

README.md

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,32 @@
1-
# dash-3d-viewer
1+
# dash-3d-viewer
2+
3+
A tool to make it easy to build slice-views on 3D image data, in Dash apps.
4+
5+
The API is currently a WIP.
6+
7+
8+
## Installation
9+
10+
Eventually, this would be pip-installable. For now, use the developer workflow.
11+
12+
13+
## Usage
14+
15+
TODO, see the examples.
16+
17+
18+
## License
19+
20+
This code is distributed under MIT license.
21+
22+
23+
## Developers
24+
25+
26+
* Make sure that you have Python with the appropriate dependencies installed, e.g. via `venv`.
27+
* Run `pip install -e .` to do an in-place install of the package.
28+
* Run the examples using e.g. `python examples/slicer_with_1_view.py`
29+
30+
* Use `black .` to autoformat.
31+
* Use `flake8 .` to lint.
32+
* Use `pytest .` to run the tests.

dash_3d_viewer/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
Dash 3d viewer - a tool to make it easy to build slice-views on 3D image data.
3+
"""
4+
5+
6+
from .slicer import DashVolumeSlicer # noqa: F401
7+
8+
9+
__version__ = "0.0.1"
10+
version_info = tuple(map(int, __version__.split(".")))

dash_3d_viewer/slicer.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import numpy as np
2+
from plotly.graph_objects import Figure
3+
from dash import Dash
4+
from dash.dependencies import Input, Output, State
5+
from dash_core_components import Graph, Slider, Store
6+
7+
from .utils import gen_random_id, img_array_to_uri
8+
9+
10+
class DashVolumeSlicer:
11+
"""A slicer to show 3D image data in Dash."""
12+
13+
def __init__(self, app, volume, axis=0, id=None):
14+
if not isinstance(app, Dash):
15+
raise TypeError("Expect first arg to be a Dash app.")
16+
# Check and store volume
17+
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
18+
raise TypeError("Expected volume to be a 3D numpy array")
19+
self._volume = volume
20+
# Check and store axis
21+
if not (isinstance(axis, int) and 0 <= axis <= 2):
22+
raise ValueError("The given axis must be 0, 1, or 2.")
23+
self._axis = int(axis)
24+
# Check and store id
25+
if id is None:
26+
id = gen_random_id()
27+
elif not isinstance(id, str):
28+
raise TypeError("Id must be a string")
29+
self._id = id
30+
31+
# Get the slice size (width, height), and max index
32+
arr_shape = list(volume.shape)
33+
arr_shape.pop(self._axis)
34+
slice_size = list(reversed(arr_shape))
35+
self._max_index = self._volume.shape[self._axis] - 1
36+
37+
# Create the figure object
38+
fig = Figure()
39+
fig.update_layout(
40+
template=None,
41+
margin=dict(l=0, r=0, b=0, t=0, pad=4),
42+
)
43+
fig.update_xaxes(
44+
showgrid=False,
45+
range=(0, slice_size[0]),
46+
showticklabels=False,
47+
zeroline=False,
48+
)
49+
fig.update_yaxes(
50+
showgrid=False,
51+
scaleanchor="x",
52+
range=(slice_size[1], 0), # todo: allow flipping x or y
53+
showticklabels=False,
54+
zeroline=False,
55+
)
56+
# Add an empty layout image that we can populate from JS.
57+
fig.add_layout_image(
58+
dict(
59+
source="",
60+
xref="x",
61+
yref="y",
62+
x=0,
63+
y=0,
64+
sizex=slice_size[0],
65+
sizey=slice_size[1],
66+
sizing="contain",
67+
layer="below",
68+
)
69+
)
70+
# Wrap the figure in a graph
71+
# todo: or should the user provide this?
72+
self.graph = Graph(
73+
id=self._subid("graph"),
74+
figure=fig,
75+
config={"scrollZoom": True},
76+
)
77+
# Create a slider object that the user can put in the layout (or not)
78+
self.slider = Slider(
79+
id=self._subid("slider"),
80+
min=0,
81+
max=self._max_index,
82+
step=1,
83+
value=self._max_index // 2,
84+
updatemode="drag",
85+
)
86+
# Create the stores that we need (these must be present in the layout)
87+
self.stores = [
88+
Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2),
89+
Store(id=self._subid("_requested-slice-index"), data=0),
90+
Store(id=self._subid("_slice-data"), data=""),
91+
]
92+
93+
self._create_server_callbacks(app)
94+
self._create_client_callbacks(app)
95+
96+
def _subid(self, subid):
97+
"""Given a subid, get the full id including the slicer's prefix."""
98+
return self._id + "-" + subid
99+
100+
def _slice(self, index):
101+
"""Sample a slice from the volume."""
102+
indices = [slice(None), slice(None), slice(None)]
103+
indices[self._axis] = index
104+
return self._volume[tuple(indices)]
105+
106+
def _create_server_callbacks(self, app):
107+
"""Create the callbacks that run server-side."""
108+
109+
@app.callback(
110+
Output(self._subid("_slice-data"), "data"),
111+
[Input(self._subid("_requested-slice-index"), "data")],
112+
)
113+
def upload_requested_slice(slice_index):
114+
slice = self._slice(slice_index)
115+
slice = (slice.astype(np.float32) * (255 / slice.max())).astype(np.uint8)
116+
return [slice_index, img_array_to_uri(slice)]
117+
118+
def _create_client_callbacks(self, app):
119+
"""Create the callbacks that run client-side."""
120+
121+
app.clientside_callback(
122+
"""
123+
function handle_slider_move(index) {
124+
return index;
125+
}
126+
""",
127+
Output(self._subid("slice-index"), "data"),
128+
[Input(self._subid("slider"), "value")],
129+
)
130+
131+
app.clientside_callback(
132+
"""
133+
function handle_slice_index(index) {
134+
if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; }
135+
let slice_cache = window.slicecache_for_{{ID}};
136+
if (slice_cache[index]) {
137+
return window.dash_clientside.no_update;
138+
} else {
139+
console.log('requesting slice ' + index)
140+
return index;
141+
}
142+
}
143+
""".replace(
144+
"{{ID}}", self._id
145+
),
146+
Output(self._subid("_requested-slice-index"), "data"),
147+
[Input(self._subid("slice-index"), "data")],
148+
)
149+
150+
# app.clientside_callback("""
151+
# function update_slider_pos(index) {
152+
# return index;
153+
# }
154+
# """,
155+
# [Output("slice-index", "data")],
156+
# [State("slider", "value")],
157+
# )
158+
159+
app.clientside_callback(
160+
"""
161+
function handle_incoming_slice(index, index_and_data, ori_figure) {
162+
let new_index = index_and_data[0];
163+
let new_data = index_and_data[1];
164+
// Store data in cache
165+
if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; }
166+
let slice_cache = window.slicecache_for_{{ID}};
167+
slice_cache[new_index] = new_data;
168+
// Get the data we need *now*
169+
let data = slice_cache[index];
170+
// Maybe we do not need an update
171+
if (!data) {
172+
return window.dash_clientside.no_update;
173+
}
174+
if (data == ori_figure.layout.images[0].source) {
175+
return window.dash_clientside.no_update;
176+
}
177+
// Otherwise, perform update
178+
console.log("updating figure");
179+
let figure = {...ori_figure};
180+
figure.layout.images[0].source = data;
181+
return figure;
182+
}
183+
""".replace(
184+
"{{ID}}", self._id
185+
),
186+
Output(self._subid("graph"), "figure"),
187+
[
188+
Input(self._subid("slice-index"), "data"),
189+
Input(self._subid("_slice-data"), "data"),
190+
],
191+
[State(self._subid("graph"), "figure")],
192+
)

dash_3d_viewer/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import random
2+
3+
import PIL.Image
4+
import skimage
5+
from plotly.utils import ImageUriValidator
6+
7+
8+
def gen_random_id(n=6):
9+
return "".join(random.choice("abcdefghijklmnopqrtsuvwxyz") for i in range(n))
10+
11+
12+
def img_array_to_uri(img_array):
13+
img_array = skimage.util.img_as_ubyte(img_array)
14+
# todo: leverage this Plotly util once it becomes part of the public API (also drops the Pillow dependency)
15+
# from plotly.express._imshow import _array_to_b64str
16+
# return _array_to_b64str(img_array)
17+
img_pil = PIL.Image.fromarray(img_array)
18+
uri = ImageUriValidator.pil_image_to_uri(img_pil)
19+
return uri

examples/slicer_with_1_view.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""
2+
A truly minimal example.
3+
"""
4+
5+
import dash
6+
import dash_html_components as html
7+
from dash_3d_viewer import DashVolumeSlicer
8+
import imageio
9+
10+
11+
app = dash.Dash(__name__)
12+
13+
vol = imageio.volread("imageio:stent.npz")
14+
slicer = DashVolumeSlicer(app, vol)
15+
16+
app.layout = html.Div([slicer.graph, slicer.slider, *slicer.stores])
17+
18+
19+
if __name__ == "__main__":
20+
app.run_server(debug=False)

examples/slicer_with_2_views.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
An example with two slicers on the same volume.
3+
"""
4+
5+
import dash
6+
import dash_html_components as html
7+
from dash_3d_viewer import DashVolumeSlicer
8+
import imageio
9+
10+
11+
app = dash.Dash(__name__)
12+
13+
vol = imageio.volread("imageio:stent.npz")
14+
slicer1 = DashVolumeSlicer(app, vol, axis=1, id="slicer1")
15+
slicer2 = DashVolumeSlicer(app, vol, axis=2, id="slicer2")
16+
17+
app.layout = html.Div(
18+
style={
19+
"display": "grid",
20+
"grid-template-columns": "40% 40%",
21+
},
22+
children=[
23+
html.Div(
24+
[
25+
html.H1("Coronal"),
26+
slicer1.graph,
27+
html.Br(),
28+
slicer1.slider,
29+
*slicer1.stores,
30+
]
31+
),
32+
html.Div(
33+
[
34+
html.H1("Sagittal"),
35+
slicer2.graph,
36+
html.Br(),
37+
slicer2.slider,
38+
*slicer2.stores,
39+
]
40+
),
41+
],
42+
)
43+
44+
45+
if __name__ == "__main__":
46+
app.run_server(debug=True)

0 commit comments

Comments
 (0)