Skip to content

Commit 4e1e316

Browse files
surajpaibwyli
andauthored
Add demo for 2d inference on 3D volume (#479)
* Add demo for 2d inference on 3D volume Add demo for 2d inference on 3D volume Signed-off-by: Suraj Pai <b.pai@maastrichtuniversity.nl> * Add slice_axis to SlidingWindowInferer Signed-off-by: Suraj <b.pai@maastrichtuniversity.nl> * Address comments + PEP compliance Signed-off-by: Suraj <b.pai@maastrichtuniversity.nl> * Change to spatial_dim + assert Signed-off-by: Suraj <b.pai@maastrichtuniversity.nl> Co-authored-by: Wenqi Li <wenqil@nvidia.com>
1 parent 4d42277 commit 4e1e316

File tree

4 files changed

+219
-0
lines changed

4 files changed

+219
-0
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ Training and evaluation examples of 3D segmentation based on UNet3D and syntheti
179179
The examples are built with MONAI workflows, mainly contain: trainer/evaluator, handlers, post_transforms, etc.
180180
#### [3d_image_transforms](./modules/3d_image_transforms.ipynb)
181181
This notebook demonstrates the transformations on volumetric images.
182+
183+
#### [2d_inference_3d_volume](./modules/2d_inference_3d_volume.ipynb)
184+
Tutorial that demonstrates how monai `SlidingWindowInferer` can be used when a 3D volume input needs to be provided slice-by-slice to a 2D model and finally, aggregated into a 3D volume.
185+
186+
182187
#### [autoencoder_mednist](./modules/autoencoder_mednist.ipynb)
183188
This tutorial uses the MedNIST hand CT scan dataset to demonstrate MONAI's autoencoder class. The autoencoder is used with an identity encode/decode (i.e., what you put in is what you should get back), as well as demonstrating its usage for de-blurring and de-noising.
184189
#### [batch_output_transform](./modules/batch_output_transform.py)

figures/2d_inference_3d_input.png

18 KB
Loading

modules/2d_inference_3d_volume.ipynb

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "c408367e",
6+
"metadata": {},
7+
"source": [
8+
"# 2D Model Inference on a 3D Volume "
9+
]
10+
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "a8681db2",
14+
"metadata": {},
15+
"source": [
16+
"Usecase: A 2D Model, such as, a 2D segmentation U-Net operates on 2D input which can be slices from a 3D volume (for example, a CT scan). \n",
17+
"\n",
18+
"After editing sliding window inferer as described in this tutorial, it can handle the entire flow as shown:\n",
19+
"![image](../figures/2d_inference_3d_input.png)\n",
20+
"\n",
21+
"The input is a *3D Volume*, a *2D model* and the output is a *3D volume* with 2D slice predictions aggregated. \n",
22+
"\n"
23+
]
24+
},
25+
{
26+
"cell_type": "markdown",
27+
"id": "239b0d93",
28+
"metadata": {},
29+
"source": [
30+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/2d_inference_3d_volume.ipynb)\n"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": 1,
36+
"id": "f2e1b91f",
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"# Install monai\n",
41+
"!python -c \"import monai\" || pip install -q \"monai-weekly\""
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": 2,
47+
"id": "e9cd1b08",
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"# Import libs\n",
52+
"from monai.inferers import SlidingWindowInferer\n",
53+
"import torch\n",
54+
"from typing import Callable, Any\n",
55+
"from monai.networks.nets import UNet"
56+
]
57+
},
58+
{
59+
"cell_type": "markdown",
60+
"id": "85f00a47",
61+
"metadata": {},
62+
"source": [
63+
"## Overiding SlidingWindowInferer\n",
64+
"The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": 3,
70+
"id": "01f8bfa3",
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"class YourSlidingWindowInferer(SlidingWindowInferer):\n",
75+
" def __init__(self, spatial_dim: int = 0, *args, **kwargs):\n",
76+
" # Set dim to slice the volume across, for example, `0` could slide over axial slices,\n",
77+
" # `1` over coronal slices\n",
78+
" # and `2` over sagittal slices.\n",
79+
" self.spatial_dim = spatial_dim\n",
80+
"\n",
81+
" super().__init__(*args, **kwargs)\n",
82+
"\n",
83+
" def __call__(\n",
84+
" self,\n",
85+
" inputs: torch.Tensor,\n",
86+
" network: Callable[..., torch.Tensor],\n",
87+
" slice_axis: int = 0,\n",
88+
" *args: Any,\n",
89+
" **kwargs: Any,\n",
90+
" ) -> torch.Tensor:\n",
91+
"\n",
92+
" assert (\n",
93+
" self.spatial_dim < 3\n",
94+
" ), \"`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively\"\n",
95+
"\n",
96+
" # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch\n",
97+
" if len(self.roi_size) != len(inputs.shape[2:]):\n",
98+
"\n",
99+
" # If they mismatch and roi_size is 2D add another dimension to roi size\n",
100+
" if len(self.roi_size) == 2:\n",
101+
" self.roi_size = list(self.roi_size)\n",
102+
" self.roi_size.insert(self.spatial_dim, 1)\n",
103+
" else:\n",
104+
" raise RuntimeError(\n",
105+
" \"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \"\n",
106+
" )\n",
107+
"\n",
108+
" return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n",
109+
"\n",
110+
" def network_wrapper(self, network, x, *args, **kwargs):\n",
111+
" \"\"\"\n",
112+
" Wrapper handles cases where inference needs to be done using\n",
113+
" 2D models over 3D volume inputs.\n",
114+
" \"\"\"\n",
115+
" # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n",
116+
" # be handled accordingly\n",
117+
"\n",
118+
" if self.roi_size[self.spatial_dim] == 1:\n",
119+
" # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n",
120+
" x = x.squeeze(dim=self.spatial_dim + 2)\n",
121+
" out = network(x, *args, **kwargs)\n",
122+
" # Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n",
123+
" # the default SlidingWindowInferer class\n",
124+
" return out.unsqueeze(dim=self.spatial_dim + 2)\n",
125+
"\n",
126+
" else:\n",
127+
" return network(x, *args, **kwargs)"
128+
]
129+
},
130+
{
131+
"cell_type": "markdown",
132+
"id": "bb0a63dd",
133+
"metadata": {},
134+
"source": [
135+
"## Testing added functionality\n",
136+
"Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above."
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": 4,
142+
"id": "85b15305",
143+
"metadata": {},
144+
"outputs": [
145+
{
146+
"name": "stdout",
147+
"output_type": "stream",
148+
"text": [
149+
"Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n",
150+
"Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n"
151+
]
152+
}
153+
],
154+
"source": [
155+
"# Create a 2D UNet with randomly initialized weights for testing purposes\n",
156+
"\n",
157+
"# 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units\n",
158+
"net = UNet(\n",
159+
" spatial_dims=2,\n",
160+
" in_channels=1,\n",
161+
" out_channels=1,\n",
162+
" channels=(4, 8, 16),\n",
163+
" strides=(2, 2),\n",
164+
" num_res_units=2,\n",
165+
")\n",
166+
"\n",
167+
"# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n",
168+
"input_volume = torch.ones(1, 1, 64, 256, 256)\n",
169+
"\n",
170+
"# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n",
171+
"axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)\n",
172+
"\n",
173+
"output = axial_inferer(input_volume, net)\n",
174+
"\n",
175+
"# Output is a 3D volume with 2D slices aggregated\n",
176+
"print(\"Axial Inferer Output Shape: \", output.shape)\n",
177+
"# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n",
178+
"coronal_inferer = YourSlidingWindowInferer(\n",
179+
" roi_size=(64, 256),\n",
180+
" sw_batch_size=1,\n",
181+
" spatial_dim=1, # Spatial dim to slice along is added here\n",
182+
" cval=-1,\n",
183+
")\n",
184+
"\n",
185+
"output = coronal_inferer(input_volume, net)\n",
186+
"\n",
187+
"# Output is a 3D volume with 2D slices aggregated\n",
188+
"print(\"Coronal Inferer Output Shape: \", output.shape)"
189+
]
190+
}
191+
],
192+
"metadata": {
193+
"kernelspec": {
194+
"display_name": "Python 3 (ipykernel)",
195+
"language": "python",
196+
"name": "python3"
197+
},
198+
"language_info": {
199+
"codemirror_mode": {
200+
"name": "ipython",
201+
"version": 3
202+
},
203+
"file_extension": ".py",
204+
"mimetype": "text/x-python",
205+
"name": "python",
206+
"nbconvert_exporter": "python",
207+
"pygments_lexer": "ipython3",
208+
"version": "3.7.11"
209+
}
210+
},
211+
"nbformat": 4,
212+
"nbformat_minor": 5
213+
}

runner.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" UNet_input_size_con
3636
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" network_api.ipynb)
3737
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" tcia_csv_processing.ipynb)
3838
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" transform_visualization.ipynb)
39+
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" 2d_inference_3d_volume.ipynb)
3940

4041
# output formatting
4142
separator=""

0 commit comments

Comments
 (0)