Skip to content

update inference 2d from 3d #746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 43 additions & 79 deletions modules/2d_inference_3d_volume.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,25 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"id": "f2e1b91f",
"metadata": {},
"outputs": [],
"source": [
"# Install monai\n",
"!python -c \"import monai\" || pip install -q \"monai-weekly\""
"!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"id": "e9cd1b08",
"metadata": {},
"outputs": [],
"source": [
"# Import libs\n",
"from monai.inferers import SlidingWindowInferer\n",
"from monai.inferers import SliceInferer\n",
"import torch\n",
"from typing import Callable, Any\n",
"from monai.networks.nets import UNet"
]
},
Expand All @@ -60,93 +59,49 @@
"id": "85f00a47",
"metadata": {},
"source": [
"## Overiding SlidingWindowInferer\n",
"The simplest way to achieve this functionality is to create a class `YourSlidingWindowInferer` that inherits from `SlidingWindowInferer` in `monai.inferers`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "01f8bfa3",
"metadata": {},
"outputs": [],
"source": [
"class YourSlidingWindowInferer(SlidingWindowInferer):\n",
" def __init__(self, spatial_dim: int = 0, *args, **kwargs):\n",
" # Set dim to slice the volume across, for example, `0` could slide over axial slices,\n",
" # `1` over coronal slices\n",
" # and `2` over sagittal slices.\n",
" self.spatial_dim = spatial_dim\n",
"\n",
" super().__init__(*args, **kwargs)\n",
"\n",
" def __call__(\n",
" self,\n",
" inputs: torch.Tensor,\n",
" network: Callable[..., torch.Tensor],\n",
" slice_axis: int = 0,\n",
" *args: Any,\n",
" **kwargs: Any,\n",
" ) -> torch.Tensor:\n",
"\n",
" assert (\n",
" self.spatial_dim < 3\n",
" ), \"`spatial_dim` can only be `[D, H, W]` with `0, 1, 2` respectively\"\n",
"\n",
" # Check if roi size (eg. 2D roi) and input volume sizes (3D input) mismatch\n",
" if len(self.roi_size) != len(inputs.shape[2:]):\n",
"\n",
" # If they mismatch and roi_size is 2D add another dimension to roi size\n",
" if len(self.roi_size) == 2:\n",
" self.roi_size = list(self.roi_size)\n",
" self.roi_size.insert(self.spatial_dim, 1)\n",
" else:\n",
" raise RuntimeError(\n",
" \"Currently, only 2D `roi_size` is supported, cannot broadcast to volume. \"\n",
" )\n",
"\n",
" return super().__call__(inputs, lambda x: self.network_wrapper(network, x))\n",
"\n",
" def network_wrapper(self, network, x, *args, **kwargs):\n",
" \"\"\"\n",
" Wrapper handles cases where inference needs to be done using\n",
" 2D models over 3D volume inputs.\n",
" \"\"\"\n",
" # If depth dim is 1 in [D, H, W] roi size, then the input is 2D and needs\n",
" # be handled accordingly\n",
"\n",
" if self.roi_size[self.spatial_dim] == 1:\n",
" # Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n",
" x = x.squeeze(dim=self.spatial_dim + 2)\n",
" out = network(x, *args, **kwargs)\n",
" # Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n",
" # the default SlidingWindowInferer class\n",
" return out.unsqueeze(dim=self.spatial_dim + 2)\n",
"\n",
" else:\n",
" return network(x, *args, **kwargs)"
"## SliceInferer\n",
"The simplest way to achieve this functionality is to extend the `SlidingWindowInferer` in `monai.inferers`. This is made available as `SliceInferer` in MONAI (https://docs.monai.io/en/latest/inferers.html#sliceinferer)."
]
},
{
"cell_type": "markdown",
"id": "bb0a63dd",
"metadata": {},
"source": [
"## Testing added functionality\n",
"Let's use the `YourSlidingWindowInferer` in a dummy example to execute the workflow described above."
"## Usage"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "85b15305",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:00<00:00, 107.33it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:01<00:00, 177.69it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Axial Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n",
"Coronal Inferer Output Shape: torch.Size([1, 1, 64, 256, 256])\n"
]
}
Expand All @@ -167,26 +122,35 @@
"# Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n",
"input_volume = torch.ones(1, 1, 64, 256, 256)\n",
"\n",
"# Create an instance of YourSlidingWindowInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n",
"axial_inferer = YourSlidingWindowInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1)\n",
"# Create an instance of SliceInferer with roi_size as the 256x256 (HxW) and sliding over D axis\n",
"axial_inferer = SliceInferer(roi_size=(256, 256), sw_batch_size=1, cval=-1, progress=True)\n",
"\n",
"output = axial_inferer(input_volume, net)\n",
"\n",
"# Output is a 3D volume with 2D slices aggregated\n",
"print(\"Axial Inferer Output Shape: \", output.shape)\n",
"# Create an instance of YourSlidingWindowInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n",
"coronal_inferer = YourSlidingWindowInferer(\n",
"# Create an instance of SliceInferer with roi_size as the 64x256 (DxW) and sliding over H axis\n",
"coronal_inferer = SliceInferer(\n",
" roi_size=(64, 256),\n",
" sw_batch_size=1,\n",
" spatial_dim=1, # Spatial dim to slice along is added here\n",
" cval=-1,\n",
" progress=True,\n",
")\n",
"\n",
"output = coronal_inferer(input_volume, net)\n",
"\n",
"# Output is a 3D volume with 2D slices aggregated\n",
"print(\"Coronal Inferer Output Shape: \", output.shape)"
]
},
{
"cell_type": "markdown",
"id": "f2596d86",
"metadata": {},
"source": [
"Note that with `axial_inferer` and `coronal_inferer`, the number of inference iterations is 64 and 256 repectively."
]
}
],
"metadata": {
Expand All @@ -205,7 +169,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
"version": "3.8.12"
}
},
"nbformat": 4,
Expand Down