-
Notifications
You must be signed in to change notification settings - Fork 738
Add demo for 2d inference on 3D volume #479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Nic-Ma
merged 6 commits into
Project-MONAI:master
from
surajpaib:2d_inference_3d_volume
Dec 16, 2021
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
c2e45f3
Add demo for 2d inference on 3D volume
surajpaib e463f9c
Add slice_axis to SlidingWindowInferer
surajpaib fa986bf
Merge branch 'master' into 2d_inference_3d_volume
wyli 252e329
Address comments + PEP compliance
surajpaib 72309b6
Merge branch '2d_inference_3d_volume' of github.com:surajpaib/tutoria…
surajpaib 71d5141
Change to spatial_dim + assert
surajpaib File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
{ | ||
surajpaib marked this conversation as resolved.
Show resolved
Hide resolved
surajpaib marked this conversation as resolved.
Show resolved
Hide resolved
surajpaib marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c408367e", | ||
"metadata": {}, | ||
"source": [ | ||
"# 2D Model Inference on a 3D Volume " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a8681db2", | ||
"metadata": {}, | ||
"source": [ | ||
"Usecase: A 2D Model, such as, a 2D segmentation U-Net operates on 2D input which can be slices from a 3D volume (for example, a CT scan). \n", | ||
"\n", | ||
"After editing sliding window inferer as described in this tutorial, it can handle the entire flow as shown:\n", | ||
"\n", | ||
"\n", | ||
"The input is a *3D Volume*, a *2D model* and the output is a *3D volume* with 2D slice predictions aggregated. \n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "239b0d93", | ||
"metadata": {}, | ||
"source": [ | ||
"[](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/2d_inference_3d_volume.ipynb)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "f2e1b91f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Install monai\n", | ||
"!python -c \"import monai\" || pip install -q \"monai-weekly\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "e9cd1b08", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Import libs\n", | ||
"from monai.inferers import SlidingWindowInferer\n", | ||
"import torch\n", | ||
"from typing import Callable, Any\n", | ||
"from monai.networks.nets import UNet" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "85f00a47", | ||
"metadata": {}, | ||
"source": [ | ||
"## Overiding SlidingWindowInferer\n", | ||
"The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "01f8bfa3", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class YourSlidingWindowInferer(SlidingWindowInferer):\n", | ||
" def __init__(self, spatial_dim: int = 0, *args, **kwargs):\n", | ||
" # Set dim to slice the volume across, for example, `0` could slide over axial slices,\n", | ||
" # `1` over coronal slices\n", | ||
" # and `2` over sagittal slices.\n", | ||
" self.spatial_dim = spatial_dim\n", | ||
"\n", | ||
" super().__init__(*args, **kwargs)\n", | ||
"\n", | ||
" def __call__(\n", | ||
" self,\n", | ||
" inputs: torch.Tensor,\n", | ||
" network: Callable[..., torch.Tensor],\n", | ||
" slice_axis: int = 0,\n", | ||
" *args: Any,\n", | ||
" **kwargs: Any,\n", | ||
" ) -> torch.Tensor:\n", | ||
"\n", | ||
" assert (\n", | ||
" self.spatial_dim < 3\n", | ||
" ), \"`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively\"\n", | ||
"\n", | ||
" # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch\n", | ||
" if len(self.roi_size) != len(inputs.shape[2:]):\n", | ||
"\n", | ||
" # If they mismatch and roi_size is 2D add another dimension to roi size\n", | ||
" if len(self.roi_size) == 2:\n", | ||
" self.roi_size = list(self.roi_size)\n", | ||
" self.roi_size.insert(self.spatial_dim, 1)\n", | ||
" else:\n", | ||
" raise RuntimeError(\n", | ||
" \"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \"\n", | ||
" )\n", | ||
"\n", | ||
" return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n", | ||
"\n", | ||
" def network_wrapper(self, network, x, *args, **kwargs):\n", | ||
" \"\"\"\n", | ||
" Wrapper handles cases where inference needs to be done using\n", | ||
" 2D models over 3D volume inputs.\n", | ||
" \"\"\"\n", | ||
" # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n", | ||
" # be handled accordingly\n", | ||
"\n", | ||
" if self.roi_size[self.spatial_dim] == 1:\n", | ||
" # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n", | ||
" x = x.squeeze(dim=self.spatial_dim + 2)\n", | ||
" out = network(x, *args, **kwargs)\n", | ||
" # Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n", | ||
" # the default SlidingWindowInferer class\n", | ||
" return out.unsqueeze(dim=self.spatial_dim + 2)\n", | ||
"\n", | ||
" else:\n", | ||
" return network(x, *args, **kwargs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bb0a63dd", | ||
"metadata": {}, | ||
"source": [ | ||
"## Testing added functionality\n", | ||
"Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "85b15305", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n", | ||
"Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Create a 2D UNet with randomly initialized weights for testing purposes\n", | ||
"\n", | ||
"# 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units\n", | ||
"net = UNet(\n", | ||
" spatial_dims=2,\n", | ||
" in_channels=1,\n", | ||
" out_channels=1,\n", | ||
" channels=(4, 8, 16),\n", | ||
" strides=(2, 2),\n", | ||
" num_res_units=2,\n", | ||
")\n", | ||
"\n", | ||
"# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n", | ||
"input_volume = torch.ones(1, 1, 64, 256, 256)\n", | ||
"\n", | ||
"# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n", | ||
"axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)\n", | ||
"\n", | ||
"output = axial_inferer(input_volume, net)\n", | ||
"\n", | ||
"# Output is a 3D volume with 2D slices aggregated\n", | ||
"print(\"Axial Inferer Output Shape: \", output.shape)\n", | ||
"# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n", | ||
"coronal_inferer = YourSlidingWindowInferer(\n", | ||
" roi_size=(64, 256),\n", | ||
" sw_batch_size=1,\n", | ||
" spatial_dim=1, # Spatial dim to slice along is added here\n", | ||
" cval=-1,\n", | ||
")\n", | ||
"\n", | ||
"output = coronal_inferer(input_volume, net)\n", | ||
"\n", | ||
"# Output is a 3D volume with 2D slices aggregated\n", | ||
"print(\"Coronal Inferer Output Shape: \", output.shape)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.