|
41 | 41 | "output_type": "stream",
|
42 | 42 | "text": [
|
43 | 43 | "Requirement already satisfied: monai in /home/suraj/miniconda3/envs/monai/lib/python3.7/site-packages (0.8.0)\n",
|
44 |
| - "Requirement already satisfied: numpy>=1.17 in /home/suraj/miniconda3/envs/monai/lib/python3.7/site-packages (from monai) (1.21.4)\n", |
45 | 44 | "Requirement already satisfied: torch>=1.6 in /home/suraj/miniconda3/envs/monai/lib/python3.7/site-packages (from monai) (1.10.0)\n",
|
| 45 | + "Requirement already satisfied: numpy>=1.17 in /home/suraj/miniconda3/envs/monai/lib/python3.7/site-packages (from monai) (1.21.4)\n", |
46 | 46 | "Requirement already satisfied: typing-extensions in /home/suraj/miniconda3/envs/monai/lib/python3.7/site-packages (from torch>=1.6->monai) (4.0.1)\n"
|
47 | 47 | ]
|
48 | 48 | }
|
|
74 | 74 | "\n",
|
75 | 75 | "\n",
|
76 | 76 | "class YourSlidingWindowInferer(SlidingWindowInferer):\n",
|
77 |
| - " def __init__(self, *args, **kwargs):\n", |
| 77 | + " def __init__(self, slice_axis: int = 0, \n", |
| 78 | + " *args, **kwargs):\n", |
| 79 | + " # Set axis to slice the volume across, for example, `0` could slide over axial slices, `1` over coronal slices \n", |
| 80 | + " # and `2` over sagittal slices.\n", |
| 81 | + " self.slice_axis = slice_axis\n", |
| 82 | + "\n", |
78 | 83 | " super().__init__(*args, **kwargs)\n",
|
79 | 84 | "\n",
|
80 | 85 | " def __call__(\n",
|
81 | 86 | " self,\n",
|
82 | 87 | " inputs: torch.Tensor,\n",
|
83 | 88 | " network: Callable[..., torch.Tensor],\n",
|
| 89 | + " slice_axis: int = 0,\n", |
84 | 90 | " *args: Any,\n",
|
85 | 91 | " **kwargs: Any,\n",
|
86 | 92 | " ) -> torch.Tensor:\n",
|
|
90 | 96 | "\n",
|
91 | 97 | " # If they mismatch and roi_size is 2D add another dimension to roi size\n",
|
92 | 98 | " if len(self.roi_size) == 2:\n",
|
93 |
| - " self.roi_size = [1, *self.roi_size]\n", |
| 99 | + " self.roi_size = list(self.roi_size)\n", |
| 100 | + " self.roi_size.insert(self.slice_axis, 1)\n", |
94 | 101 | " else:\n",
|
95 |
| - " raise RuntimeError(\"Unsupported roi size, cannot broadcast to volume. \")\n", |
96 |
| - "\n", |
| 102 | + " raise RuntimeError(\"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \")\n", |
| 103 | + " \n", |
97 | 104 | " return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n",
|
98 | 105 | "\n",
|
99 | 106 | " def network_wrapper(self, network, x, *args, **kwargs):\n",
|
|
103 | 110 | " \"\"\"\n",
|
104 | 111 | " # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n",
|
105 | 112 | " # be handled accordingly\n",
|
106 |
| - " if self.roi_size[0] == 1:\n", |
107 |
| - " # Pass [N, C, H, W] to the model as it is 2D.\n", |
108 |
| - " x = x.squeeze(dim=2)\n", |
| 113 | + "\n", |
| 114 | + " if self.roi_size[self.slice_axis] == 1:\n", |
| 115 | + " # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n", |
| 116 | + " x = x.squeeze(dim=self.slice_axis + 2)\n", |
109 | 117 | " out = network(x, *args, **kwargs)\n",
|
110 | 118 | " # Unsqueeze the network output so it is [N, C, D, H, W] as expected by the default SlidingWindowInferer class\n",
|
111 |
| - " return out.unsqueeze(dim=2)\n", |
| 119 | + " return out.unsqueeze(dim=self.slice_axis + 2)\n", |
112 | 120 | "\n",
|
113 | 121 | " else:\n",
|
114 | 122 | " return network(x, *args, **kwargs)"
|
|
128 | 136 | "execution_count": 3,
|
129 | 137 | "id": "85b15305",
|
130 | 138 | "metadata": {},
|
131 |
| - "outputs": [], |
| 139 | + "outputs": [ |
| 140 | + { |
| 141 | + "name": "stdout", |
| 142 | + "output_type": "stream", |
| 143 | + "text": [ |
| 144 | + "Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n", |
| 145 | + "Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n" |
| 146 | + ] |
| 147 | + } |
| 148 | + ], |
132 | 149 | "source": [
|
133 | 150 | "# Create a 2D UNet with randomly initialized weights for testing purposes\n",
|
134 | 151 | "from monai.networks.nets import UNet\n",
|
|
144 | 161 | ")\n",
|
145 | 162 | "\n",
|
146 | 163 | "# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n",
|
147 |
| - "input_volume = torch.ones(1, 1, 30, 256, 256)\n", |
| 164 | + "input_volume = torch.ones(1, 1, 64, 256, 256)\n", |
148 | 165 | "\n",
|
149 |
| - "# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW)\n", |
150 |
| - "inferer = YourSlidingWindowInferer(roi_size=(256, 256),\n", |
| 166 | + "# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n", |
| 167 | + "axial_inferer = YourSlidingWindowInferer(roi_size=(32, 256),\n", |
151 | 168 | " sw_batch_size=1,\n",
|
152 | 169 | " cval=-1)\n",
|
153 | 170 | "\n",
|
154 |
| - "output = inferer(input_volume, net)" |
155 |
| - ] |
156 |
| - }, |
157 |
| - { |
158 |
| - "cell_type": "code", |
159 |
| - "execution_count": 4, |
160 |
| - "id": "5ad96534", |
161 |
| - "metadata": {}, |
162 |
| - "outputs": [ |
163 |
| - { |
164 |
| - "data": { |
165 |
| - "text/plain": [ |
166 |
| - "torch.Size([1, 1, 30, 256, 256])" |
167 |
| - ] |
168 |
| - }, |
169 |
| - "execution_count": 4, |
170 |
| - "metadata": {}, |
171 |
| - "output_type": "execute_result" |
172 |
| - } |
173 |
| - ], |
174 |
| - "source": [ |
| 171 | + "output = axial_inferer(input_volume, net)\n", |
| 172 | + "\n", |
| 173 | + "# Output is a 3D volume with 2D slices aggregated\n", |
| 174 | + "print(\"Axial Inferer Output Shape: \", output.shape)\n", |
| 175 | + "\n", |
| 176 | + "# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", |
| 177 | + "coronal_inferer = YourSlidingWindowInferer(roi_size=(64, 256),\n", |
| 178 | + " sw_batch_size=1, slice_axis=1, # Slice axis is added here\n", |
| 179 | + " cval=-1)\n", |
| 180 | + "\n", |
| 181 | + "output = coronal_inferer(input_volume, net)\n", |
| 182 | + "\n", |
175 | 183 | "# Output is a 3D volume with 2D slices aggregated\n",
|
176 |
| - "output.shape" |
| 184 | + "print(\"Coronal Inferer Output Shape: \", output.shape)" |
177 | 185 | ]
|
178 | 186 | }
|
179 | 187 | ],
|
|
0 commit comments