Skip to content

Commit 88e2e7b

Browse files
authored
Support for anisotropic data. (#8)
* Support for anisotropic data. * add support/comments for swapping/reversing axes
1 parent dcd0a6a commit 88e2e7b

File tree

6 files changed

+132
-50
lines changed

6 files changed

+132
-50
lines changed

dash_3d_viewer/slicer.py

Lines changed: 80 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,25 @@
44
from dash.dependencies import Input, Output, State, ALL
55
from dash_core_components import Graph, Slider, Store
66

7-
from .utils import img_array_to_uri, get_thumbnail_size_from_shape
7+
from .utils import img_array_to_uri, get_thumbnail_size_from_shape, shape3d_to_size2d
88

99

1010
class DashVolumeSlicer:
1111
"""A slicer to show 3D image data in Dash.
1212
1313
Parameters:
1414
app (dash.Dash): the Dash application instance.
15-
volume (ndarray): the 3D numpy array to slice through.
15+
volume (ndarray): the 3D numpy array to slice through. The dimensions
16+
are assumed to be in zyx order. If this is not the case, you can
17+
use ``np.swapaxes`` to make it so.
18+
spacing (tuple of floats): The distance between voxels for each dimension (zyx).
19+
The spacing and origin are applied to make the slice drawn in
20+
"scene space" rather than "voxel space".
21+
origin (tuple of floats): The offset for each dimension (zyx).
1622
axis (int): the dimension to slice in. Default 0.
23+
reverse_y (bool): Whether to reverse the y-axis, so that the origin of
24+
the slice is in the top-left, rather than bottom-left. Default True.
25+
(This sets the figure's yaxes ``autorange`` to either "reversed" or True.)
1726
scene_id (str): the scene that this slicer is part of. Slicers
1827
that have the same scene-id show each-other's positions with
1928
line indicators. By default this is a hash of ``id(volume)``.
@@ -38,14 +47,29 @@ class DashVolumeSlicer:
3847

3948
_global_slicer_counter = 0
4049

41-
def __init__(self, app, volume, axis=0, scene_id=None):
50+
def __init__(
51+
self,
52+
app,
53+
volume,
54+
*,
55+
spacing=None,
56+
origin=None,
57+
axis=0,
58+
reverse_y=True,
59+
scene_id=None
60+
):
61+
# todo: also implement xyz dim order?
4262
if not isinstance(app, Dash):
4363
raise TypeError("Expect first arg to be a Dash app.")
4464
self._app = app
4565
# Check and store volume
4666
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
4767
raise TypeError("Expected volume to be a 3D numpy array")
4868
self._volume = volume
69+
spacing = (1, 1, 1) if spacing is None else spacing
70+
spacing = float(spacing[0]), float(spacing[1]), float(spacing[2])
71+
origin = (0, 0, 0) if origin is None else origin
72+
origin = float(origin[0]), float(origin[1]), float(origin[2])
4973
# Check and store axis
5074
if not (isinstance(axis, int) and 0 <= axis <= 2):
5175
raise ValueError("The given axis must be 0, 1, or 2.")
@@ -60,20 +84,26 @@ def __init__(self, app, volume, axis=0, scene_id=None):
6084
DashVolumeSlicer._global_slicer_counter += 1
6185
self.context_id = "slicer_" + str(DashVolumeSlicer._global_slicer_counter)
6286

63-
# Get the slice size (width, height), and max index
64-
arr_shape = list(volume.shape)
65-
arr_shape.pop(self._axis)
66-
self._slice_size = tuple(reversed(arr_shape))
67-
self._max_index = self._volume.shape[self._axis] - 1
87+
# Prepare slice info
88+
info = {
89+
"shape": tuple(volume.shape),
90+
"axis": self._axis,
91+
"size": shape3d_to_size2d(volume.shape, axis),
92+
"origin": shape3d_to_size2d(origin, axis),
93+
"spacing": shape3d_to_size2d(spacing, axis),
94+
}
6895

6996
# Prep low-res slices
70-
thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32)
97+
thumbnail_size = get_thumbnail_size_from_shape(
98+
(info["size"][1], info["size"][0]), 32
99+
)
71100
thumbnails = [
72101
img_array_to_uri(self._slice(i), thumbnail_size)
73-
for i in range(self._max_index + 1)
102+
for i in range(info["size"][2])
74103
]
104+
info["lowres_size"] = thumbnail_size
75105

76-
# Create a placeholder trace
106+
# Create traces
77107
# todo: can add "%{z[0]}", but that would be the scaled value ...
78108
image_trace = Image(
79109
source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})<extra></extra>"
@@ -97,6 +127,7 @@ def __init__(self, app, volume, axis=0, scene_id=None):
97127
scaleanchor="x",
98128
showticklabels=False,
99129
zeroline=False,
130+
autorange="reversed" if reverse_y else True,
100131
)
101132
# Wrap the figure in a graph
102133
# todo: or should the user provide this?
@@ -106,22 +137,20 @@ def __init__(self, app, volume, axis=0, scene_id=None):
106137
config={"scrollZoom": True},
107138
)
108139
# Create a slider object that the user can put in the layout (or not)
109-
# todo: use tooltip to show current value?
110140
self.slider = Slider(
111141
id=self._subid("slider"),
112142
min=0,
113-
max=self._max_index,
143+
max=info["size"][2] - 1,
114144
step=1,
115-
value=self._max_index // 2,
145+
value=info["size"][2] // 2,
116146
tooltip={"always_visible": False, "placement": "left"},
117147
updatemode="drag",
118148
)
119149
# Create the stores that we need (these must be present in the layout)
120150
self.stores = [
121-
Store(
122-
id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size
123-
),
151+
Store(id=self._subid("info"), data=info),
124152
Store(id=self._subid("index"), data=volume.shape[self._axis] // 2),
153+
Store(id=self._subid("position"), data=0),
125154
Store(id=self._subid("_requested-slice-index"), data=0),
126155
Store(id=self._subid("_slice-data"), data=""),
127156
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
@@ -175,6 +204,17 @@ def _create_client_callbacks(self):
175204
[Input(self._subid("slider"), "value")],
176205
)
177206

207+
app.clientside_callback(
208+
"""
209+
function update_position(index, info) {
210+
return info.origin[2] + index * info.spacing[2];
211+
}
212+
""",
213+
Output(self._subid("position"), "data"),
214+
[Input(self._subid("index"), "data")],
215+
[State(self._subid("info"), "data")],
216+
)
217+
178218
app.clientside_callback(
179219
"""
180220
function handle_slice_index(index) {
@@ -205,7 +245,7 @@ def _create_client_callbacks(self):
205245

206246
app.clientside_callback(
207247
"""
208-
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) {
248+
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, info) {
209249
let new_index = index_and_data[0];
210250
let new_data = index_and_data[1];
211251
// Store data in cache
@@ -214,18 +254,18 @@ def _create_client_callbacks(self):
214254
slice_cache[new_index] = new_data;
215255
// Get the data we need *now*
216256
let data = slice_cache[index];
217-
let x0 = 0, y0 = 0, dx = 1, dy = 1;
257+
let x0 = info.origin[0], y0 = info.origin[1];
258+
let dx = info.spacing[0], dy = info.spacing[1];
218259
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
219260
// Maybe we do not need an update
220261
if (!data) {
221262
data = lowres[index];
222263
// Scale the image to take the exact same space as the full-res
223264
// version. It's not correct, but it looks better ...
224-
// slice_size = full_w, full_h, low_w, low_h
225-
dx = slice_size[0] / slice_size[2];
226-
dy = slice_size[1] / slice_size[3];
227-
x0 = 0.5 * dx - 0.5;
228-
y0 = 0.5 * dy - 0.5;
265+
dx *= info.size[0] / info.lowres_size[0];
266+
dy *= info.size[1] / info.lowres_size[1];
267+
x0 += 0.5 * dx - 0.5 * info.spacing[0];
268+
y0 += 0.5 * dy - 0.5 * info.spacing[1];
229269
}
230270
if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
231271
return window.dash_clientside.no_update;
@@ -253,7 +293,7 @@ def _create_client_callbacks(self):
253293
[
254294
State(self._subid("graph"), "figure"),
255295
State(self._subid("_slice-data-lowres"), "data"),
256-
State(self._subid("_slice-size"), "data"),
296+
State(self._subid("info"), "data"),
257297
],
258298
)
259299

@@ -266,18 +306,22 @@ def _create_client_callbacks(self):
266306
# * match any of the selected axii
267307
app.clientside_callback(
268308
"""
269-
function handle_indicator(indices1, indices2, slice_size, current) {
270-
let w = slice_size[0], h = slice_size[1];
271-
let dx = w / 20, dy = h / 20;
309+
function handle_indicator(positions1, positions2, info, current) {
310+
let x0 = info.origin[0], y0 = info.origin[1];
311+
let x1 = x0 + info.size[0] * info.spacing[0], y1 = y0 + info.size[1] * info.spacing[1];
312+
x0 = x0 - info.spacing[0], y0 = y0 - info.spacing[1];
313+
let d = ((x1 - x0) + (y1 - y0)) * 0.5 * 0.05;
272314
let version = (current.version || 0) + 1;
273315
let x = [], y = [];
274-
for (let index of indices1) {
275-
x.push(...[-dx, -1, null, w, w + dx, null]);
276-
y.push(...[index, index, index, index, index, index]);
316+
for (let pos of positions1) {
317+
// x relative to our slice, y in scene-coords
318+
x.push(...[x0 - d, x0, null, x1, x1 + d, null]);
319+
y.push(...[pos, pos, pos, pos, pos, pos]);
277320
}
278-
for (let index of indices2) {
279-
x.push(...[index, index, index, index, index, index]);
280-
y.push(...[-dy, -1, null, h, h + dy, null]);
321+
for (let pos of positions2) {
322+
// x in scene-coords, y relative to our slice
323+
x.push(...[pos, pos, pos, pos, pos, pos]);
324+
y.push(...[y0 - d, y0, null, y1, y1 + d, null]);
281325
}
282326
return {
283327
type: 'scatter',
@@ -296,15 +340,15 @@ def _create_client_callbacks(self):
296340
{
297341
"scene": self.scene_id,
298342
"context": ALL,
299-
"name": "index",
343+
"name": "position",
300344
"axis": axis,
301345
},
302346
"data",
303347
)
304348
for axis in axii
305349
],
306350
[
307-
State(self._subid("_slice-size"), "data"),
351+
State(self._subid("info"), "data"),
308352
State(self._subid("_indicators"), "data"),
309353
],
310354
)

dash_3d_viewer/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,14 @@ def get_thumbnail_size_from_shape(shape, base_size):
3232
img_pil = PIL.Image.fromarray(img_array)
3333
img_pil.thumbnail((base_size, base_size))
3434
return img_pil.size
35+
36+
37+
def shape3d_to_size2d(shape, axis):
38+
"""Turn a 3d shape (z, y, x) into a local (x', y', z'),
39+
where z' represents the dimension indicated by axis.
40+
"""
41+
shape = list(shape)
42+
axis_value = shape.pop(axis)
43+
size = list(reversed(shape))
44+
size.append(axis_value)
45+
return tuple(size)

examples/slicer_with_1_plus_2_views.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
This demonstrates how multiple indicators can be shown per axis.
44
55
Sharing the same scene_id is enough for the slicers to show each-others
6-
position. If the same volume object is given, it works by default,
6+
position. If the same volume object would be given, it works by default,
77
because the default scene_id is a hash of the volume object. Specifying
88
a scene_id provides slice position indicators even when slicing through
99
different volumes.
10+
11+
Further, this example has one slider showing data with different spacing.
12+
Note how the indicators represent the actual position in "scene coordinates".
13+
1014
"""
1115

1216
import dash
@@ -17,22 +21,33 @@
1721

1822
app = dash.Dash(__name__)
1923

20-
vol = imageio.volread("imageio:stent.npz")
21-
slicer1 = DashVolumeSlicer(app, vol, axis=1, scene_id="myscene")
22-
slicer2 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
23-
slicer3 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
24+
vol1 = imageio.volread("imageio:stent.npz")
25+
26+
vol2 = vol1[::3, ::2, :]
27+
spacing = 3, 2, 1
28+
ori = 110, 120, 140
29+
30+
31+
slicer1 = DashVolumeSlicer(
32+
app, vol1, axis=1, origin=ori, reverse_y=False, scene_id="scene1"
33+
)
34+
slicer2 = DashVolumeSlicer(
35+
app, vol1, axis=0, origin=ori, reverse_y=False, scene_id="scene1"
36+
)
37+
slicer3 = DashVolumeSlicer(
38+
app, vol2, axis=0, origin=ori, spacing=spacing, reverse_y=False, scene_id="scene1"
39+
)
2440

2541
app.layout = html.Div(
2642
style={
2743
"display": "grid",
28-
"grid-template-columns": "40% 40%",
44+
"gridTemplateColumns": "40% 40%",
2945
},
3046
children=[
3147
html.Div(
3248
[
3349
html.H1("Coronal"),
3450
slicer1.graph,
35-
html.Br(),
3651
slicer1.slider,
3752
*slicer1.stores,
3853
]
@@ -41,7 +56,6 @@
4156
[
4257
html.H1("Transversal 1"),
4358
slicer2.graph,
44-
html.Br(),
4559
slicer2.slider,
4660
*slicer2.stores,
4761
]
@@ -51,7 +65,6 @@
5165
[
5266
html.H1("Transversal 2"),
5367
slicer3.graph,
54-
html.Br(),
5568
slicer3.slider,
5669
*slicer3.stores,
5770
]

examples/slicer_with_2_views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
app.layout = html.Div(
1818
style={
1919
"display": "grid",
20-
"grid-template-columns": "40% 40%",
20+
"gridTemplateColumns": "40% 40%",
2121
},
2222
children=[
2323
html.Div(

examples/slicer_with_3_views.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
# Read volumes and create slicer objects
1717
vol = imageio.volread("imageio:stent.npz")
18-
slicer1 = DashVolumeSlicer(app, vol, axis=0)
19-
slicer2 = DashVolumeSlicer(app, vol, axis=1)
20-
slicer3 = DashVolumeSlicer(app, vol, axis=2)
18+
slicer1 = DashVolumeSlicer(app, vol, reverse_y=False, axis=0)
19+
slicer2 = DashVolumeSlicer(app, vol, reverse_y=False, axis=1)
20+
slicer3 = DashVolumeSlicer(app, vol, reverse_y=False, axis=2)
2121

2222
# Calculate isosurface and create a figure with a mesh object
2323
verts, faces, _, _ = marching_cubes(vol, 300, step_size=2)
@@ -30,7 +30,7 @@
3030
app.layout = html.Div(
3131
style={
3232
"display": "grid",
33-
"grid-template-columns": "40% 40%",
33+
"gridTemplateColumns": "40% 40%",
3434
},
3535
children=[
3636
html.Div(

tests/test_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from dash_3d_viewer.utils import shape3d_to_size2d
2+
3+
from pytest import raises
4+
5+
6+
def test_shape3d_to_size2d():
7+
# shape -> z, y, x
8+
# size -> x, y, out-of-plane
9+
assert shape3d_to_size2d((12, 13, 14), 0) == (14, 13, 12)
10+
assert shape3d_to_size2d((12, 13, 14), 1) == (14, 12, 13)
11+
assert shape3d_to_size2d((12, 13, 14), 2) == (13, 12, 14)
12+
13+
with raises(IndexError):
14+
shape3d_to_size2d((12, 13, 14), 3)

0 commit comments

Comments
 (0)