Skip to content

Add demo for 2d inference on 3D volume #479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ Training and evaluation examples of 3D segmentation based on UNet3D and syntheti
The examples are built with MONAI workflows, mainly contain: trainer/evaluator, handlers, post_transforms, etc.
#### [3d_image_transforms](./modules/3d_image_transforms.ipynb)
This notebook demonstrates the transformations on volumetric images.

#### [2d_inference_3d_volume](./modules/2d_inference_3d_volume.ipynb)
Tutorial that demonstrates how monai `SlidingWindowInferer` can be used when a 3D volume input needs to be provided slice-by-slice to a 2D model and finally, aggregated into a 3D volume.


#### [autoencoder_mednist](./modules/autoencoder_mednist.ipynb)
This tutorial uses the MedNIST hand CT scan dataset to demonstrate MONAI's autoencoder class. The autoencoder is used with an identity encode/decode (i.e., what you put in is what you should get back), as well as demonstrating its usage for de-blurring and de-noising.
#### [batch_output_transform](./modules/batch_output_transform.py)
Expand Down
Binary file added figures/2d_inference_3d_input.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
213 changes: 213 additions & 0 deletions modules/2d_inference_3d_volume.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c408367e",
"metadata": {},
"source": [
"# 2D Model Inference on a 3D Volume "
]
},
{
"cell_type": "markdown",
"id": "a8681db2",
"metadata": {},
"source": [
"Usecase: A 2D Model, such as, a 2D segmentation U-Net operates on 2D input which can be slices from a 3D volume (for example, a CT scan). \n",
"\n",
"After editing sliding window inferer as described in this tutorial, it can handle the entire flow as shown:\n",
"![image](../figures/2d_inference_3d_input.png)\n",
"\n",
"The input is a *3D Volume*, a *2D model* and the output is a *3D volume* with 2D slice predictions aggregated. \n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "239b0d93",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/2d_inference_3d_volume.ipynb)\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f2e1b91f",
"metadata": {},
"outputs": [],
"source": [
"# Install monai\n",
"!python -c \"import monai\" || pip install -q \"monai-weekly\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e9cd1b08",
"metadata": {},
"outputs": [],
"source": [
"# Import libs\n",
"from monai.inferers import SlidingWindowInferer\n",
"import torch\n",
"from typing import Callable, Any\n",
"from monai.networks.nets import UNet"
]
},
{
"cell_type": "markdown",
"id": "85f00a47",
"metadata": {},
"source": [
"## Overiding SlidingWindowInferer\n",
"The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "01f8bfa3",
"metadata": {},
"outputs": [],
"source": [
"class YourSlidingWindowInferer(SlidingWindowInferer):\n",
" def __init__(self, spatial_dim: int = 0, *args, **kwargs):\n",
" # Set dim to slice the volume across, for example, `0` could slide over axial slices,\n",
" # `1` over coronal slices\n",
" # and `2` over sagittal slices.\n",
" self.spatial_dim = spatial_dim\n",
"\n",
" super().__init__(*args, **kwargs)\n",
"\n",
" def __call__(\n",
" self,\n",
" inputs: torch.Tensor,\n",
" network: Callable[..., torch.Tensor],\n",
" slice_axis: int = 0,\n",
" *args: Any,\n",
" **kwargs: Any,\n",
" ) -> torch.Tensor:\n",
"\n",
" assert (\n",
" self.spatial_dim < 3\n",
" ), \"`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively\"\n",
"\n",
" # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch\n",
" if len(self.roi_size) != len(inputs.shape[2:]):\n",
"\n",
" # If they mismatch and roi_size is 2D add another dimension to roi size\n",
" if len(self.roi_size) == 2:\n",
" self.roi_size = list(self.roi_size)\n",
" self.roi_size.insert(self.spatial_dim, 1)\n",
" else:\n",
" raise RuntimeError(\n",
" \"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \"\n",
" )\n",
"\n",
" return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n",
"\n",
" def network_wrapper(self, network, x, *args, **kwargs):\n",
" \"\"\"\n",
" Wrapper handles cases where inference needs to be done using\n",
" 2D models over 3D volume inputs.\n",
" \"\"\"\n",
" # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n",
" # be handled accordingly\n",
"\n",
" if self.roi_size[self.spatial_dim] == 1:\n",
" # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n",
" x = x.squeeze(dim=self.spatial_dim + 2)\n",
" out = network(x, *args, **kwargs)\n",
" # Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n",
" # the default SlidingWindowInferer class\n",
" return out.unsqueeze(dim=self.spatial_dim + 2)\n",
"\n",
" else:\n",
" return network(x, *args, **kwargs)"
]
},
{
"cell_type": "markdown",
"id": "bb0a63dd",
"metadata": {},
"source": [
"## Testing added functionality\n",
"Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "85b15305",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n",
"Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n"
]
}
],
"source": [
"# Create a 2D UNet with randomly initialized weights for testing purposes\n",
"\n",
"# 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units\n",
"net = UNet(\n",
" spatial_dims=2,\n",
" in_channels=1,\n",
" out_channels=1,\n",
" channels=(4, 8, 16),\n",
" strides=(2, 2),\n",
" num_res_units=2,\n",
")\n",
"\n",
"# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n",
"input_volume = torch.ones(1, 1, 64, 256, 256)\n",
"\n",
"# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n",
"axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)\n",
"\n",
"output = axial_inferer(input_volume, net)\n",
"\n",
"# Output is a 3D volume with 2D slices aggregated\n",
"print(\"Axial Inferer Output Shape: \", output.shape)\n",
"# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n",
"coronal_inferer = YourSlidingWindowInferer(\n",
" roi_size=(64, 256),\n",
" sw_batch_size=1,\n",
" spatial_dim=1, # Spatial dim to slice along is added here\n",
" cval=-1,\n",
")\n",
"\n",
"output = coronal_inferer(input_volume, net)\n",
"\n",
"# Output is a 3D volume with 2D slices aggregated\n",
"print(\"Coronal Inferer Output Shape: \", output.shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" UNet_input_size_con
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" network_api.ipynb)
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" tcia_csv_processing.ipynb)
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" transform_visualization.ipynb)
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" 2d_inference_3d_volume.ipynb)

# output formatting
separator=""
Expand Down