|
32 | 32 | },
|
33 | 33 | {
|
34 | 34 | "cell_type": "code",
|
35 |
| - "execution_count": 1, |
| 35 | + "execution_count": 4, |
36 | 36 | "id": "f2e1b91f",
|
37 | 37 | "metadata": {},
|
38 | 38 | "outputs": [],
|
39 | 39 | "source": [
|
40 | 40 | "# Install monai\n",
|
41 |
| - "!python -c \"import monai\" || pip install -q \"monai-weekly\"" |
| 41 | + "!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\"" |
42 | 42 | ]
|
43 | 43 | },
|
44 | 44 | {
|
45 | 45 | "cell_type": "code",
|
46 |
| - "execution_count": 2, |
| 46 | + "execution_count": 5, |
47 | 47 | "id": "e9cd1b08",
|
48 | 48 | "metadata": {},
|
49 | 49 | "outputs": [],
|
50 | 50 | "source": [
|
51 | 51 | "# Import libs\n",
|
52 |
| - "from monai.inferers import SlidingWindowInferer\n", |
| 52 | + "from monai.inferers import SliceInferer\n", |
53 | 53 | "import torch\n",
|
54 |
| - "from typing import Callable, Any\n", |
55 | 54 | "from monai.networks.nets import UNet"
|
56 | 55 | ]
|
57 | 56 | },
|
|
60 | 59 | "id": "85f00a47",
|
61 | 60 | "metadata": {},
|
62 | 61 | "source": [
|
63 |
| - "## Overiding SlidingWindowInferer\n", |
64 |
| - "The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`" |
65 |
| - ] |
66 |
| - }, |
67 |
| - { |
68 |
| - "cell_type": "code", |
69 |
| - "execution_count": 3, |
70 |
| - "id": "01f8bfa3", |
71 |
| - "metadata": {}, |
72 |
| - "outputs": [], |
73 |
| - "source": [ |
74 |
| - "class YourSlidingWindowInferer(SlidingWindowInferer):\n", |
75 |
| - " def __init__(self, spatial_dim: int = 0, *args, **kwargs):\n", |
76 |
| - " # Set dim to slice the volume across, for example, `0` could slide over axial slices,\n", |
77 |
| - " # `1` over coronal slices\n", |
78 |
| - " # and `2` over sagittal slices.\n", |
79 |
| - " self.spatial_dim = spatial_dim\n", |
80 |
| - "\n", |
81 |
| - " super().__init__(*args, **kwargs)\n", |
82 |
| - "\n", |
83 |
| - " def __call__(\n", |
84 |
| - " self,\n", |
85 |
| - " inputs: torch.Tensor,\n", |
86 |
| - " network: Callable[..., torch.Tensor],\n", |
87 |
| - " slice_axis: int = 0,\n", |
88 |
| - " *args: Any,\n", |
89 |
| - " **kwargs: Any,\n", |
90 |
| - " ) -> torch.Tensor:\n", |
91 |
| - "\n", |
92 |
| - " assert (\n", |
93 |
| - " self.spatial_dim < 3\n", |
94 |
| - " ), \"`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively\"\n", |
95 |
| - "\n", |
96 |
| - " # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch\n", |
97 |
| - " if len(self.roi_size) != len(inputs.shape[2:]):\n", |
98 |
| - "\n", |
99 |
| - " # If they mismatch and roi_size is 2D add another dimension to roi size\n", |
100 |
| - " if len(self.roi_size) == 2:\n", |
101 |
| - " self.roi_size = list(self.roi_size)\n", |
102 |
| - " self.roi_size.insert(self.spatial_dim, 1)\n", |
103 |
| - " else:\n", |
104 |
| - " raise RuntimeError(\n", |
105 |
| - " \"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \"\n", |
106 |
| - " )\n", |
107 |
| - "\n", |
108 |
| - " return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n", |
109 |
| - "\n", |
110 |
| - " def network_wrapper(self, network, x, *args, **kwargs):\n", |
111 |
| - " \"\"\"\n", |
112 |
| - " Wrapper handles cases where inference needs to be done using\n", |
113 |
| - " 2D models over 3D volume inputs.\n", |
114 |
| - " \"\"\"\n", |
115 |
| - " # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n", |
116 |
| - " # be handled accordingly\n", |
117 |
| - "\n", |
118 |
| - " if self.roi_size[self.spatial_dim] == 1:\n", |
119 |
| - " # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n", |
120 |
| - " x = x.squeeze(dim=self.spatial_dim + 2)\n", |
121 |
| - " out = network(x, *args, **kwargs)\n", |
122 |
| - " # Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n", |
123 |
| - " # the default SlidingWindowInferer class\n", |
124 |
| - " return out.unsqueeze(dim=self.spatial_dim + 2)\n", |
125 |
| - "\n", |
126 |
| - " else:\n", |
127 |
| - " return network(x, *args, **kwargs)" |
| 62 | + "## SliceInferer\n", |
| 63 | + "The simplest way to achieve this functionality is to extend the `SlidingWindowInferer` in `monai.inferers`. This is made available as `SliceInferer` in MONAI (https://docs.monai.io/en/latest/inferers.html#sliceinferer)." |
128 | 64 | ]
|
129 | 65 | },
|
130 | 66 | {
|
131 | 67 | "cell_type": "markdown",
|
132 | 68 | "id": "bb0a63dd",
|
133 | 69 | "metadata": {},
|
134 | 70 | "source": [
|
135 |
| - "## Testing added functionality\n", |
136 |
| - "Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above." |
| 71 | + "## Usage" |
137 | 72 | ]
|
138 | 73 | },
|
139 | 74 | {
|
140 | 75 | "cell_type": "code",
|
141 |
| - "execution_count": 4, |
| 76 | + "execution_count": 6, |
142 | 77 | "id": "85b15305",
|
143 | 78 | "metadata": {},
|
144 | 79 | "outputs": [
|
| 80 | + { |
| 81 | + "name": "stderr", |
| 82 | + "output_type": "stream", |
| 83 | + "text": [ |
| 84 | + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:00<00:00, 107.33it/s]\n" |
| 85 | + ] |
| 86 | + }, |
| 87 | + { |
| 88 | + "name": "stdout", |
| 89 | + "output_type": "stream", |
| 90 | + "text": [ |
| 91 | + "Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n" |
| 92 | + ] |
| 93 | + }, |
| 94 | + { |
| 95 | + "name": "stderr", |
| 96 | + "output_type": "stream", |
| 97 | + "text": [ |
| 98 | + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:01<00:00, 177.69it/s]\n" |
| 99 | + ] |
| 100 | + }, |
145 | 101 | {
|
146 | 102 | "name": "stdout",
|
147 | 103 | "output_type": "stream",
|
148 | 104 | "text": [
|
149 |
| - "Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n", |
150 | 105 | "Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n"
|
151 | 106 | ]
|
152 | 107 | }
|
|
167 | 122 | "# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n",
|
168 | 123 | "input_volume = torch.ones(1, 1, 64, 256, 256)\n",
|
169 | 124 | "\n",
|
170 |
| - "# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n", |
171 |
| - "axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)\n", |
| 125 | + "# Create an instance of SliceInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n", |
| 126 | + "axial_inferer = SliceInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1, progress=True)\n", |
172 | 127 | "\n",
|
173 | 128 | "output = axial_inferer(input_volume, net)\n",
|
174 | 129 | "\n",
|
175 | 130 | "# Output is a 3D volume with 2D slices aggregated\n",
|
176 | 131 | "print(\"Axial Inferer Output Shape: \", output.shape)\n",
|
177 |
| - "# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", |
178 |
| - "coronal_inferer = YourSlidingWindowInferer(\n", |
| 132 | + "# Create an instance of SliceInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", |
| 133 | + "coronal_inferer = SliceInferer(\n", |
179 | 134 | " roi_size=(64, 256),\n",
|
180 | 135 | " sw_batch_size=1,\n",
|
181 | 136 | " spatial_dim=1, # Spatial dim to slice along is added here\n",
|
182 | 137 | " cval=-1,\n",
|
| 138 | + " progress=True,\n", |
183 | 139 | ")\n",
|
184 | 140 | "\n",
|
185 | 141 | "output = coronal_inferer(input_volume, net)\n",
|
186 | 142 | "\n",
|
187 | 143 | "# Output is a 3D volume with 2D slices aggregated\n",
|
188 | 144 | "print(\"Coronal Inferer Output Shape: \", output.shape)"
|
189 | 145 | ]
|
| 146 | + }, |
| 147 | + { |
| 148 | + "cell_type": "markdown", |
| 149 | + "id": "f2596d86", |
| 150 | + "metadata": {}, |
| 151 | + "source": [ |
| 152 | + "Note that with `axial_inferer` and `coronal_inferer`, the number of inference iterations is 64 and 256 repectively." |
| 153 | + ] |
190 | 154 | }
|
191 | 155 | ],
|
192 | 156 | "metadata": {
|
|
205 | 169 | "name": "python",
|
206 | 170 | "nbconvert_exporter": "python",
|
207 | 171 | "pygments_lexer": "ipython3",
|
208 |
| - "version": "3.7.11" |
| 172 | + "version": "3.8.12" |
209 | 173 | }
|
210 | 174 | },
|
211 | 175 | "nbformat": 4,
|
|
0 commit comments