diff --git a/README.md b/README.md index 836567d49a..8db4ad4b2f 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,11 @@ Training and evaluation examples of 3D segmentation based on UNet3D and syntheti The examples are built with MONAI workflows, mainly contain: trainer/evaluator, handlers, post_transforms, etc. #### [3d_image_transforms](./modules/3d_image_transforms.ipynb) This notebook demonstrates the transformations on volumetric images. + +#### [2d_inference_3d_volume](./modules/2d_inference_3d_volume.ipynb) +Tutorial that demonstrates how monai `SlidingWindowInferer` can be used when a 3D volume input needs to be provided slice-by-slice to a 2D model and finally, aggregated into a 3D volume. + + #### [autoencoder_mednist](./modules/autoencoder_mednist.ipynb) This tutorial uses the MedNIST hand CT scan dataset to demonstrate MONAI's autoencoder class. The autoencoder is used with an identity encode/decode (i.e., what you put in is what you should get back), as well as demonstrating its usage for de-blurring and de-noising. #### [batch_output_transform](./modules/batch_output_transform.py) diff --git a/figures/2d_inference_3d_input.png b/figures/2d_inference_3d_input.png new file mode 100644 index 0000000000..d344079dea Binary files /dev/null and b/figures/2d_inference_3d_input.png differ diff --git a/modules/2d_inference_3d_volume.ipynb b/modules/2d_inference_3d_volume.ipynb new file mode 100644 index 0000000000..99c34c4c8b --- /dev/null +++ b/modules/2d_inference_3d_volume.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c408367e", + "metadata": {}, + "source": [ + "# 2D Model Inference on a 3D Volume " + ] + }, + { + "cell_type": "markdown", + "id": "a8681db2", + "metadata": {}, + "source": [ + "Usecase: A 2D Model, such as, a 2D segmentation U-Net operates on 2D input which can be slices from a 3D volume (for example, a CT scan). \n", + "\n", + "After editing sliding window inferer as described in this tutorial, it can handle the entire flow as shown:\n", + "![image](../figures/2d_inference_3d_input.png)\n", + "\n", + "The input is a *3D Volume*, a *2D model* and the output is a *3D volume* with 2D slice predictions aggregated. \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "239b0d93", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/2d_inference_3d_volume.ipynb)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f2e1b91f", + "metadata": {}, + "outputs": [], + "source": [ + "# Install monai\n", + "!python -c \"import monai\" || pip install -q \"monai-weekly\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e9cd1b08", + "metadata": {}, + "outputs": [], + "source": [ + "# Import libs\n", + "from monai.inferers import SlidingWindowInferer\n", + "import torch\n", + "from typing import Callable, Any\n", + "from monai.networks.nets import UNet" + ] + }, + { + "cell_type": "markdown", + "id": "85f00a47", + "metadata": {}, + "source": [ + "## Overiding SlidingWindowInferer\n", + "The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "01f8bfa3", + "metadata": {}, + "outputs": [], + "source": [ + "class YourSlidingWindowInferer(SlidingWindowInferer):\n", + " def __init__(self, spatial_dim: int = 0, *args, **kwargs):\n", + " # Set dim to slice the volume across, for example, `0` could slide over axial slices,\n", + " # `1` over coronal slices\n", + " # and `2` over sagittal slices.\n", + " self.spatial_dim = spatial_dim\n", + "\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " def __call__(\n", + " self,\n", + " inputs: torch.Tensor,\n", + " network: Callable[..., torch.Tensor],\n", + " slice_axis: int = 0,\n", + " *args: Any,\n", + " **kwargs: Any,\n", + " ) -> torch.Tensor:\n", + "\n", + " assert (\n", + " self.spatial_dim < 3\n", + " ), \"`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively\"\n", + "\n", + " # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch\n", + " if len(self.roi_size) != len(inputs.shape[2:]):\n", + "\n", + " # If they mismatch and roi_size is 2D add another dimension to roi size\n", + " if len(self.roi_size) == 2:\n", + " self.roi_size = list(self.roi_size)\n", + " self.roi_size.insert(self.spatial_dim, 1)\n", + " else:\n", + " raise RuntimeError(\n", + " \"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \"\n", + " )\n", + "\n", + " return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n", + "\n", + " def network_wrapper(self, network, x, *args, **kwargs):\n", + " \"\"\"\n", + " Wrapper handles cases where inference needs to be done using\n", + " 2D models over 3D volume inputs.\n", + " \"\"\"\n", + " # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n", + " # be handled accordingly\n", + "\n", + " if self.roi_size[self.spatial_dim] == 1:\n", + " # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n", + " x = x.squeeze(dim=self.spatial_dim + 2)\n", + " out = network(x, *args, **kwargs)\n", + " # Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n", + " # the default SlidingWindowInferer class\n", + " return out.unsqueeze(dim=self.spatial_dim + 2)\n", + "\n", + " else:\n", + " return network(x, *args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "id": "bb0a63dd", + "metadata": {}, + "source": [ + "## Testing added functionality\n", + "Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "85b15305", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n", + "Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n" + ] + } + ], + "source": [ + "# Create a 2D UNet with randomly initialized weights for testing purposes\n", + "\n", + "# 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units\n", + "net = UNet(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " channels=(4, 8, 16),\n", + " strides=(2, 2),\n", + " num_res_units=2,\n", + ")\n", + "\n", + "# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n", + "input_volume = torch.ones(1, 1, 64, 256, 256)\n", + "\n", + "# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n", + "axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)\n", + "\n", + "output = axial_inferer(input_volume, net)\n", + "\n", + "# Output is a 3D volume with 2D slices aggregated\n", + "print(\"Axial Inferer Output Shape: \", output.shape)\n", + "# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", + "coronal_inferer = YourSlidingWindowInferer(\n", + " roi_size=(64, 256),\n", + " sw_batch_size=1,\n", + " spatial_dim=1, # Spatial dim to slice along is added here\n", + " cval=-1,\n", + ")\n", + "\n", + "output = coronal_inferer(input_volume, net)\n", + "\n", + "# Output is a 3D volume with 2D slices aggregated\n", + "print(\"Coronal Inferer Output Shape: \", output.shape)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/runner.sh b/runner.sh index a2b17d6c7b..46981768aa 100755 --- a/runner.sh +++ b/runner.sh @@ -36,6 +36,7 @@ doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" UNet_input_size_con doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" network_api.ipynb) doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" tcia_csv_processing.ipynb) doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" transform_visualization.ipynb) +doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" 2d_inference_3d_volume.ipynb) # output formatting separator=""