Skip to content

Commit e463f9c

Browse files
committed
Add slice_axis to SlidingWindowInferer
Signed-off-by: Suraj <b.pai@maastrichtuniversity.nl>
1 parent c2e45f3 commit e463f9c

File tree

1 file changed

+47
-31
lines changed

1 file changed

+47
-31
lines changed

modules/2d_inference_3d_volume.ipynb

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,19 @@
7474
"\n",
7575
"\n",
7676
"class YourSlidingWindowInferer(SlidingWindowInferer):\n",
77-
" def __init__(self, *args, **kwargs):\n",
77+
" def __init__(self, slice_axis: int = 0, \n",
78+
" *args, **kwargs):\n",
79+
" # Set axis to slice the volume across, for example, `0` could slide over axial slices, `1` over coronal slices \n",
80+
" # and `2` over sagittal slices.\n",
81+
" self.slice_axis = slice_axis\n",
82+
"\n",
7883
" super().__init__(*args, **kwargs)\n",
7984
"\n",
8085
" def __call__(\n",
8186
" self,\n",
8287
" inputs: torch.Tensor,\n",
8388
" network: Callable[..., torch.Tensor],\n",
89+
" slice_axis: int = 0,\n",
8490
" *args: Any,\n",
8591
" **kwargs: Any,\n",
8692
" ) -> torch.Tensor:\n",
@@ -90,10 +96,11 @@
9096
"\n",
9197
" # If they mismatch and roi_size is 2D add another dimension to roi size\n",
9298
" if len(self.roi_size) == 2:\n",
93-
" self.roi_size = [1, *self.roi_size]\n",
99+
" self.roi_size = list(self.roi_size)\n",
100+
" self.roi_size.insert(self.slice_axis, 1)\n",
94101
" else:\n",
95-
" raise RuntimeError(\"Unsupported roi size, cannot broadcast to volume. \")\n",
96-
"\n",
102+
" raise RuntimeError(\"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \")\n",
103+
" \n",
97104
" return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n",
98105
"\n",
99106
" def network_wrapper(self, network, x, *args, **kwargs):\n",
@@ -103,12 +110,13 @@
103110
" \"\"\"\n",
104111
" # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n",
105112
" # be handled accordingly\n",
106-
" if self.roi_size[0] == 1:\n",
107-
" # Pass [N, C, H, W] to the model as it is 2D.\n",
108-
" x = x.squeeze(dim=2)\n",
113+
"\n",
114+
" if self.roi_size[self.slice_axis] == 1:\n",
115+
" # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n",
116+
" x = x.squeeze(dim=self.slice_axis + 2)\n",
109117
" out = network(x, *args, **kwargs)\n",
110118
" # Unsqueeze the network output so it is [N, C, D, H, W] as expected by the default SlidingWindowInferer class\n",
111-
" return out.unsqueeze(dim=2)\n",
119+
" return out.unsqueeze(dim=self.slice_axis + 2)\n",
112120
"\n",
113121
" else:\n",
114122
" return network(x, *args, **kwargs)"
@@ -128,7 +136,16 @@
128136
"execution_count": 3,
129137
"id": "85b15305",
130138
"metadata": {},
131-
"outputs": [],
139+
"outputs": [
140+
{
141+
"name": "stdout",
142+
"output_type": "stream",
143+
"text": [
144+
"Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n",
145+
"Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n"
146+
]
147+
}
148+
],
132149
"source": [
133150
"# Create a 2D UNet with randomly initialized weights for testing purposes\n",
134151
"from monai.networks.nets import UNet\n",
@@ -144,37 +161,36 @@
144161
")\n",
145162
"\n",
146163
"# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n",
147-
"input_volume = torch.ones(1, 1, 30, 256, 256)\n",
164+
"input_volume = torch.ones(1, 1, 64, 256, 256)\n",
148165
"\n",
149-
"# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW)\n",
150-
"inferer = YourSlidingWindowInferer(roi_size=(256, 256),\n",
166+
"# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n",
167+
"axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256),\n",
151168
" sw_batch_size=1,\n",
152169
" cval=-1)\n",
153170
"\n",
154-
"output = inferer(input_volume, net)"
171+
"output = axial_inferer(input_volume, net)\n",
172+
"\n",
173+
"# Output is a 3D volume with 2D slices aggregated\n",
174+
"print(\"Axial Inferer Output Shape: \", output.shape)\n",
175+
"\n",
176+
"# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n",
177+
"coronal_inferer = YourSlidingWindowInferer(roi_size=(64, 256),\n",
178+
" sw_batch_size=1, slice_axis=1, # Slice axis is added here\n",
179+
" cval=-1)\n",
180+
"\n",
181+
"output = coronal_inferer(input_volume, net)\n",
182+
"\n",
183+
"# Output is a 3D volume with 2D slices aggregated\n",
184+
"print(\"Coronal Inferer Output Shape: \", output.shape)"
155185
]
156186
},
157187
{
158188
"cell_type": "code",
159-
"execution_count": 4,
160-
"id": "5ad96534",
189+
"execution_count": null,
190+
"id": "454c353f",
161191
"metadata": {},
162-
"outputs": [
163-
{
164-
"data": {
165-
"text/plain": [
166-
"torch.Size([1, 1, 30, 256, 256])"
167-
]
168-
},
169-
"execution_count": 4,
170-
"metadata": {},
171-
"output_type": "execute_result"
172-
}
173-
],
174-
"source": [
175-
"# Output is a 3D volume with 2D slices aggregated\n",
176-
"output.shape"
177-
]
192+
"outputs": [],
193+
"source": []
178194
}
179195
],
180196
"metadata": {

0 commit comments

Comments
 (0)