From 2604beb0730dbda9cc82fe0d3782c92eaf03fc52 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 5 Jun 2022 19:58:20 +0100 Subject: [PATCH 1/3] update inference 2d from 3d Signed-off-by: Wenqi Li --- modules/2d_inference_3d_volume.ipynb | 119 ++++++++++----------------- 1 file changed, 42 insertions(+), 77 deletions(-) diff --git a/modules/2d_inference_3d_volume.ipynb b/modules/2d_inference_3d_volume.ipynb index 7744461a49..4b267ac739 100644 --- a/modules/2d_inference_3d_volume.ipynb +++ b/modules/2d_inference_3d_volume.ipynb @@ -32,24 +32,24 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "id": "f2e1b91f", "metadata": {}, "outputs": [], "source": [ "# Install monai\n", - "!python -c \"import monai\" || pip install -q \"monai-weekly\"" + "!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\"" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "id": "e9cd1b08", "metadata": {}, "outputs": [], "source": [ "# Import libs\n", - "from monai.inferers import SlidingWindowInferer\n", + "from monai.inferers import SliceInferer\n", "import torch\n", "from typing import Callable, Any\n", "from monai.networks.nets import UNet" @@ -60,71 +60,8 @@ "id": "85f00a47", "metadata": {}, "source": [ - "## Overiding SlidingWindowInferer\n", - "The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "01f8bfa3", - "metadata": {}, - "outputs": [], - "source": [ - "class YourSlidingWindowInferer(SlidingWindowInferer):\n", - " def __init__(self, spatial_dim: int = 0, *args, **kwargs):\n", - " # Set dim to slice the volume across, for example, `0` could slide over axial slices,\n", - " # `1` over coronal slices\n", - " # and `2` over sagittal slices.\n", - " self.spatial_dim = spatial_dim\n", - "\n", - " super().__init__(*args, **kwargs)\n", - "\n", - " def __call__(\n", - " self,\n", - " inputs: torch.Tensor,\n", - " network: Callable[..., torch.Tensor],\n", - " slice_axis: int = 0,\n", - " *args: Any,\n", - " **kwargs: Any,\n", - " ) -> torch.Tensor:\n", - "\n", - " assert (\n", - " self.spatial_dim < 3\n", - " ), \"`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively\"\n", - "\n", - " # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch\n", - " if len(self.roi_size) != len(inputs.shape[2:]):\n", - "\n", - " # If they mismatch and roi_size is 2D add another dimension to roi size\n", - " if len(self.roi_size) == 2:\n", - " self.roi_size = list(self.roi_size)\n", - " self.roi_size.insert(self.spatial_dim, 1)\n", - " else:\n", - " raise RuntimeError(\n", - " \"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \"\n", - " )\n", - "\n", - " return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n", - "\n", - " def network_wrapper(self, network, x, *args, **kwargs):\n", - " \"\"\"\n", - " Wrapper handles cases where inference needs to be done using\n", - " 2D models over 3D volume inputs.\n", - " \"\"\"\n", - " # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n", - " # be handled accordingly\n", - "\n", - " if self.roi_size[self.spatial_dim] == 1:\n", - " # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n", - " x = x.squeeze(dim=self.spatial_dim + 2)\n", - " out = network(x, *args, **kwargs)\n", - " # Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n", - " # the default SlidingWindowInferer class\n", - " return out.unsqueeze(dim=self.spatial_dim + 2)\n", - "\n", - " else:\n", - " return network(x, *args, **kwargs)" + "## SliceInferer\n", + "The simplest way to achieve this functionality is to extend the `SlidingWindowInferer` in `monai.inferers`. This is made available as `SliceInferer` in MONAI (https://docs.monai.io/en/latest/inferers.html#sliceinferer)." ] }, { @@ -132,21 +69,40 @@ "id": "bb0a63dd", "metadata": {}, "source": [ - "## Testing added functionality\n", - "Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above." + "## Usage" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "id": "85b15305", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:00<00:00, 107.33it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:01<00:00, 177.69it/s]\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n", "Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n" ] } @@ -167,19 +123,20 @@ "# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n", "input_volume = torch.ones(1, 1, 64, 256, 256)\n", "\n", - "# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n", - "axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)\n", + "# Create an instance of SliceInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n", + "axial_inferer = SliceInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1, progress=True)\n", "\n", "output = axial_inferer(input_volume, net)\n", "\n", "# Output is a 3D volume with 2D slices aggregated\n", "print(\"Axial Inferer Output Shape: \", output.shape)\n", "# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", - "coronal_inferer = YourSlidingWindowInferer(\n", + "coronal_inferer = SliceInferer(\n", " roi_size=(64, 256),\n", " sw_batch_size=1,\n", " spatial_dim=1, # Spatial dim to slice along is added here\n", " cval=-1,\n", + " progress=True,\n", ")\n", "\n", "output = coronal_inferer(input_volume, net)\n", @@ -187,6 +144,14 @@ "# Output is a 3D volume with 2D slices aggregated\n", "print(\"Coronal Inferer Output Shape: \", output.shape)" ] + }, + { + "cell_type": "markdown", + "id": "f1bbf389", + "metadata": {}, + "source": [ + "Note that with `axial_inferer` and `coronal_inferer`, the number of inference iterations is 64 and 256 repectively." + ] } ], "metadata": { @@ -205,7 +170,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.11" + "version": "3.8.12" } }, "nbformat": 4, From 47ee657b9d93c2c542c07059f5453f6b8a51dae3 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 5 Jun 2022 20:00:52 +0100 Subject: [PATCH 2/3] update docstring Signed-off-by: Wenqi Li --- modules/2d_inference_3d_volume.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/2d_inference_3d_volume.ipynb b/modules/2d_inference_3d_volume.ipynb index 4b267ac739..fffd3f1b72 100644 --- a/modules/2d_inference_3d_volume.ipynb +++ b/modules/2d_inference_3d_volume.ipynb @@ -130,7 +130,7 @@ "\n", "# Output is a 3D volume with 2D slices aggregated\n", "print(\"Axial Inferer Output Shape: \", output.shape)\n", - "# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", + "# Create an instance of SliceInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", "coronal_inferer = SliceInferer(\n", " roi_size=(64, 256),\n", " sw_batch_size=1,\n", @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "f1bbf389", + "id": "f2596d86", "metadata": {}, "source": [ "Note that with `axial_inferer` and `coronal_inferer`, the number of inference iterations is 64 and 256 repectively." From 33efbd564698e25ff22d0564280b1731f890bc07 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 5 Jun 2022 20:02:31 +0100 Subject: [PATCH 3/3] fixes pep8 Signed-off-by: Wenqi Li --- modules/2d_inference_3d_volume.ipynb | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/2d_inference_3d_volume.ipynb b/modules/2d_inference_3d_volume.ipynb index fffd3f1b72..6bdc3e4a06 100644 --- a/modules/2d_inference_3d_volume.ipynb +++ b/modules/2d_inference_3d_volume.ipynb @@ -51,7 +51,6 @@ "# Import libs\n", "from monai.inferers import SliceInferer\n", "import torch\n", - "from typing import Callable, Any\n", "from monai.networks.nets import UNet" ] },