Skip to content

Commit c2e45f3

Browse files
committed
Add demo for 2d inference on 3D volume
Add demo for 2d inference on 3D volume Signed-off-by: Suraj Pai <b.pai@maastrichtuniversity.nl>
1 parent bbc440f commit c2e45f3

File tree

4 files changed

+207
-0
lines changed

4 files changed

+207
-0
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ Training and evaluation examples of 3D segmentation based on UNet3D and syntheti
176176
The examples are built with MONAI workflows, mainly contain: trainer/evaluator, handlers, post_transforms, etc.
177177
#### [3d_image_transforms](./modules/3d_image_transforms.ipynb)
178178
This notebook demonstrates the transformations on volumetric images.
179+
180+
#### [2d_inference_3d_volume](./modules/2d_inference_3d_volume.ipynb)
181+
Tutorial that demonstrates how monai `SlidingWindowInferer` can be used when a 3D volume input needs to be provided slice-by-slice to a 2D model and finally, aggregated into a 3D volume.
182+
183+
179184
#### [autoencoder_mednist](./modules/autoencoder_mednist.ipynb)
180185
This tutorial uses the MedNIST hand CT scan dataset to demonstrate MONAI's autoencoder class. The autoencoder is used with an identity encode/decode (i.e., what you put in is what you should get back), as well as demonstrating its usage for de-blurring and de-noising.
181186
#### [batch_output_transform](./modules/batch_output_transform.py)

figures/2d_inference_3d_input.png

18 KB
Loading

