diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2d0cba7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.pyc +*.pyo +*.egg-info +dist/ +build/ diff --git a/README.md b/README.md index 3b4dbc8..e7a5217 100644 --- a/README.md +++ b/README.md @@ -1 +1,32 @@ -# dash-3d-viewer \ No newline at end of file +# dash-3d-viewer + +A tool to make it easy to build slice-views on 3D image data, in Dash apps. + +The API is currently a WIP. + + +## Installation + +Eventually, this would be pip-installable. For now, use the developer workflow. + + +## Usage + +TODO, see the examples. + + +## License + +This code is distributed under MIT license. + + +## Developers + + +* Make sure that you have Python with the appropriate dependencies installed, e.g. via `venv`. +* Run `pip install -e .` to do an in-place install of the package. +* Run the examples using e.g. `python examples/slicer_with_1_view.py` + +* Use `black .` to autoformat. +* Use `flake8 .` to lint. +* Use `pytest .` to run the tests. diff --git a/dash_3d_viewer/__init__.py b/dash_3d_viewer/__init__.py new file mode 100644 index 0000000..1f56b2a --- /dev/null +++ b/dash_3d_viewer/__init__.py @@ -0,0 +1,10 @@ +""" +Dash 3d viewer - a tool to make it easy to build slice-views on 3D image data. +""" + + +from .slicer import DashVolumeSlicer # noqa: F401 + + +__version__ = "0.0.1" +version_info = tuple(map(int, __version__.split("."))) diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py new file mode 100644 index 0000000..d367ed4 --- /dev/null +++ b/dash_3d_viewer/slicer.py @@ -0,0 +1,192 @@ +import numpy as np +from plotly.graph_objects import Figure +from dash import Dash +from dash.dependencies import Input, Output, State +from dash_core_components import Graph, Slider, Store + +from .utils import gen_random_id, img_array_to_uri + + +class DashVolumeSlicer: + """A slicer to show 3D image data in Dash.""" + + def __init__(self, app, volume, axis=0, id=None): + if not isinstance(app, Dash): + raise TypeError("Expect first arg to be a Dash app.") + # Check and store volume + if not (isinstance(volume, np.ndarray) and volume.ndim == 3): + raise TypeError("Expected volume to be a 3D numpy array") + self._volume = volume + # Check and store axis + if not (isinstance(axis, int) and 0 <= axis <= 2): + raise ValueError("The given axis must be 0, 1, or 2.") + self._axis = int(axis) + # Check and store id + if id is None: + id = gen_random_id() + elif not isinstance(id, str): + raise TypeError("Id must be a string") + self._id = id + + # Get the slice size (width, height), and max index + arr_shape = list(volume.shape) + arr_shape.pop(self._axis) + slice_size = list(reversed(arr_shape)) + self._max_index = self._volume.shape[self._axis] - 1 + + # Create the figure object + fig = Figure() + fig.update_layout( + template=None, + margin=dict(l=0, r=0, b=0, t=0, pad=4), + ) + fig.update_xaxes( + showgrid=False, + range=(0, slice_size[0]), + showticklabels=False, + zeroline=False, + ) + fig.update_yaxes( + showgrid=False, + scaleanchor="x", + range=(slice_size[1], 0), # todo: allow flipping x or y + showticklabels=False, + zeroline=False, + ) + # Add an empty layout image that we can populate from JS. + fig.add_layout_image( + dict( + source="", + xref="x", + yref="y", + x=0, + y=0, + sizex=slice_size[0], + sizey=slice_size[1], + sizing="contain", + layer="below", + ) + ) + # Wrap the figure in a graph + # todo: or should the user provide this? + self.graph = Graph( + id=self._subid("graph"), + figure=fig, + config={"scrollZoom": True}, + ) + # Create a slider object that the user can put in the layout (or not) + self.slider = Slider( + id=self._subid("slider"), + min=0, + max=self._max_index, + step=1, + value=self._max_index // 2, + updatemode="drag", + ) + # Create the stores that we need (these must be present in the layout) + self.stores = [ + Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2), + Store(id=self._subid("_requested-slice-index"), data=0), + Store(id=self._subid("_slice-data"), data=""), + ] + + self._create_server_callbacks(app) + self._create_client_callbacks(app) + + def _subid(self, subid): + """Given a subid, get the full id including the slicer's prefix.""" + return self._id + "-" + subid + + def _slice(self, index): + """Sample a slice from the volume.""" + indices = [slice(None), slice(None), slice(None)] + indices[self._axis] = index + return self._volume[tuple(indices)] + + def _create_server_callbacks(self, app): + """Create the callbacks that run server-side.""" + + @app.callback( + Output(self._subid("_slice-data"), "data"), + [Input(self._subid("_requested-slice-index"), "data")], + ) + def upload_requested_slice(slice_index): + slice = self._slice(slice_index) + slice = (slice.astype(np.float32) * (255 / slice.max())).astype(np.uint8) + return [slice_index, img_array_to_uri(slice)] + + def _create_client_callbacks(self, app): + """Create the callbacks that run client-side.""" + + app.clientside_callback( + """ + function handle_slider_move(index) { + return index; + } + """, + Output(self._subid("slice-index"), "data"), + [Input(self._subid("slider"), "value")], + ) + + app.clientside_callback( + """ + function handle_slice_index(index) { + if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; } + let slice_cache = window.slicecache_for_{{ID}}; + if (slice_cache[index]) { + return window.dash_clientside.no_update; + } else { + console.log('requesting slice ' + index) + return index; + } + } + """.replace( + "{{ID}}", self._id + ), + Output(self._subid("_requested-slice-index"), "data"), + [Input(self._subid("slice-index"), "data")], + ) + + # app.clientside_callback(""" + # function update_slider_pos(index) { + # return index; + # } + # """, + # [Output("slice-index", "data")], + # [State("slider", "value")], + # ) + + app.clientside_callback( + """ + function handle_incoming_slice(index, index_and_data, ori_figure) { + let new_index = index_and_data[0]; + let new_data = index_and_data[1]; + // Store data in cache + if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; } + let slice_cache = window.slicecache_for_{{ID}}; + slice_cache[new_index] = new_data; + // Get the data we need *now* + let data = slice_cache[index]; + // Maybe we do not need an update + if (!data) { + return window.dash_clientside.no_update; + } + if (data == ori_figure.layout.images[0].source) { + return window.dash_clientside.no_update; + } + // Otherwise, perform update + console.log("updating figure"); + let figure = {...ori_figure}; + figure.layout.images[0].source = data; + return figure; + } + """.replace( + "{{ID}}", self._id + ), + Output(self._subid("graph"), "figure"), + [ + Input(self._subid("slice-index"), "data"), + Input(self._subid("_slice-data"), "data"), + ], + [State(self._subid("graph"), "figure")], + ) diff --git a/dash_3d_viewer/utils.py b/dash_3d_viewer/utils.py new file mode 100644 index 0000000..61846e1 --- /dev/null +++ b/dash_3d_viewer/utils.py @@ -0,0 +1,19 @@ +import random + +import PIL.Image +import skimage +from plotly.utils import ImageUriValidator + + +def gen_random_id(n=6): + return "".join(random.choice("abcdefghijklmnopqrtsuvwxyz") for i in range(n)) + + +def img_array_to_uri(img_array): + img_array = skimage.util.img_as_ubyte(img_array) + # todo: leverage this Plotly util once it becomes part of the public API (also drops the Pillow dependency) + # from plotly.express._imshow import _array_to_b64str + # return _array_to_b64str(img_array) + img_pil = PIL.Image.fromarray(img_array) + uri = ImageUriValidator.pil_image_to_uri(img_pil) + return uri diff --git a/examples/slicer_with_1_view.py b/examples/slicer_with_1_view.py new file mode 100644 index 0000000..61ac309 --- /dev/null +++ b/examples/slicer_with_1_view.py @@ -0,0 +1,20 @@ +""" +A truly minimal example. +""" + +import dash +import dash_html_components as html +from dash_3d_viewer import DashVolumeSlicer +import imageio + + +app = dash.Dash(__name__) + +vol = imageio.volread("imageio:stent.npz") +slicer = DashVolumeSlicer(app, vol) + +app.layout = html.Div([slicer.graph, slicer.slider, *slicer.stores]) + + +if __name__ == "__main__": + app.run_server(debug=False) diff --git a/examples/slicer_with_2_views.py b/examples/slicer_with_2_views.py new file mode 100644 index 0000000..7913e2d --- /dev/null +++ b/examples/slicer_with_2_views.py @@ -0,0 +1,46 @@ +""" +An example with two slicers on the same volume. +""" + +import dash +import dash_html_components as html +from dash_3d_viewer import DashVolumeSlicer +import imageio + + +app = dash.Dash(__name__) + +vol = imageio.volread("imageio:stent.npz") +slicer1 = DashVolumeSlicer(app, vol, axis=1, id="slicer1") +slicer2 = DashVolumeSlicer(app, vol, axis=2, id="slicer2") + +app.layout = html.Div( + style={ + "display": "grid", + "grid-template-columns": "40% 40%", + }, + children=[ + html.Div( + [ + html.H1("Coronal"), + slicer1.graph, + html.Br(), + slicer1.slider, + *slicer1.stores, + ] + ), + html.Div( + [ + html.H1("Sagittal"), + slicer2.graph, + html.Br(), + slicer2.slider, + *slicer2.stores, + ] + ), + ], +) + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py new file mode 100644 index 0000000..93e3906 --- /dev/null +++ b/examples/slicer_with_3_views.py @@ -0,0 +1,71 @@ +""" +An example creating three slice-views through a volume, as is common +in medical applications. In the fourth quadrant we put an isosurface mesh. +""" + +import plotly.graph_objects as go +import dash +import dash_html_components as html +import dash_core_components as dcc +from dash_3d_viewer import DashVolumeSlicer +from skimage.measure import marching_cubes +import imageio + +app = dash.Dash(__name__) + +# Read volumes and create slicer objects +vol = imageio.volread("imageio:stent.npz") +slicer1 = DashVolumeSlicer(app, vol, axis=0, id="slicer1") +slicer2 = DashVolumeSlicer(app, vol, axis=1, id="slicer2") +slicer3 = DashVolumeSlicer(app, vol, axis=2, id="slicer3") + +# Calculate isosurface and create a figure with a mesh object +verts, faces, _, _ = marching_cubes(vol, 300, step_size=2) +x, y, z = verts.T +i, j, k = faces.T +fig_mesh = go.Figure() +fig_mesh.add_trace(go.Mesh3d(x=z, y=y, z=x, opacity=0.2, i=k, j=j, k=i)) + +# Put everything together in a 2x2 grid +app.layout = html.Div( + style={ + "display": "grid", + "grid-template-columns": "40% 40%", + }, + children=[ + html.Div( + [ + html.Center(html.H1("Transversal")), + slicer1.graph, + html.Br(), + slicer1.slider, + *slicer1.stores, + ] + ), + html.Div( + [ + html.Center(html.H1("Coronal")), + slicer2.graph, + html.Br(), + slicer2.slider, + *slicer2.stores, + ] + ), + html.Div( + [ + html.Center(html.H1("Sagittal")), + slicer3.graph, + html.Br(), + slicer3.slider, + *slicer3.stores, + ] + ), + html.Div( + [html.Center(html.H1("3D")), dcc.Graph(id="graph-helper", figure=fig_mesh)] + ), + ], +) + + +if __name__ == "__main__": + app.run_server(debug=False) diff --git a/examples/use_components.py b/examples/use_components.py new file mode 100644 index 0000000..229a66a --- /dev/null +++ b/examples/use_components.py @@ -0,0 +1,63 @@ +""" +A small example showing how to write callbacks involving the slicer's +components. The slicer's components are used as both inputs and outputs. +""" + +import dash +import dash_html_components as html +from dash.dependencies import Input, Output, State +from dash_3d_viewer import DashVolumeSlicer +import imageio + + +app = dash.Dash(__name__) + +vol = imageio.volread("imageio:stent.npz") +slicer = DashVolumeSlicer(app, vol) + +# We can access the components, and modify them +slicer.slider.value = 0 + +# Define the layour, including extra buttons +app.layout = html.Div( + [ + slicer.graph, + html.Br(), + html.Div( + style={"display": "flex"}, + children=[ + html.Div("", id="index-show", style={"padding": "0.4em"}), + html.Button("<", id="decrease-index"), + html.Div(slicer.slider, style={"flexGrow": "1"}), + html.Button(">", id="increase-index"), + ], + ), + *slicer.stores, + ] +) + +# New callbacks for our added widgets + + +@app.callback( + Output("index-show", "children"), + [Input(slicer.slider.id, "value")], +) +def show_slider_value(index): + return str(index) + + +@app.callback( + Output(slicer.slider.id, "value"), + [Input("decrease-index", "n_clicks"), Input("increase-index", "n_clicks")], + [State(slicer.slider.id, "value")], +) +def handle_button_input(press1, press2, index): + ctx = dash.callback_context + if ctx.triggered: + index += 1 if "increase" in ctx.triggered[0]["prop_id"] else -1 + return index + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2e416aa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +pillow +numpy +plotly +dash +dash_core_components diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..39e30a6 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,4 @@ +[flake8] +max_line_length = 89 +extend-ignore = E501 +exclude = build,dist,*.egg-info diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..594f4b3 --- /dev/null +++ b/setup.py @@ -0,0 +1,48 @@ +import re + +from setuptools import find_packages, setup + + +NAME = "dash_3d_viewer" +SUMMARY = ( + "A library to make it easy to build slice-views on 3D image data in Dash apps." +) + +with open(f"{NAME}/__init__.py") as fh: + VERSION = re.search(r"__version__ = \"(.*?)\"", fh.read()).group(1) + + +runtime_deps = [ + "pillow", + "numpy", + "plotly", + "dash", + "dash_core_components", + "scikit-image", # may not be needed eventually? +] + + +setup( + name=NAME, + version=VERSION, + packages=find_packages(exclude=["tests", "tests.*", "examples", "examples.*"]), + python_requires=">=3.6.0", + install_requires=runtime_deps, + license="MIT", + description=SUMMARY, + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + author="Plotly", + author_email="almar.klein@gmail.com", + # url="https://github.com/plotly/will be renamed?", + data_files=[("", ["LICENSE"])], + zip_safe=True, # not if we put JS in a seperate file, I think + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Scientific/Engineering :: Visualization", + ], +)