Skip to content

Commit 90cdef2

Browse files
authored
update inference 2d from 3d (#746)
* update inference 2d from 3d Signed-off-by: Wenqi Li <wenqil@nvidia.com> * update docstring Signed-off-by: Wenqi Li <wenqil@nvidia.com> * fixes pep8 Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent ffa6c1d commit 90cdef2

File tree

1 file changed

+43
-79
lines changed

1 file changed

+43
-79
lines changed

modules/2d_inference_3d_volume.ipynb

Lines changed: 43 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,25 @@
3232
},
3333
{
3434
"cell_type": "code",
35-
"execution_count": 1,
35+
"execution_count": 4,
3636
"id": "f2e1b91f",
3737
"metadata": {},
3838
"outputs": [],
3939
"source": [
4040
"# Install monai\n",
41-
"!python -c \"import monai\" || pip install -q \"monai-weekly\""
41+
"!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\""
4242
]
4343
},
4444
{
4545
"cell_type": "code",
46-
"execution_count": 2,
46+
"execution_count": 5,
4747
"id": "e9cd1b08",
4848
"metadata": {},
4949
"outputs": [],
5050
"source": [
5151
"# Import libs\n",
52-
"from monai.inferers import SlidingWindowInferer\n",
52+
"from monai.inferers import SliceInferer\n",
5353
"import torch\n",
54-
"from typing import Callable, Any\n",
5554
"from monai.networks.nets import UNet"
5655
]
5756
},
@@ -60,93 +59,49 @@
6059
"id": "85f00a47",
6160
"metadata": {},
6261
"source": [
63-
"## Overiding SlidingWindowInferer\n",
64-
"The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`"
65-
]
66-
},
67-
{
68-
"cell_type": "code",
69-
"execution_count": 3,
70-
"id": "01f8bfa3",
71-
"metadata": {},
72-
"outputs": [],
73-
"source": [
74-
"class YourSlidingWindowInferer(SlidingWindowInferer):\n",
75-
" def __init__(self, spatial_dim: int = 0, *args, **kwargs):\n",
76-
" # Set dim to slice the volume across, for example, `0` could slide over axial slices,\n",
77-
" # `1` over coronal slices\n",
78-
" # and `2` over sagittal slices.\n",
79-
" self.spatial_dim = spatial_dim\n",
80-
"\n",
81-
" super().__init__(*args, **kwargs)\n",
82-
"\n",
83-
" def __call__(\n",
84-
" self,\n",
85-
" inputs: torch.Tensor,\n",
86-
" network: Callable[..., torch.Tensor],\n",
87-
" slice_axis: int = 0,\n",
88-
" *args: Any,\n",
89-
" **kwargs: Any,\n",
90-
" ) -> torch.Tensor:\n",
91-
"\n",
92-
" assert (\n",
93-
" self.spatial_dim < 3\n",
94-
" ), \"`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively\"\n",
95-
"\n",
96-
" # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch\n",
97-
" if len(self.roi_size) != len(inputs.shape[2:]):\n",
98-
"\n",
99-
" # If they mismatch and roi_size is 2D add another dimension to roi size\n",
100-
" if len(self.roi_size) == 2:\n",
101-
" self.roi_size = list(self.roi_size)\n",
102-
" self.roi_size.insert(self.spatial_dim, 1)\n",
103-
" else:\n",
104-
" raise RuntimeError(\n",
105-
" \"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \"\n",
106-
" )\n",
107-
"\n",
108-
" return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n",
109-
"\n",
110-
" def network_wrapper(self, network, x, *args, **kwargs):\n",
111-
" \"\"\"\n",
112-
" Wrapper handles cases where inference needs to be done using\n",
113-
" 2D models over 3D volume inputs.\n",
114-
" \"\"\"\n",
115-
" # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n",
116-
" # be handled accordingly\n",
117-
"\n",
118-
" if self.roi_size[self.spatial_dim] == 1:\n",
119-
" # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n",
120-
" x = x.squeeze(dim=self.spatial_dim + 2)\n",
121-
" out = network(x, *args, **kwargs)\n",
122-
" # Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n",
123-
" # the default SlidingWindowInferer class\n",
124-
" return out.unsqueeze(dim=self.spatial_dim + 2)\n",
125-
"\n",
126-
" else:\n",
127-
" return network(x, *args, **kwargs)"
62+
"## SliceInferer\n",
63+
"The simplest way to achieve this functionality is to extend the `SlidingWindowInferer` in `monai.inferers`. This is made available as `SliceInferer` in MONAI (https://docs.monai.io/en/latest/inferers.html#sliceinferer)."
12864
]
12965
},
13066
{
13167
"cell_type": "markdown",
13268
"id": "bb0a63dd",
13369
"metadata": {},
13470
"source": [
135-
"## Testing added functionality\n",
136-
"Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above."
71+
"## Usage"
13772
]
13873
},
13974
{
14075
"cell_type": "code",
141-
"execution_count": 4,
76+
"execution_count": 6,
14277
"id": "85b15305",
14378
"metadata": {},
14479
"outputs": [
80+
{
81+
"name": "stderr",
82+
"output_type": "stream",
83+
"text": [
84+
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:00<00:00, 107.33it/s]\n"
85+
]
86+
},
87+
{
88+
"name": "stdout",
89+
"output_type": "stream",
90+
"text": [
91+
"Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n"
92+
]
93+
},
94+
{
95+
"name": "stderr",
96+
"output_type": "stream",
97+
"text": [
98+
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:01<00:00, 177.69it/s]\n"
99+
]
100+
},
145101
{
146102
"name": "stdout",
147103
"output_type": "stream",
148104
"text": [
149-
"Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n",
150105
"Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n"
151106
]
152107
}
@@ -167,26 +122,35 @@
167122
"# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n",
168123
"input_volume = torch.ones(1, 1, 64, 256, 256)\n",
169124
"\n",
170-
"# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n",
171-
"axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)\n",
125+
"# Create an instance of SliceInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n",
126+
"axial_inferer = SliceInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1, progress=True)\n",
172127
"\n",
173128
"output = axial_inferer(input_volume, net)\n",
174129
"\n",
175130
"# Output is a 3D volume with 2D slices aggregated\n",
176131
"print(\"Axial Inferer Output Shape: \", output.shape)\n",
177-
"# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n",
178-
"coronal_inferer = YourSlidingWindowInferer(\n",
132+
"# Create an instance of SliceInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n",
133+
"coronal_inferer = SliceInferer(\n",
179134
" roi_size=(64, 256),\n",
180135
" sw_batch_size=1,\n",
181136
" spatial_dim=1, # Spatial dim to slice along is added here\n",
182137
" cval=-1,\n",
138+
" progress=True,\n",
183139
")\n",
184140
"\n",
185141
"output = coronal_inferer(input_volume, net)\n",
186142
"\n",
187143
"# Output is a 3D volume with 2D slices aggregated\n",
188144
"print(\"Coronal Inferer Output Shape: \", output.shape)"
189145
]
146+
},
147+
{
148+
"cell_type": "markdown",
149+
"id": "f2596d86",
150+
"metadata": {},
151+
"source": [
152+
"Note that with `axial_inferer` and `coronal_inferer`, the number of inference iterations is 64 and 256 repectively."
153+
]
190154
}
191155
],
192156
"metadata": {
@@ -205,7 +169,7 @@
205169
"name": "python",
206170
"nbconvert_exporter": "python",
207171
"pygments_lexer": "ipython3",
208-
"version": "3.7.11"
172+
"version": "3.8.12"
209173
}
210174
},
211175
"nbformat": 4,

0 commit comments

Comments
 (0)