modules/2d_inference_3d_volume.ipynb

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "c408367e",
6+
"metadata": {},
7+
"source": [
8+
"# 2D Model Inference on a 3D Volume "
9+
]
10+
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "a8681db2",
14+
"metadata": {},
15+
"source": [
16+
"Usecase: A 2D Model, such as, a 2D segmentation U-Net operates on 2D input which can be slices from a 3D volume (for example, a CT scan). \n",
17+
"\n",
18+
"After editing sliding window inferer as described in this tutorial, it can handle the entire flow as shown:\n",
19+
"![image](../figures/2d_inference_3d_input.png)\n",
20+
"\n",
21+
"The input is a *3D Volume*, a *2D model* and the output is a *3D volume* with 2D slice predictions aggregated. \n",
22+
"\n"
23+
]
24+
},
25+
{
26+
"cell_type": "markdown",
27+
"id": "239b0d93",
28+
"metadata": {},
29+
"source": [
30+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/2d_inference_3d_volume.ipynb)\n"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": 1,
36+
"id": "f2e1b91f",
37+
"metadata": {},
38+
"outputs": [
39+
{
40+
"name": "stdout",
41+
"output_type": "stream",
42+
"text": [
43+
"Requirement already satisfied: monai in /home/suraj/miniconda3/envs/monai/lib/python3.7/site-packages (0.8.0)\n",
44+
"Requirement already satisfied: numpy>=1.17 in /home/suraj/miniconda3/envs/monai/lib/python3.7/site-packages (from monai) (1.21.4)\n",
45+
"Requirement already satisfied: torch>=1.6 in /home/suraj/miniconda3/envs/monai/lib/python3.7/site-packages (from monai) (1.10.0)\n",
46+
"Requirement already satisfied: typing-extensions in /home/suraj/miniconda3/envs/monai/lib/python3.7/site-packages (from torch>=1.6->monai) (4.0.1)\n"
47+
]
48+
}
49+
],
50+
"source": [
51+
"# Install monai\n",
52+
"!pip install monai"
53+
]
54+
},
55+
{
56+
"cell_type": "markdown",
57+
"id": "85f00a47",
58+
"metadata": {},
59+
"source": [
60+
"## Overiding SlidingWindowInferer\n",
61+
"The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": 2,
67+
"id": "01f8bfa3",
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"from monai.inferers import SlidingWindowInferer\n",
72+
"import torch\n",
73+
"from typing import Callable, Any\n",
74+
"\n",
75+
"\n",
76+
"class YourSlidingWindowInferer(SlidingWindowInferer):\n",
77+
" def __init__(self, *args, **kwargs):\n",
78+
" super().__init__(*args, **kwargs)\n",
79+
"\n",
80+
" def __call__(\n",
81+
" self,\n",
82+
" inputs: torch.Tensor,\n",
83+
" network: Callable[..., torch.Tensor],\n",
84+
" *args: Any,\n",
85+
" **kwargs: Any,\n",
86+
" ) -> torch.Tensor:\n",
87+
"\n",
88+
" # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch \n",
89+
" if len(self.roi_size) != len(inputs.shape[2:]):\n",
90+
"\n",
91+
" # If they mismatch and roi_size is 2D add another dimension to roi size\n",
92+
" if len(self.roi_size) == 2:\n",
93+
" self.roi_size = [1, *self.roi_size]\n",
94+
" else:\n",
95+
" raise RuntimeError(\"Unsupported roi size, cannot broadcast to volume. \")\n",
96+
"\n",
97+
" return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n",
98+
"\n",
99+
" def network_wrapper(self, network, x, *args, **kwargs):\n",
100+
" \"\"\"\n",
101+
" Wrapper handles cases where inference needs to be done using \n",
102+
" 2D models over 3D volume inputs.\n",
103+
" \"\"\"\n",
104+
" # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n",
105+
" # be handled accordingly\n",
106+
" if self.roi_size[0] == 1:\n",
107+
" # Pass [N, C, H, W] to the model as it is 2D.\n",
108+
" x = x.squeeze(dim=2)\n",
109+
" out = network(x, *args, **kwargs)\n",
110+
" # Unsqueeze the network output so it is [N, C, D, H, W] as expected by the default SlidingWindowInferer class\n",
111+
" return out.unsqueeze(dim=2)\n",
112+
"\n",
113+
" else:\n",
114+
" return network(x, *args, **kwargs)"
115+
]
116+
},
117+
{
118+
"cell_type": "markdown",
119+
"id": "bb0a63dd",
120+
"metadata": {},
121+
"source": [
122+
"## Testing added functionality\n",
123+
"Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above."
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": 3,
129+
"id": "85b15305",
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"# Create a 2D UNet with randomly initialized weights for testing purposes\n",
134+
"from monai.networks.nets import UNet\n",
135+
"\n",
136+
"# 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units\n",
137+
"net = UNet(\n",
138+
" spatial_dims=2,\n",
139+
" in_channels=1,\n",
140+
" out_channels=1,\n",
141+
" channels=(4, 8, 16),\n",
142+
" strides=(2, 2),\n",
143+
" num_res_units=2\n",
144+
")\n",
145+
"\n",
146+
"# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n",
147+
"input_volume = torch.ones(1, 1, 30, 256, 256)\n",
148+
"\n",
149+
"# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW)\n",
150+
"inferer = YourSlidingWindowInferer(roi_size=(256, 256),\n",
151+
" sw_batch_size=1,\n",
152+
" cval=-1)\n",
153+
"\n",
154+
"output = inferer(input_volume, net)"
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": 4,
160+
"id": "5ad96534",
161+
"metadata": {},
162+
"outputs": [
163+
{
164+
"data": {
165+
"text/plain": [
166+
"torch.Size([1, 1, 30, 256, 256])"
167+
]
168+
},
169+
"execution_count": 4,
170+
"metadata": {},
171+
"output_type": "execute_result"
172+
}
173+
],
174+
"source": [
175+
"# Output is a 3D volume with 2D slices aggregated\n",
176+
"output.shape"
177+
]
178+
}
179+
],
180+
"metadata": {
181+
"kernelspec": {
182+
"display_name": "Python 3 (ipykernel)",
183+
"language": "python",
184+
"name": "python3"
185+
},
186+
"language_info": {
187+
"codemirror_mode": {
188+
"name": "ipython",
189+
"version": 3
190+
},
191+
"file_extension": ".py",
192+
"mimetype": "text/x-python",
193+
"name": "python",
194+
"nbconvert_exporter": "python",
195+
"pygments_lexer": "ipython3",
196+
"version": "3.7.11"
197+
}
198+
},
199+
"nbformat": 4,
200+
"nbformat_minor": 5
201+
}

runner.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" UNet_input_size_con
3636
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" network_api.ipynb)
3737
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" tcia_csv_processing.ipynb)
3838
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" transform_visualization.ipynb)
39+
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" 2d_inference_3d_volume.ipynb)
3940

4041
# output formatting
4142
separator=""

0 commit comments

Comments
 (0)