From 2bd177ebe4a33c9b5d1ebfb2c37ddb80dfc53a4b Mon Sep 17 00:00:00 2001 From: simben Date: Sun, 2 Feb 2025 13:11:41 +0000 Subject: [PATCH 01/12] nnUNet MONAI Bundle Notebook Tutorial --- bundle/06_nnunet_monai_bundle.ipynb | 4141 +++++++++++++++++++++++++++ 1 file changed, 4141 insertions(+) create mode 100644 bundle/06_nnunet_monai_bundle.ipynb diff --git a/bundle/06_nnunet_monai_bundle.ipynb b/bundle/06_nnunet_monai_bundle.ipynb new file mode 100644 index 000000000..347cdf12d --- /dev/null +++ b/bundle/06_nnunet_monai_bundle.ipynb @@ -0,0 +1,4141 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) MONAI Consortium \n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); \n", + "you may not use this file except in compliance with the License. \n", + "You may obtain a copy of the License at \n", + "    http://www.apache.org/licenses/LICENSE-2.0 \n", + "Unless required by applicable law or agreed to in writing, software \n", + "distributed under the License is distributed on an \"AS IS\" BASIS, \n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n", + "See the License for the specific language governing permissions and \n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# nnUNet MONAI Bundle\n", + "\n", + "This notebook demonstrates how to create a MONAI Bundle for a trained nnUNet and use it for inference. This is needed when some other application from the MONAI EcoSystem require a MONAI Bundle (MONAI Label, MonaiAlgo for Federated Learning, etc).\n", + "\n", + "This notebook cover the steps to convert a trained nnUNet model to a consumable MONAI Bundle. The nnUNet training is here perfomed using the `nnUNetV2Runner`.\n", + "\n", + "Optionally, the notebook also demonstrates how to use the same nnUNet MONAI Bundle for training a new model. This might be needed in some applications where the nnUNet training needs to be performed through a MONAI Bundle (i.e., Active Learning in MONAI Label, MonaiAlgo for Federated Learning, etc)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, tqdm]\"\n", + "!python -c \"import matplotlib\" || pip install -q matplotlib\n", + "!python -c \"import nnunetv2\" || pip install -q nnunetv2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from monai.config import print_config\n", + "import os\n", + "import tempfile\n", + "from monai.bundle.config_parser import ConfigParser\n", + "from monai.apps.nnunet import nnUNetV2Runner\n", + "from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup data directory\n", + "\n", + "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable. \n", + "This allows you to save results and reuse downloads. \n", + "If not specified a temporary directory will be used." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"/home/maia-user/Tutorials/MONAI/data\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "if directory is not None:\n", + " os.makedirs(directory, exist_ok=True)\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", + "print(root_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download Decathlon Spleen Dataset and Generate Data List\n", + "\n", + "To get the Decathlon Spleen dataset and generate the corresponding data list, you can follow the instructions in the [MSD Datalist Generator Notebook](https://github.com/Project-MONAI/tutorials/blob/main/auto3dseg/notebooks/msd_datalist_generator.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At the end of the notebook, remember to copy the generated `msd_task09_spleen_folds.json` file to the `/Task09_Spleen` directory." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## nnUNet Experiment with nnUNetV2Runner\n", + "\n", + "In the following sections, we will use the nnUNetV2Runner to train a model on the spleen dataset from the Medical Segmentation Decathlon." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We first create the Config file for the nnUNetV2Runner:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "nnunet_root_dir = os.path.join(root_dir, \"nnUNet\")\n", + "\n", + "os.makedirs(nnunet_root_dir, exist_ok=True)\n", + "\n", + "data_src_cfg = os.path.join(nnunet_root_dir, \"data_src_cfg.yaml\")\n", + "data_src = {\"modality\": \"CT\", \"datalist\": os.path.join(root_dir,\"Task09_Spleen/msd_task09_spleen_folds.json\"), \"dataroot\": os.path.join(root_dir,\"Task09_Spleen\")}\n", + "\n", + "ConfigParser.export_config_file(data_src, data_src_cfg)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=\"nnUNetTrainer_1epoch\", work_dir=nnunet_root_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "runner.run(run_train=True, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## nnUNet MONAI Bundle for Inference\n", + "\n", + "This section is the relevant part of the nnUNet MONAI Bundle for Inference, showing how to use the trained model to perform inference on new data through the use of a MONAI Bundle, wrapping the native nnUNet model and its pre- and post-processing steps." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We first create the MONAI Bundle for the nnUNet model:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "rm nnUNetBundle/configs/inference.json\n", + "python -m monai.bundle init_bundle nnUNetBundle\n", + "\n", + "mkdir -p nnUNetBundle/src\n", + "touch nnUNetBundle/src/__init__.py\n", + "which tree && tree nnUNetBundle || true" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We then populate the MONAI Bundle with the configuration for inference:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/configs/inference.yaml\n", + "\n", + "imports: \n", + " - $import json\n", + " - $from pathlib import Path\n", + " - $import os\n", + " - $import monai.bundle.nnunet\n", + " - $from ignite.contrib.handlers.tqdm_logger import ProgressBar\n", + " - $import shutil\n", + "\n", + "\n", + "output_dir: \".\"\n", + "bundle_root: \".\"\n", + "data_list_file : \".\"\n", + "data_dir: \".\"\n", + "\n", + "prediction_suffix: \"prediction\"\n", + "\n", + "test_data_list: \"$monai.data.load_decathlon_datalist(@data_list_file, is_segmentation=True, data_list_key='testing', base_dir=@data_dir)\"\n", + "image_modality_keys: \"$list(@modality_conf.keys())\"\n", + "image_key: \"image\"\n", + "image_suffix: \"@image_key\"\n", + "\n", + "preprocessing:\n", + " _target_: Compose\n", + " transforms:\n", + " - _target_: LoadImaged\n", + " keys: \"image\"\n", + " ensure_channel_first: True\n", + " image_only: False\n", + "\n", + "test_dataset:\n", + " _target_: Dataset\n", + " data: \"$@test_data_list\"\n", + " transform: \"@preprocessing\"\n", + "\n", + "test_loader:\n", + " _target_: DataLoader\n", + " dataset: \"@test_dataset\"\n", + " batch_size: 1\n", + "\n", + "\n", + "device: \"$torch.device('cuda')\"\n", + "\n", + "nnunet_config:\n", + " model_folder: \"$@bundle_root + '/models'\"\n", + "\n", + "network_def: \"$monai.bundle.nnunet.get_nnunet_monai_predictor(**@nnunet_config)\"\n", + "\n", + "postprocessing:\n", + " _target_: \"Compose\"\n", + " transforms:\n", + " - _target_: Transposed\n", + " keys: \"pred\"\n", + " indices:\n", + " - 0\n", + " - 3\n", + " - 2\n", + " - 1\n", + " - _target_: SaveImaged\n", + " keys: \"pred\"\n", + " resample: False\n", + " output_postfix: \"@prediction_suffix\"\n", + " output_dir: \"@output_dir\"\n", + " meta_keys: \"image_meta_dict\"\n", + "\n", + "\n", + "testing:\n", + " dataloader: \"$@test_loader\"\n", + " pbar:\n", + " _target_: \"ignite.contrib.handlers.tqdm_logger.ProgressBar\"\n", + " test_inferer: \"$@inferer\"\n", + "\n", + "inferer: \n", + " _target_: \"SimpleInferer\"\n", + "\n", + "validator:\n", + " _target_: \"SupervisedEvaluator\"\n", + " postprocessing: \"$@postprocessing\"\n", + " device: \"$@device\"\n", + " inferer: \"$@testing#test_inferer\"\n", + " val_data_loader: \"$@testing#dataloader\"\n", + " network: \"@network_def\"\n", + " #prepare_batch: \"$src.inferer.prepare_nnunet_inference_batch\"\n", + " val_handlers:\n", + " - _target_: \"CheckpointLoader\"\n", + " load_path: \"$@bundle_root+'/models/model.pt'\"\n", + " load_dict:\n", + " network_weights: '$@network_def.network_weights'\n", + "run:\n", + " - \"$@testing#pbar.attach(@validator)\"\n", + " - \"$@validator.run()\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### nnUnet to MONAI Bundle Conversion\n", + "\n", + "Finally, we convert the nnUNet Trained Model to a Bundle-compatible format using the `convert_nnunet_to_monai_bundle` function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nnunet_config = {\n", + " \"dataset_name_or_id\": \"001\",\n", + " \"nnunet_trainer\": \"nnUNetTrainer_1epoch\",\n", + "}\n", + "\n", + "bundle_root = \"nnUNetBundle\"\n", + "\n", + "convert_nnunet_to_monai_bundle(nnunet_config, bundle_root, 0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can then inspect the content of the `models` folder to verify that the model has been converted to the MONAI Bundle format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "which tree && tree nnUNetBundle/models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Test the MONAI Bundle for Inference\n", + "\n", + "The MONAI Bundle for Inference is now ready to be used for inference on new data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Predicting image of shape torch.Size([1, 294, 584, 584]):\n", + "perform_everything_on_device: True\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration: [12/20] 60%|██████ [09:03<05:50]\n", + " 0%| | 0/378 [00:00 bundle_root: 'nnUNetBundle'\n", + "2025-01-30 16:41:59,285 - INFO - ---\n", + "\n", + "\n", + "Using device: cuda:0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "#######################################################################\n", + "Please cite the following paper when using nnU-Net:\n", + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.\n", + "#######################################################################\n", + "\n", + "2025-01-30 16:42:00.755623: do_dummy_2d_data_aug: False\n", + "2025-01-30 16:42:00.757267: Using splits from existing split file: /home/maia-user/Tutorials/MONAI/data/nnUNet/nnUNet_preprocessed/Dataset001_Task09_Spleen/splits_final.json\n", + "2025-01-30 16:42:00.758354: The split file contains 5 splits.\n", + "2025-01-30 16:42:00.759119: Desired fold for training: 0\n", + "2025-01-30 16:42:00.759716: This split has 32 training and 9 validation cases.\n", + "using pin_memory on device 0\n", + "using pin_memory on device 0\n", + "2025-01-30 16:42:06.184253: Using torch.compile...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "This is the configuration used by this training:\n", + "Configuration name: 3d_fullres\n", + " {'data_identifier': 'nnUNetPlans_3d_fullres', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 2, 'patch_size': [64, 192, 160], 'median_image_size_in_voxels': [187.0, 512.0, 512.0], 'spacing': [1.6000100374221802, 0.7929689884185791, 0.7929689884185791], 'normalization_schemes': ['CTNormalization'], 'use_mask_for_norm': [False], 'resampling_fn_data': 'resample_data_or_seg_to_shape', 'resampling_fn_seg': 'resample_data_or_seg_to_shape', 'resampling_fn_data_kwargs': {'is_seg': False, 'order': 3, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_seg_kwargs': {'is_seg': True, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_probabilities': 'resample_data_or_seg_to_shape', 'resampling_fn_probabilities_kwargs': {'is_seg': False, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'architecture': {'network_class_name': 'dynamic_network_architectures.architectures.unet.PlainConvUNet', 'arch_kwargs': {'n_stages': 6, 'features_per_stage': [32, 64, 128, 256, 320, 320], 'conv_op': 'torch.nn.modules.conv.Conv3d', 'kernel_sizes': [[1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], 'strides': [[1, 1, 1], [1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 'n_conv_per_stage': [2, 2, 2, 2, 2, 2], 'n_conv_per_stage_decoder': [2, 2, 2, 2, 2], 'conv_bias': True, 'norm_op': 'torch.nn.modules.instancenorm.InstanceNorm3d', 'norm_op_kwargs': {'eps': 1e-05, 'affine': True}, 'dropout_op': None, 'dropout_op_kwargs': None, 'nonlin': 'torch.nn.LeakyReLU', 'nonlin_kwargs': {'inplace': True}, 'deep_supervision': True}, '_kw_requires_import': ['conv_op', 'norm_op', 'dropout_op', 'nonlin']}, 'batch_dice': True} \n", + "\n", + "These are the global plan.json settings:\n", + " {'dataset_name': 'Dataset001_Task09_Spleen', 'plans_name': 'nnUNetPlans', 'original_median_spacing_after_transp': [5.0, 0.7929689884185791, 0.7929689884185791], 'original_median_shape_after_transp': [90, 512, 512], 'image_reader_writer': 'SimpleITKIO', 'transpose_forward': [0, 1, 2], 'transpose_backward': [0, 1, 2], 'experiment_planner_used': 'ExperimentPlanner', 'label_manager': 'LabelManager', 'foreground_intensity_properties_per_channel': {'0': {'max': 1038.0, 'mean': 93.1926040649414, 'median': 97.0, 'min': -620.0, 'percentile_00_5': -42.0, 'percentile_99_5': 176.0, 'std': 40.7836799621582}}} \n", + "\n", + "2025-01-30 16:42:07.315368: unpacking dataset...\n", + "2025-01-30 16:42:07.954188: unpacking done...\n", + "2025-01-30 16:42:08.038784: Unable to plot network architecture: nnUNet_compile is enabled!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n", + "`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", + "Epoch [1/1000]: [1/250] 0%| , loss=0.694 [00:00 Date: Wed, 5 Feb 2025 15:14:49 +0000 Subject: [PATCH 02/12] nnUNet MONAI Bundle Notebook --- bundle/06_nnunet_monai_bundle.ipynb | 3463 +-------------------------- 1 file changed, 47 insertions(+), 3416 deletions(-) diff --git a/bundle/06_nnunet_monai_bundle.ipynb b/bundle/06_nnunet_monai_bundle.ipynb index 347cdf12d..c2d2ad60f 100644 --- a/bundle/06_nnunet_monai_bundle.ipynb +++ b/bundle/06_nnunet_monai_bundle.ipynb @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -69,8 +69,8 @@ "import tempfile\n", "from monai.bundle.config_parser import ConfigParser\n", "from monai.apps.nnunet import nnUNetV2Runner\n", - "from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", - "\n", + "#from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", + "import nnunetv2\n", "print_config()" ] }, @@ -85,15 +85,6 @@ "If not specified a temporary directory will be used." ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"/home/maia-user/Tutorials/MONAI/data\"" - ] - }, { "cell_type": "code", "execution_count": null, @@ -113,7 +104,7 @@ "source": [ "## Download Decathlon Spleen Dataset and Generate Data List\n", "\n", - "To get the Decathlon Spleen dataset and generate the corresponding data list, you can follow the instructions in the [MSD Datalist Generator Notebook](https://github.com/Project-MONAI/tutorials/blob/main/auto3dseg/notebooks/msd_datalist_generator.ipynb)" + "To get the Decathlon Spleen dataset and generate the corresponding data list, you can follow the instructions in the [MSD Datalist Generator Notebook](../auto3dseg/notebooks/msd_datalist_generator.ipynb)" ] }, { @@ -141,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -161,7 +152,18 @@ "metadata": {}, "outputs": [], "source": [ - "runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=\"nnUNetTrainer_1epoch\", work_dir=nnunet_root_dir)" + "runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=\"nnUNetTrainer_10epochs\", work_dir=nnunet_root_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "runner.plan_and_process(npfp=2,n_proc=[2,2,2])" ] }, { @@ -169,6 +171,17 @@ "execution_count": null, "metadata": {}, "outputs": [], + "source": [ + "runner.train(configs=\"3d_fullres\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], "source": [ "runner.run(run_train=True, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False)" ] @@ -191,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -368,536 +381,14 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Predicting image of shape torch.Size([1, 294, 584, 584]):\n", - "perform_everything_on_device: True\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Iteration: [12/20] 60%|██████ [09:03<05:50]\n", - " 0%| | 0/378 [00:00/training/lr_scheduler/polylr.py\n", "\n", "from torch.optim.lr_scheduler import _LRScheduler\n", "\n", @@ -1258,2841 +717,13 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2025-01-30 16:41:59,255 - WARNING - Default logging file in nnUNetBundle/configs/logging.conf does not exist, skipping logging.\n", - "2025-01-30 16:41:59,284 - INFO - --- input summary of monai.bundle.scripts.run ---\n", - "2025-01-30 16:41:59,285 - INFO - > bundle_root: 'nnUNetBundle'\n", - "2025-01-30 16:41:59,285 - INFO - ---\n", - "\n", - "\n", - "Using device: cuda:0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "#######################################################################\n", - "Please cite the following paper when using nnU-Net:\n", - "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.\n", - "#######################################################################\n", - "\n", - "2025-01-30 16:42:00.755623: do_dummy_2d_data_aug: False\n", - "2025-01-30 16:42:00.757267: Using splits from existing split file: /home/maia-user/Tutorials/MONAI/data/nnUNet/nnUNet_preprocessed/Dataset001_Task09_Spleen/splits_final.json\n", - "2025-01-30 16:42:00.758354: The split file contains 5 splits.\n", - "2025-01-30 16:42:00.759119: Desired fold for training: 0\n", - "2025-01-30 16:42:00.759716: This split has 32 training and 9 validation cases.\n", - "using pin_memory on device 0\n", - "using pin_memory on device 0\n", - "2025-01-30 16:42:06.184253: Using torch.compile...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "This is the configuration used by this training:\n", - "Configuration name: 3d_fullres\n", - " {'data_identifier': 'nnUNetPlans_3d_fullres', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 2, 'patch_size': [64, 192, 160], 'median_image_size_in_voxels': [187.0, 512.0, 512.0], 'spacing': [1.6000100374221802, 0.7929689884185791, 0.7929689884185791], 'normalization_schemes': ['CTNormalization'], 'use_mask_for_norm': [False], 'resampling_fn_data': 'resample_data_or_seg_to_shape', 'resampling_fn_seg': 'resample_data_or_seg_to_shape', 'resampling_fn_data_kwargs': {'is_seg': False, 'order': 3, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_seg_kwargs': {'is_seg': True, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_probabilities': 'resample_data_or_seg_to_shape', 'resampling_fn_probabilities_kwargs': {'is_seg': False, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'architecture': {'network_class_name': 'dynamic_network_architectures.architectures.unet.PlainConvUNet', 'arch_kwargs': {'n_stages': 6, 'features_per_stage': [32, 64, 128, 256, 320, 320], 'conv_op': 'torch.nn.modules.conv.Conv3d', 'kernel_sizes': [[1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], 'strides': [[1, 1, 1], [1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 'n_conv_per_stage': [2, 2, 2, 2, 2, 2], 'n_conv_per_stage_decoder': [2, 2, 2, 2, 2], 'conv_bias': True, 'norm_op': 'torch.nn.modules.instancenorm.InstanceNorm3d', 'norm_op_kwargs': {'eps': 1e-05, 'affine': True}, 'dropout_op': None, 'dropout_op_kwargs': None, 'nonlin': 'torch.nn.LeakyReLU', 'nonlin_kwargs': {'inplace': True}, 'deep_supervision': True}, '_kw_requires_import': ['conv_op', 'norm_op', 'dropout_op', 'nonlin']}, 'batch_dice': True} \n", - "\n", - "These are the global plan.json settings:\n", - " {'dataset_name': 'Dataset001_Task09_Spleen', 'plans_name': 'nnUNetPlans', 'original_median_spacing_after_transp': [5.0, 0.7929689884185791, 0.7929689884185791], 'original_median_shape_after_transp': [90, 512, 512], 'image_reader_writer': 'SimpleITKIO', 'transpose_forward': [0, 1, 2], 'transpose_backward': [0, 1, 2], 'experiment_planner_used': 'ExperimentPlanner', 'label_manager': 'LabelManager', 'foreground_intensity_properties_per_channel': {'0': {'max': 1038.0, 'mean': 93.1926040649414, 'median': 97.0, 'min': -620.0, 'percentile_00_5': -42.0, 'percentile_99_5': 176.0, 'std': 40.7836799621582}}} \n", - "\n", - "2025-01-30 16:42:07.315368: unpacking dataset...\n", - "2025-01-30 16:42:07.954188: unpacking done...\n", - "2025-01-30 16:42:08.038784: Unable to plot network architecture: nnUNet_compile is enabled!\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n", - "`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", - "Epoch [1/1000]: [1/250] 0%| , loss=0.694 [00:00 Date: Wed, 5 Feb 2025 15:38:58 +0000 Subject: [PATCH 03/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- bundle/06_nnunet_monai_bundle.ipynb | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/bundle/06_nnunet_monai_bundle.ipynb b/bundle/06_nnunet_monai_bundle.ipynb index c2d2ad60f..5fc8031d0 100644 --- a/bundle/06_nnunet_monai_bundle.ipynb +++ b/bundle/06_nnunet_monai_bundle.ipynb @@ -69,8 +69,10 @@ "import tempfile\n", "from monai.bundle.config_parser import ConfigParser\n", "from monai.apps.nnunet import nnUNetV2Runner\n", - "#from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", + "\n", + "# from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", "import nnunetv2\n", + "\n", "print_config()" ] }, @@ -141,9 +143,13 @@ "os.makedirs(nnunet_root_dir, exist_ok=True)\n", "\n", "data_src_cfg = os.path.join(nnunet_root_dir, \"data_src_cfg.yaml\")\n", - "data_src = {\"modality\": \"CT\", \"datalist\": os.path.join(root_dir,\"Task09_Spleen/msd_task09_spleen_folds.json\"), \"dataroot\": os.path.join(root_dir,\"Task09_Spleen\")}\n", + "data_src = {\n", + " \"modality\": \"CT\",\n", + " \"datalist\": os.path.join(root_dir, \"Task09_Spleen/msd_task09_spleen_folds.json\"),\n", + " \"dataroot\": os.path.join(root_dir, \"Task09_Spleen\"),\n", + "}\n", "\n", - "ConfigParser.export_config_file(data_src, data_src_cfg)\n" + "ConfigParser.export_config_file(data_src, data_src_cfg)" ] }, { @@ -152,7 +158,9 @@ "metadata": {}, "outputs": [], "source": [ - "runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=\"nnUNetTrainer_10epochs\", work_dir=nnunet_root_dir)" + "runner = nnUNetV2Runner(\n", + " input_config=data_src_cfg, trainer_class_name=\"nnUNetTrainer_10epochs\", work_dir=nnunet_root_dir\n", + ")" ] }, { @@ -163,7 +171,7 @@ }, "outputs": [], "source": [ - "runner.plan_and_process(npfp=2,n_proc=[2,2,2])" + "runner.plan_and_process(npfp=2, n_proc=[2, 2, 2])" ] }, { @@ -341,8 +349,8 @@ "outputs": [], "source": [ "nnunet_config = {\n", - " \"dataset_name_or_id\": \"001\",\n", - " \"nnunet_trainer\": \"nnUNetTrainer_1epoch\",\n", + " \"dataset_name_or_id\": \"001\",\n", + " \"nnunet_trainer\": \"nnUNetTrainer_1epoch\",\n", "}\n", "\n", "bundle_root = \"nnUNetBundle\"\n", @@ -661,6 +669,7 @@ "outputs": [], "source": [ "import nnunetv2\n", + "\n", "print(nnunetv2.__file__)" ] }, From 4d8ba15ebd017ddfdacd7494eb1151a8f7bc152f Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 5 Feb 2025 16:48:17 +0000 Subject: [PATCH 04/12] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: 2bd177ebe4a33c9b5d1ebfb2c37ddb80dfc53a4b I, simben , hereby add my Signed-off-by to this commit: 46d8073864288d74099a48a2762fb54ef36cbf86 Signed-off-by: simben --- bundle/06_nnunet_monai_bundle.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bundle/06_nnunet_monai_bundle.ipynb b/bundle/06_nnunet_monai_bundle.ipynb index 5fc8031d0..f8c185126 100644 --- a/bundle/06_nnunet_monai_bundle.ipynb +++ b/bundle/06_nnunet_monai_bundle.ipynb @@ -70,7 +70,7 @@ "from monai.bundle.config_parser import ConfigParser\n", "from monai.apps.nnunet import nnUNetV2Runner\n", "\n", - "# from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", + "from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", "import nnunetv2\n", "\n", "print_config()" From 28d2b815dba97f36d3aafc6e97206a1b7cce7449 Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 5 Feb 2025 16:50:53 +0000 Subject: [PATCH 05/12] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: 2bd177ebe4a33c9b5d1ebfb2c37ddb80dfc53a4b I, simben , hereby add my Signed-off-by to this commit: 46d8073864288d74099a48a2762fb54ef36cbf86 I, simben , hereby add my Signed-off-by to this commit: 4d8ba15ebd017ddfdacd7494eb1151a8f7bc152f Signed-off-by: simben --- bundle/06_nnunet_monai_bundle.ipynb | 1 - 1 file changed, 1 deletion(-) diff --git a/bundle/06_nnunet_monai_bundle.ipynb b/bundle/06_nnunet_monai_bundle.ipynb index f8c185126..a472e584a 100644 --- a/bundle/06_nnunet_monai_bundle.ipynb +++ b/bundle/06_nnunet_monai_bundle.ipynb @@ -71,7 +71,6 @@ "from monai.apps.nnunet import nnUNetV2Runner\n", "\n", "from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", - "import nnunetv2\n", "\n", "print_config()" ] From e65a170f09d3041be0f362dada707b0db5a0fcff Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 12 Feb 2025 17:29:48 +0000 Subject: [PATCH 06/12] Enhance nnUNet MONAI Bundle Notebook with data handling and configuration updates --- bundle/06_nnunet_monai_bundle.ipynb | 132 +- bundle/nnUNet_Bundle.ipynb | 2461 +++++++++++++++++++++++++++ 2 files changed, 2571 insertions(+), 22 deletions(-) create mode 100644 bundle/nnUNet_Bundle.ipynb diff --git a/bundle/06_nnunet_monai_bundle.ipynb b/bundle/06_nnunet_monai_bundle.ipynb index a472e584a..1482340ac 100644 --- a/bundle/06_nnunet_monai_bundle.ipynb +++ b/bundle/06_nnunet_monai_bundle.ipynb @@ -69,8 +69,10 @@ "import tempfile\n", "from monai.bundle.config_parser import ConfigParser\n", "from monai.apps.nnunet import nnUNetV2Runner\n", - "\n", + "import random\n", "from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", + "import json\n", + "from pathlib import Path\n", "\n", "print_config()" ] @@ -86,6 +88,15 @@ "If not specified a temporary directory will be used." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"/home/maia-user/Documents/GitHub/tutorials/bundle/MONAI/Data\"" + ] + }, { "cell_type": "code", "execution_count": null, @@ -109,10 +120,84 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataroot = os.path.join(root_dir, \"Task09_Spleen/\")\n", + "\n", + "test_dir = os.path.join(dataroot, \"imagesTs/\")\n", + "train_dir = os.path.join(dataroot, \"imagesTr/\")\n", + "label_dir = os.path.join(dataroot, \"labelsTr/\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datalist_json = {\"testing\": [], \"training\": []}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datalist_json[\"testing\"] = [\n", + " {\"image\": \"./imagesTs/\" + file} for file in os.listdir(test_dir) if (\".nii.gz\" in file) and (\"._\" not in file)\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datalist_json[\"training\"] = [\n", + " {\"image\": \"./imagesTr/\" + file, \"label\": \"./labelsTr/\" + file, \"fold\": 0}\n", + " for file in os.listdir(train_dir)\n", + " if (\".nii.gz\" in file) and (\"._\" not in file)\n", + "] # Initialize as single fold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random.seed(42)\n", + "random.shuffle(datalist_json[\"training\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_folds = 5\n", + "fold_size = len(datalist_json[\"training\"]) // num_folds\n", + "for i in range(num_folds):\n", + " for j in range(fold_size):\n", + " datalist_json[\"training\"][i * fold_size + j][\"fold\"] = i" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "At the end of the notebook, remember to copy the generated `msd_task09_spleen_folds.json` file to the `/Task09_Spleen` directory." + "datalist_file = Path(root_dir).joinpath(\"Task09_Spleen\",\"Task09_Spleen_folds.json\")\n", + "with open(datalist_file, \"w\", encoding=\"utf-8\") as f:\n", + " json.dump(datalist_json, f, ensure_ascii=False, indent=4)\n", + "print(f\"Datalist is saved to {datalist_file}\")" ] }, { @@ -144,6 +229,7 @@ "data_src_cfg = os.path.join(nnunet_root_dir, \"data_src_cfg.yaml\")\n", "data_src = {\n", " \"modality\": \"CT\",\n", + " \"dataset_name_or_id\": \"09\",\n", " \"datalist\": os.path.join(root_dir, \"Task09_Spleen/msd_task09_spleen_folds.json\"),\n", " \"dataroot\": os.path.join(root_dir, \"Task09_Spleen\"),\n", "}\n", @@ -165,12 +251,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ - "runner.plan_and_process(npfp=2, n_proc=[2, 2, 2])" + "runner.convert_dataset()" ] }, { @@ -179,18 +263,16 @@ "metadata": {}, "outputs": [], "source": [ - "runner.train(configs=\"3d_fullres\")" + "runner.plan_and_process(npfp=2, n_proc=[2, 2, 2])" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ - "runner.run(run_train=True, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False)" + "runner.train_single_model(config=\"3d_fullres\", fold=0)" ] }, { @@ -217,9 +299,10 @@ "source": [ "%%bash\n", "\n", - "rm nnUNetBundle/configs/inference.json\n", + "\n", "python -m monai.bundle init_bundle nnUNetBundle\n", "\n", + "rm nnUNetBundle/configs/inference.json\n", "mkdir -p nnUNetBundle/src\n", "touch nnUNetBundle/src/__init__.py\n", "which tree && tree nnUNetBundle || true" @@ -348,8 +431,8 @@ "outputs": [], "source": [ "nnunet_config = {\n", - " \"dataset_name_or_id\": \"001\",\n", - " \"nnunet_trainer\": \"nnUNetTrainer_1epoch\",\n", + " \"dataset_name_or_id\": \"009\",\n", + " \"nnunet_trainer\": \"nnUNetTrainer_10epochs\",\n", "}\n", "\n", "bundle_root = \"nnUNetBundle\"\n", @@ -392,13 +475,17 @@ "source": [ "%%bash\n", "\n", + "\n", + "BUNDLE_ROOT=nnUNetBundle\n", + "MONAI_DATA_DIRECTORY=/home/maia-user/Documents/GitHub/tutorials/bundle/MONAI/Data\n", + "\n", "python -m monai.bundle run \\\n", - " --config-file nnUNetBundle/configs/inference.yaml \\\n", - " --bundle-root nnUNetBundle \\\n", + " --config-file $BUNDLE_ROOT/configs/inference.yaml \\\n", + " --bundle-root $BUNDLE_ROOT \\\n", " --data_list_file $MONAI_DATA_DIRECTORY/Task09_Spleen/msd_task09_spleen_folds.json \\\n", - " --output-dir nnUNetBundle/pred_output \\\n", - " --data_dir /home/maia-user/Tutorials/MONAI/data/Task09_Spleen \\\n", - " --logging-file nnUNetBundle/configs/logging.conf" + " --output-dir $BUNDLE_ROOT/pred_output \\\n", + " --data_dir $MONAI_DATA_DIRECTORY/Task09_Spleen \\\n", + " --logging-file$BUNDLE_ROOT/configs/logging.conf" ] }, { @@ -410,6 +497,8 @@ "In some cases, you may want to train the nnUNet model from the MONAI Bundle (i.e., without using the nnUNetV2Runner).\n", "This is usually the case when the specific training logic is designed to be used with the MONAI Bundle, such as the Active Learning in MONAI Label or Federated Learning in NVFLare using the MONAI Algo implementation.\n", "\n", + "For more details on how to create the nnUNet MONAI Bundle and test all the different components, you can follow the instructions in the [nnUNet MONAI Bundle Notebook](./nnUNet_Bundle.ipynb)\n", + "\n", "This can be done by following the steps below:" ] }, @@ -432,7 +521,6 @@ " - $import pathlib\n", "\n", "\n", - "pymaia_config_dict: \"$json.load(open(@pymaia_config_file))\"\n", "bundle_root: .\n", "ckpt_dir: \"$@bundle_root + '/models'\"\n", "num_classes: 2\n", @@ -760,7 +848,7 @@ "kernelspec": { "display_name": "MONAI", "language": "python", - "name": "monai" + "name": "python3" }, "language_info": { "codemirror_mode": { diff --git a/bundle/nnUNet_Bundle.ipynb b/bundle/nnUNet_Bundle.ipynb new file mode 100644 index 000000000..02c5b4602 --- /dev/null +++ b/bundle/nnUNet_Bundle.ipynb @@ -0,0 +1,2461 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bec25bff", + "metadata": {}, + "source": [ + "# nnUNet MONAI Bundle\n", + "\n", + "In this notebook, we will demonstrate how to create a MONAI Bundle supporting nnUNet experiment for training and inference. In this step-by step tutorial, we will describe how to create all the required python code and YAML configuration files needed to train and evaluate a nnUNet model using the MONAI Bundle format.\n", + "\n", + "The tutorial assumes that the Spleen Dataset has been already downloaded and preprocessed as described in the [nnUNet MONAI Bundle Notebook](./06_nnunet_monai_bundle.ipynb)." + ] + }, + { + "cell_type": "markdown", + "id": "70a2adb6", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63d65cf0-56ea-4c45-88bc-e271a5e2195c", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from monai.data import Dataset\n", + "from monai.handlers import StatsHandler, from_engine, MeanDice, ValidationHandler, LrScheduleHandler, CheckpointSaver, CheckpointLoader, TensorBoardStatsHandler, MLFlowHandler\n", + "from monai.engines import SupervisedTrainer, SupervisedEvaluator\n", + "\n", + "from monai.transforms import Compose, Lambdad, Activationsd, AsDiscreted\n", + "\n", + "import re\n", + "import pathlib\n", + "import os\n", + "import yaml\n", + "import json\n", + "from monai.bundle import ConfigParser\n", + "import monai\n", + "from pathlib import Path\n", + "from odict import odict\n" + ] + }, + { + "cell_type": "markdown", + "id": "297a2bb9", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d9c2ae8", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"MONAI/Data\"\n", + "\n", + "work_dir = os.path.join(os.environ[\"MONAI_DATA_DIRECTORY\"], \"nnUNet\")\n", + "\n", + "nnunet_raw = os.path.join(work_dir, \"nnUNet_raw_data_base\")\n", + "nnunet_preprocessed = os.path.join(\".\", work_dir, \"nnUNet_preprocessed\")\n", + "nnunet_results = os.path.join(\".\", work_dir, \"nnUNet_trained_models\")\n", + "\n", + "if not os.path.exists(nnunet_raw):\n", + " os.makedirs(nnunet_raw)\n", + "\n", + "if not os.path.exists(nnunet_preprocessed):\n", + " os.makedirs(nnunet_preprocessed)\n", + "\n", + "if not os.path.exists(nnunet_results):\n", + " os.makedirs(nnunet_results)\n", + "\n", + "# claim environment variable\n", + "os.environ[\"nnUNet_raw\"] = nnunet_raw\n", + "os.environ[\"nnUNet_preprocessed\"] = nnunet_preprocessed\n", + "os.environ[\"nnUNet_results\"] = nnunet_results\n", + "os.environ[\"OMP_NUM_THREADS\"] = str(1)" + ] + }, + { + "cell_type": "markdown", + "id": "c6757489", + "metadata": {}, + "source": [ + "## nnUNet Trainer\n", + "\n", + "The core component for the nnUNet MONAI Bundle is the `get_nnunet_trainer` function. This function is responsible for creating the nnUNet trainer object from the native nnUNetv2 implementation. From the nnUNet trainer object, we can access the training components, such as the data loaders, model, learning rate scheduler, optimizer, and loss function, and perform training and inference tasks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4090170-b7c4-402b-9d70-9b59c463354b", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.bundle.nnunet import get_nnunet_trainer\n", + "\n", + "nnunet_config = {\n", + " \"dataset_name_or_id\": \"009\",\n", + " \"configuration\": \"3d_fullres\",\n", + " \"trainer_class_name\": \"nnUNetTrainer_10epochs\",\n", + " \"plans_identifier\": \"nnUNetPlans\",\n", + " \"fold\": 0,\n", + "}\n", + "\n", + "\n", + "nnunet_trainer = get_nnunet_trainer(**nnunet_config)" + ] + }, + { + "cell_type": "markdown", + "id": "f355baa4", + "metadata": {}, + "source": [ + "The function `get_nnunet_trainer` accepts the following parameters:\n", + "\n", + "- `dataset_name_or_id`: The dataset name or ID to be used for training and evaluation.\n", + "- `fold`: The fold number for the cross-validation experiment.\n", + "- `configuration`: The training configuration for the nnUNet trainer, usually `3d_fullres`.\n", + "- `trainer_class_name`: The nnUNet trainer class name to be used for training, e.g. `nnUNetTrainer`.\n", + "- `plans_identifier`: The nnUNet plans identifier for the dataset, e.g. `nnUNetPlans`." + ] + }, + { + "cell_type": "markdown", + "id": "765619ea", + "metadata": {}, + "source": [ + "## Train and Val Data Loaders" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8e60cdb", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataloader = nnunet_trainer.dataloader_train\n", + "train_data = [{'case_identifier':k} for k in nnunet_trainer.dataloader_train.generator._data.dataset.keys()]\n", + "train_dataset = Dataset(data=train_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b80cc9a", + "metadata": {}, + "outputs": [], + "source": [ + "val_dataloader = nnunet_trainer.dataloader_val\n", + "val_data = [{'case_identifier':k} for k in nnunet_trainer.dataloader_val.generator._data.dataset.keys()]\n", + "val_dataset = Dataset(data=val_data)" + ] + }, + { + "cell_type": "markdown", + "id": "3c7756d7", + "metadata": {}, + "source": [ + "## Network, Optimizer, and Loss Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26b88a8e", + "metadata": {}, + "outputs": [], + "source": [ + "device = nnunet_trainer.device\n", + "\n", + "network = nnunet_trainer.network\n", + "optimizer = nnunet_trainer.optimizer\n", + "lr_scheduler = nnunet_trainer.lr_scheduler\n", + "loss = nnunet_trainer.loss" + ] + }, + { + "cell_type": "markdown", + "id": "5d7d6023", + "metadata": {}, + "source": [ + "## Prepare Batch Function\n", + "\n", + "The nnUnet `DataLoader` returns a dictionary with the `data` and `target` keys. Since the `SupervisedTrainer` used in the MONAI Bundle expects the data and target to be separate tensors, we need to create a custom prepare batch function to extract the data and target tensors from the dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "155e9460-9f69-4b15-bfb9-eae032afbc92", + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_nnunet_batch(batch, device, non_blocking):\n", + " data = batch[\"data\"].to(device, non_blocking=non_blocking)\n", + " if isinstance(batch[\"target\"], list):\n", + " target = [i.to(device, non_blocking=non_blocking) for i in batch[\"target\"]]\n", + " else:\n", + " target = batch[\"target\"].to(device, non_blocking=non_blocking)\n", + " return data, target" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70774a2a-84ff-4297-98e5-6452567f13a1", + "metadata": {}, + "outputs": [], + "source": [ + "image, label = prepare_nnunet_batch(next(iter(train_dataloader)),device=\"cpu\",non_blocking=True)" + ] + }, + { + "cell_type": "markdown", + "id": "54b5c684", + "metadata": {}, + "source": [ + "## MONAI Supervised Trainer\n", + "\n", + "The `SupervisedTrainer` class from MONAI is used to train the nnUNet model. For a minimal setup, we need to provide the model, optimizer, loss function, data loaders, number of epochs and the device to run the training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d480fbe6", + "metadata": {}, + "outputs": [], + "source": [ + "train_handlers = [\n", + " StatsHandler(\n", + " output_transform= from_engine(['loss'], first=True),\n", + " tag_name= \"train_loss\"\n", + " )\n", + "]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "844bb28a", + "metadata": {}, + "outputs": [], + "source": [ + "iterations = 100\n", + "epochs = 50" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "415bfc68", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = SupervisedTrainer(\n", + " amp= True,\n", + " device = device,\n", + " epoch_length = iterations,\n", + " loss_function = loss,\n", + " max_epochs = epochs,\n", + " network = network,\n", + " prepare_batch = prepare_nnunet_batch,\n", + " optimizer = optimizer,\n", + " train_data_loader = train_dataloader,\n", + " train_handlers= train_handlers\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba2ce831", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "c41fcf2a", + "metadata": {}, + "source": [ + "## Adding Validation and Validation Metrics\n", + "\n", + "For a complete training setup, we need to add the validation data loader and the validation metrics to the `SupervisedTrainer`. Using the MONAI class `SupervisedEvaluator`, we can evaluate the model on the validation data loader and calculate the validation metrics (`Dice Score`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5f713a2", + "metadata": {}, + "outputs": [], + "source": [ + "val_key_metric = MeanDice(\n", + " output_transform = from_engine(['pred', 'label']),\n", + " reduction = \"mean\",\n", + " include_background = False\n", + "\n", + ")\n", + "\n", + "additional_metrics = {\n", + " \"Val_Dice_Per_Class\": MeanDice(\n", + " output_transform = from_engine(['pred', 'label']),\n", + " reduction = \"mean_batch\",\n", + " include_background = False,\n", + " )\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "fa8287fb", + "metadata": {}, + "source": [ + "Additionally, in order to compute the Mean Dice score over the batch, we need to apply a pos-processing transformtation to the nnUNet model output. Since `MeanDice` accepts `y` and `y_preds` as Batch-first tensors (BCHW[D]), we need to create a custom post-processing transform to convert the nnUNet model output to the required format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e28ec37", + "metadata": {}, + "outputs": [], + "source": [ + "num_classes = 2\n", + "\n", + "postprocessing = Compose(\n", + " transforms=[\n", + " ## Extract only high-res predictions from Deep Supervision\n", + " Lambdad( \n", + " keys= [\"pred\",\"label\"],\n", + " func = lambda x: x[0]\n", + " ),\n", + " ## Apply Softmax to the predictions\n", + " Activationsd(\n", + " keys= \"pred\",\n", + " softmax= True\n", + " ),\n", + " ## Binarize the predictions\n", + " AsDiscreted(\n", + " keys= \"pred\",\n", + " threshold= 0.5\n", + " ),\n", + " ## Convert the labels to one-hot\n", + " AsDiscreted(\n", + " keys= \"label\",\n", + " to_onehot= num_classes\n", + " )\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a90e728", + "metadata": {}, + "outputs": [], + "source": [ + "val_handlers = [StatsHandler(\n", + " iteration_log = False\n", + ")]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9586476", + "metadata": {}, + "outputs": [], + "source": [ + "val_iterations = 100\n", + "val_interval = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8081ce4", + "metadata": {}, + "outputs": [], + "source": [ + "evaluator = SupervisedEvaluator(\n", + " amp= True,\n", + " device = device,\n", + " epoch_length = val_iterations,\n", + " network = network,\n", + " key_val_metric={\"Val_Dice\": val_key_metric},\n", + " prepare_batch= prepare_nnunet_batch,\n", + " val_data_loader = val_dataloader,\n", + " val_handlers= val_handlers,\n", + " postprocessing= postprocessing,\n", + " additional_metrics= additional_metrics,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "aadfd315", + "metadata": {}, + "source": [ + "And finally, we add the evaluator to the `SupervisedTrainer` to calculate the validation metrics during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2bc29ad", + "metadata": {}, + "outputs": [], + "source": [ + "train_handlers.append(\n", + " ValidationHandler(\n", + " epoch_level = True,\n", + " interval= val_interval,\n", + " validator = evaluator\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3c904b0a", + "metadata": {}, + "source": [ + "We can also add the `MeanDice` metric to the `SupervisedTrainer` to calculate the mean dice score over the batch during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9d44de3", + "metadata": {}, + "outputs": [], + "source": [ + "train_key_metric = MeanDice(\n", + " output_transform = from_engine(['pred', 'label']),\n", + " reduction = \"mean\",\n", + " include_background = False\n", + "\n", + ")\n", + "\n", + "additional_metrics = {\n", + " \"Train_Dice_Per_Class\": MeanDice(\n", + " output_transform = from_engine(['pred', 'label']),\n", + " reduction = \"mean_batch\",\n", + " include_background = False,\n", + " )\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a339901b", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = SupervisedTrainer(\n", + " amp= True,\n", + " device = device,\n", + " epoch_length = iterations,\n", + " loss_function = loss,\n", + " max_epochs = epochs,\n", + " network = network,\n", + " prepare_batch = prepare_nnunet_batch,\n", + " optimizer = optimizer,\n", + " train_data_loader = train_dataloader,\n", + " train_handlers= train_handlers,\n", + " key_train_metric = {\"Train_Dice\": train_key_metric},\n", + " postprocessing= postprocessing,\n", + " additional_metrics = additional_metrics\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f1869f5", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "fbfb0762", + "metadata": {}, + "source": [ + "## Learning Rate Scheduler\n", + "\n", + "One last component to add to the `SupervisedTrainer`, in order to replicate the training behaviour of the native nnUNet, is the learning rate scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b92598c", + "metadata": {}, + "outputs": [], + "source": [ + "train_handlers.append(\n", + " LrScheduleHandler(\n", + " lr_scheduler = lr_scheduler,\n", + " print_lr = True\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54efe274", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = SupervisedTrainer(\n", + " amp= True,\n", + " device = device,\n", + " epoch_length = iterations,\n", + " loss_function = loss,\n", + " max_epochs = epochs,\n", + " network = network,\n", + " prepare_batch = prepare_nnunet_batch,\n", + " optimizer = optimizer,\n", + " train_data_loader = train_dataloader,\n", + " train_handlers= train_handlers,\n", + " key_train_metric = {\"Train_Dice\": train_key_metric},\n", + " postprocessing= postprocessing,\n", + " additional_metrics = additional_metrics\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f687bc6", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fac3c8c5", + "metadata": {}, + "outputs": [], + "source": [ + "train_handlers[-1].lr_scheduler.get_last_lr()" + ] + }, + { + "cell_type": "markdown", + "id": "52a36367", + "metadata": {}, + "source": [ + "## Checkpointing\n", + "\n", + "To save the model weights during training, we can use the `CheckpointSaver` callback from MONAI. This callback saves the model weights after each epoch.\n", + "We can later use the `CheckpointLoader` to load the model weights and perform inference or resume training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54c4fb21", + "metadata": {}, + "outputs": [], + "source": [ + "ckpt_dir = \"nnUNetBundle/models\"\n", + "\n", + "val_handlers.append(\n", + " CheckpointSaver(\n", + " save_dir= ckpt_dir,\n", + " save_dict= {\"network_weights\": nnunet_trainer.network._orig_mod, \"optimizer_state\": nnunet_trainer.optimizer, \"scheduler\": nnunet_trainer.lr_scheduler},\n", + " #save_final= True,\n", + " save_interval= 1,\n", + " save_key_metric= True,\n", + " #final_filename= \"model_final.pt\",\n", + " #key_metric_filename= \"model.pt\",\n", + " n_saved= 1\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2a10c5d9", + "metadata": {}, + "source": [ + "## Reload Checkpoint\n", + "\n", + "When resuming the training from a checkpoint, we also want to restart the training from the same epoch. To do this, we need to load the checkpoint and update the `trainer.state.epoch` and `trainer.state.iteration` parameter in the `SupervisedTrainer`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91c0a0ce", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def subfiles(folder, prefix=None, suffix=None, join=True, sort=True):\n", + " files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]\n", + " if prefix is not None:\n", + " files = [f for f in files if f.startswith(prefix)]\n", + " if suffix is not None:\n", + " files = [f for f in files if f.endswith(suffix)]\n", + " if sort:\n", + " files.sort()\n", + " if join:\n", + " files = [os.path.join(folder, f) for f in files]\n", + " return files\n", + "\n", + "def get_checkpoint(epoch, ckpt_dir):\n", + " if epoch == \"latest\":\n", + "\n", + " latest_checkpoints = subfiles(ckpt_dir, prefix=\"checkpoint_epoch\", sort=True,\n", + " join=False)\n", + " epochs = []\n", + " for latest_checkpoint in latest_checkpoints:\n", + " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\"):-len(\".pt\")]))\n", + "\n", + " epochs.sort()\n", + " latest_epoch = epochs[-1]\n", + " return latest_epoch\n", + " else:\n", + " return epoch\n", + "\n", + "def reload_checkpoint(trainer, epoch, num_train_batches_per_epoch, ckpt_dir):\n", + "\n", + " epoch_to_load = get_checkpoint(epoch, ckpt_dir)\n", + " trainer.state.epoch = epoch_to_load\n", + " trainer.state.iteration = (epoch_to_load* num_train_batches_per_epoch) +1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ece9e988", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "reload_checkpoint_epoch = \"latest\"\n", + "\n", + "train_handlers.append(\n", + " CheckpointLoader(\n", + " load_path= os.path.join(ckpt_dir,'checkpoint_epoch='+str(get_checkpoint(reload_checkpoint_epoch, ckpt_dir))+'.pt'),\n", + " load_dict= {\"network_weights\": nnunet_trainer.network._orig_mod, \"optimizer_state\": nnunet_trainer.optimizer, \"scheduler\": nnunet_trainer.lr_scheduler},\n", + " map_location= device\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "78db7528", + "metadata": {}, + "source": [ + "## Initial nnUNet Checkpoint\n", + "\n", + "In order to provide compatibility with the native nnUNet, we need to save the nnUNet-specific configuration, together the regular MONAI checkpoint. This is done only once, before the training starts. At the end of the training, we will have a MONAI checkpoint and a nnUNet checkpoint. To be able to convert the MONAI checkpoint to a nnUNet checkpoint at any time, we can then combine the two checkpoints." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8131d003", + "metadata": {}, + "outputs": [], + "source": [ + "checkpoint = {\n", + " \"inference_allowed_mirroring_axes\": nnunet_trainer.inference_allowed_mirroring_axes,\n", + " \"init_args\": nnunet_trainer.my_init_kwargs,\n", + " \"trainer_name\": nnunet_trainer.__class__.__name__\n", + "}\n", + "checkpoint_filename = os.path.join(ckpt_dir,'nnunet_checkpoint.pth')\n", + "\n", + "torch.save(checkpoint, checkpoint_filename)" + ] + }, + { + "cell_type": "markdown", + "id": "c0e26cdb", + "metadata": {}, + "source": [ + "## MLFlow and Tensorboard Monitoring\n", + "\n", + "To monitor the training process, we can use MLFlow and Tensorboard. We can log the training metrics, hyperparameters, and model weights to MLFlow, and visualize the training metrics using Tensorboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2976402", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "log_dir = \"nnUNetBundle/logs\"\n", + "\n", + "train_handlers.append(\n", + " TensorBoardStatsHandler(\n", + " log_dir= log_dir,\n", + " output_transform= from_engine(['loss'], first=True),\n", + " tag_name = \"train_loss\"\n", + " )\n", + ")\n", + "\n", + "val_handlers.append(\n", + " TensorBoardStatsHandler(\n", + " log_dir= log_dir,\n", + " iteration_log = False\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86d54487", + "metadata": {}, + "outputs": [], + "source": [ + "def mlflow_transform(state_output):\n", + " return state_output[0]['loss']\n", + "\n", + "class MLFlownnUNetHandler(MLFlowHandler):\n", + " def __init__(self, label_dict, **kwargs):\n", + " super(MLFlownnUNetHandler, self).__init__(**kwargs)\n", + " self.label_dict = label_dict\n", + " \n", + " def _default_epoch_log(self, engine) -> None:\n", + " \"\"\"\n", + " Execute epoch level log operation.\n", + " Default to track the values from Ignite `engine.state.metrics` dict and\n", + " track the values of specified attributes of `engine.state`.\n", + "\n", + " Args:\n", + " engine: Ignite Engine, it can be a trainer, validator or evaluator.\n", + "\n", + " \"\"\"\n", + " log_dict = engine.state.metrics\n", + " if not log_dict:\n", + " return\n", + "\n", + " current_epoch = self.global_epoch_transform(engine.state.epoch)\n", + "\n", + " new_log_dict = {}\n", + "\n", + " for metric in log_dict:\n", + " if type(log_dict[metric]) == torch.Tensor:\n", + " for i,val in enumerate(log_dict[metric]):\n", + " new_log_dict[metric+\"_{}\".format(list(self.label_dict.keys())[i+1])] = val\n", + " else:\n", + " new_log_dict[metric] = log_dict[metric]\n", + " self._log_metrics(new_log_dict, step=current_epoch)\n", + "\n", + " if self.state_attributes is not None:\n", + " attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}\n", + " self._log_metrics(attrs, step=current_epoch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e8fdc2e", + "metadata": {}, + "outputs": [], + "source": [ + "def create_mlflow_experiment_params(params_file, custom_params=None):\n", + " params_dict = {}\n", + " config_values = monai.config.deviceconfig.get_config_values()\n", + " for k in config_values:\n", + " params_dict[re.sub(\"[()]\",\" \",str(k))] = config_values[k]\n", + "\n", + " optional_config_values = monai.config.deviceconfig.get_optional_config_values()\n", + " for k in optional_config_values:\n", + " params_dict[re.sub(\"[()]\",\" \",str(k))] = optional_config_values[k]\n", + "\n", + " gpu_info = monai.config.deviceconfig.get_gpu_info()\n", + " for k in gpu_info:\n", + " params_dict[re.sub(\"[()]\",\" \",str(k))] = str(gpu_info[k])\n", + "\n", + " yaml_config_files = [params_file]\n", + " # %%\n", + " monai_config = {}\n", + " for config_file in yaml_config_files:\n", + " with open(config_file, 'r') as file:\n", + " monai_config.update(yaml.safe_load(file))\n", + "\n", + " monai_config[\"bundle_root\"] = str(Path(Path(params_file).parent).parent)\n", + "\n", + " parser = ConfigParser(monai_config, globals={\"os\": \"os\",\n", + " \"pathlib\": \"pathlib\",\n", + " \"json\": \"json\",\n", + " \"ignite\": \"ignite\"\n", + " })\n", + "\n", + " parser.parse(True)\n", + "\n", + " for k in monai_config:\n", + " params_dict[k] = parser.get_parsed_content(k,instantiate=True)\n", + "\n", + " if custom_params is not None:\n", + " for k in custom_params:\n", + " params_dict[k] = custom_params[k]\n", + " return params_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8fb3858f", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/mlflow_params.yaml\n", + "\n", + "dataset_name_or_id: \"009\"\n", + "nnunet_trainer_class_name: \"nnUNetTrainer\"\n", + "nnunet_plans_identifier: \"nnUNetPlans\"\n", + "\n", + "num_classes: 2\n", + "label_dict:\n", + " 0: \"background\"\n", + " 1: \"spleen\"\n", + " \n", + "tracking_uri: \"http://localhost:5000\"\n", + "mlflow_experiment_name: \"nnUNet_Bundle_Spleen\"\n", + "mlflow_run_name: \"nnUNet_Bundle_Spleen\"\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "174505fe", + "metadata": {}, + "outputs": [], + "source": [ + "mlflow_experiment_name = \"nnUNet_Bundle_Spleen\"\n", + "mlflow_run_name = \"nnUNet_Bundle_Spleen\"\n", + "label_dict = {0: \"background\", 1: \"Spleen\"}\n", + "tracking_uri = \"http://localhost:5000\"\n", + "\n", + "params_file = \"nnUNetBundle/mlflow_params.yaml\"\n", + "\n", + "\n", + "train_handlers.append(\n", + " MLFlownnUNetHandler(\n", + " dataset_dict = {\"train\": train_dataset},\n", + " dataset_keys = \"case_identifier\",\n", + " experiment_param = create_mlflow_experiment_params(params_file),\n", + " experiment_name= mlflow_experiment_name,\n", + " label_dict = label_dict,\n", + " output_transform = mlflow_transform,\n", + " run_name = mlflow_run_name,\n", + " state_attributes = [\"best_metric\", \"best_metric_epoch\"],\n", + " tag_name = \"Train_Loss\",\n", + " tracking_uri = tracking_uri,\n", + " )\n", + ")\n", + "\n", + "val_handlers.append(\n", + " MLFlownnUNetHandler(\n", + " experiment_name= mlflow_experiment_name,\n", + " iteration_log = False,\n", + " label_dict = label_dict,\n", + " output_transform = mlflow_transform,\n", + " run_name = mlflow_run_name,\n", + " state_attributes = [\"best_metric\", \"best_metric_epoch\"],\n", + " tracking_uri = tracking_uri,\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "20722504", + "metadata": {}, + "source": [ + "To start the MLFlow server, we can run the following command in the terminal:\n", + "\n", + "```bash\n", + "cd nnUNetBundle/MLFlow && mlflow server\n", + "```\n", + "To run Tensorboard, we can use the following command:\n", + "\n", + "```bash\n", + "tensorboard --logdir Bundle/logs\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d11d8c8", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = SupervisedTrainer(\n", + " amp= True,\n", + " device = device,\n", + " epoch_length = iterations,\n", + " loss_function = loss,\n", + " max_epochs = epochs,\n", + " network = network,\n", + " prepare_batch = prepare_nnunet_batch,\n", + " optimizer = optimizer,\n", + " train_data_loader = train_dataloader,\n", + " train_handlers= train_handlers,\n", + " key_train_metric = {\"Train_Dice\": train_key_metric},\n", + " postprocessing= postprocessing,\n", + " additional_metrics = additional_metrics\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcc921bf", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "b353be42", + "metadata": {}, + "source": [ + "## Create MONAI Bundle" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "564700ff", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "python -m monai.bundle init_bundle nnUNetBundle\n", + "\n", + "mkdir -p nnUNetBundle/nnUNet\n", + "mkdir -p nnUNetBundle/src\n", + "mkdir -p nnUNetBundle/nnUNet/evaluator\n", + "which tree && tree nnUNetBundle || true" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb7aa3da", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/global.yaml\n", + "\n", + "iterations: $@nnunet_trainer.num_iterations_per_epoch\n", + "device: $@nnunet_trainer.device\n", + "epochs: $@nnunet_trainer.num_epochs\n", + "\n", + "bundle_root: .\n", + "ckpt_dir: \"$@bundle_root + '/models'\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33f32c17", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/params.yaml\n", + "\n", + "\n", + "dataset_name_or_id: \"009\"\n", + "nnunet_trainer_class_name: \"nnUNetTrainer\"\n", + "nnunet_plans_identifier: \"nnUNetPlans\"\n", + "nnunet_configuration: \"3d_fullres\"\n", + "\n", + "num_classes: 2\n", + "label_dict:\n", + " 0: \"background\"\n", + " 1: \"spleen\"\n", + " \n", + "tracking_uri: \"http://localhost:5000\"\n", + "mlflow_experiment_name: \"nnUNet_Bundle_Spleen\"\n", + "mlflow_run_name: \"nnUNet_Bundle_Spleen\"\n", + "log_dir: \"$@bundle_root + '/logs'\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e31c3314", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/imports.yaml\n", + "\n", + "imports:\n", + "- $import glob\n", + "- $import os\n", + "- $import ignite\n", + "- $import torch\n", + "- $import shutil\n", + "- $import json\n", + "- $import src\n", + "- $import nnunetv2\n", + "- $import src.mlflow\n", + "- $import src.trainer\n", + "- $from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc7fc76e", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/run.yaml\n", + "\n", + "run:\n", + "- \"$torch.save(@checkpoint,@checkpoint_filename)\"\n", + "- \"$shutil.copy(Path(@nnunet_model_folder).joinpath('dataset.json'), @bundle_root+'/models/dataset.json')\"\n", + "- \"$shutil.copy(Path(@nnunet_model_folder).joinpath('plans.json'), @bundle_root+'/models/plans.json')\"\n", + "- \"$@train#pbar.attach(@train#trainer,output_transform=lambda x: {'loss': x[0]['loss']})\"\n", + "- \"$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])\"\n", + "- $@train#trainer.run()\n", + "\n", + "initialize:\n", + "- $monai.utils.set_determinism(seed=123)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6302d85", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/nnunet_trainer.yaml\n", + "\n", + "nnunet_trainer:\n", + " _target_ : get_nnunet_trainer\n", + " dataset_name_or_id: \"@dataset_name_or_id\"\n", + " configuration: \"@nnunet_configuration\"\n", + " fold: 0\n", + " trainer_class_name: \"@nnunet_trainer_class_name\"\n", + " plans_identifier: \"@nnunet_plans_identifier\"\n", + "\n", + "loss: $@nnunet_trainer.loss\n", + "lr_scheduler: $@nnunet_trainer.lr_scheduler\n", + "\n", + "network: $@nnunet_trainer.network\n", + "\n", + "optimizer: $@nnunet_trainer.optimizer\n", + "\n", + "checkpoint:\n", + " init_args: '$@nnunet_trainer.my_init_kwargs'\n", + " trainer_name: '$@nnunet_trainer.__class__.__name__'\n", + " inference_allowed_mirroring_axes: '$@nnunet_trainer.inference_allowed_mirroring_axes'\n", + "\n", + "checkpoint_filename: \"$@bundle_root+'/models/nnunet_checkpoint.pth'\"\n", + "\n", + "dataset_name: \"$nnunetv2.utilities.dataset_name_id_conversion.maybe_convert_to_dataset_name(@dataset_name_or_id)\"\n", + "nnunet_model_folder: \"$os.path.join(os.environ['nnUNet_results'], @dataset_name, @nnunet_trainer_class_name+'__'+@nnunet_plans_identifier+'__'+@nnunet_configuration)\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8718d0b", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/train_metrics.yaml\n", + "\n", + "train_key_metric:\n", + " Train_Dice:\n", + " _target_: \"MeanDice\"\n", + " include_background: False\n", + " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", + " reduction: \"mean\"\n", + "\n", + "train_additional_metrics:\n", + " Train_Dice_per_class:\n", + " _target_: \"MeanDice\"\n", + " include_background: False\n", + " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", + " reduction: \"mean_batch\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bd28f89", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/train_postprocessing.yaml\n", + "\n", + "train_postprocessing:\n", + " _target_: \"Compose\"\n", + " transforms:\n", + " - _target_: Lambdad\n", + " keys:\n", + " - \"pred\"\n", + " - \"label\"\n", + " func: \"$lambda x: x[0]\"\n", + " - _target_: Activationsd\n", + " keys:\n", + " - \"pred\"\n", + " softmax: True\n", + " - _target_: AsDiscreted\n", + " keys:\n", + " - \"pred\"\n", + " threshold: 0.5\n", + " - _target_: AsDiscreted\n", + " keys:\n", + " - \"label\"\n", + " to_onehot: \"@num_classes\"\n", + " \n", + "train_postprocessing_region_based:\n", + " _target_: \"Compose\"\n", + " transforms:\n", + " - _target_: Lambdad\n", + " keys:\n", + " - \"pred\"\n", + " - \"label\"\n", + " func: \"$lambda x: x[0]\"\n", + " - _target_: Activationsd\n", + " keys:\n", + " - \"pred\"\n", + " sigmoid: True\n", + " - _target_: AsDiscreted\n", + " keys:\n", + " - \"pred\"\n", + " threshold: 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7268a30a", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/train.yaml\n", + "\n", + "train:\n", + " pbar:\n", + " _target_: \"ignite.contrib.handlers.tqdm_logger.ProgressBar\"\n", + " dataloader: $@nnunet_trainer.dataloader_train\n", + " train_data: \"$[{'case_identifier':k} for k in @nnunet_trainer.dataloader_train.generator._data.dataset.keys()]\"\n", + " train_dataset:\n", + " _target_: Dataset\n", + " data: \"@train#train_data\"\n", + " inferer:\n", + " _target_: SimpleInferer\n", + " trainer:\n", + " _target_: SupervisedTrainer\n", + " amp: true\n", + " device: '@device'\n", + " additional_metrics: \"@train_additional_metrics\"\n", + " epoch_length: \"@iterations\"\n", + " inferer: '@train#inferer'\n", + " key_train_metric: '@train_key_metric'\n", + " loss_function: '@loss'\n", + " max_epochs: '@epochs'\n", + " network: '@network'\n", + " prepare_batch: \"$src.trainer.prepare_nnunet_batch\"\n", + " optimizer: '@optimizer'\n", + " postprocessing: '@train_postprocessing'\n", + " train_data_loader: '@train#dataloader'\n", + " train_handlers: '@train_handlers#handlers'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "944f75b6", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/train_handlers.yaml\n", + "\n", + "train_handlers:\n", + " handlers:\n", + " - _target_: \"$src.mlflow.MLFlownnUNetHandler\"\n", + " label_dict: \"@label_dict\"\n", + " tracking_uri: \"@tracking_uri\"\n", + " experiment_name: \"@mlflow_experiment_name\"\n", + " run_name: \"@mlflow_run_name\"\n", + " output_transform: \"$src.mlflow.mlflow_transform\"\n", + " dataset_dict:\n", + " train: \"@train#train_dataset\"\n", + " dataset_keys: 'case_identifier'\n", + " state_attributes:\n", + " - \"iteration\"\n", + " - \"epoch\"\n", + " tag_name: 'Train_Loss'\n", + " experiment_param: \"$src.mlflow.create_mlflow_experiment_params( @bundle_root + '/nnUNet/params.yaml')\"\n", + " #artifacts=None\n", + " optimizer_param_names: 'lr'\n", + " #close_on_complete: False\n", + " - _target_: LrScheduleHandler\n", + " lr_scheduler: '@lr_scheduler'\n", + " print_lr: true\n", + " - _target_: ValidationHandler\n", + " epoch_level: true\n", + " interval: '@val_interval'\n", + " validator: '@validate#evaluator'\n", + " #- _target_: StatsHandler\n", + " # output_transform: $monai.handlers.from_engine(['loss'], first=True)\n", + " # tag_name: train_loss\n", + " - _target_: TensorBoardStatsHandler\n", + " log_dir: '@log_dir'\n", + " output_transform: $monai.handlers.from_engine(['loss'], first=True)\n", + " tag_name: train_loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4933773b", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/configs/train_resume.yaml\n", + "\n", + "run:\n", + "- '$src.trainer.reload_checkpoint(@train#trainer,@reload_checkpoint_epoch,@nnunet_trainer.num_iterations_per_epoch,@bundle_root+\"/models\")'\n", + "- \"$@train#pbar.attach(@train#trainer,output_transform=lambda x: {'loss': x[0]['loss']})\"\n", + "- \"$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])\"\n", + "- $@train#trainer.run()\n", + "\n", + "train_handlers:\n", + " handlers:\n", + " - _target_: \"$src.mlflow.MLFlownnUNetHandler\"\n", + " label_dict: \"@label_dict\"\n", + " tracking_uri: \"@tracking_uri\"\n", + " experiment_name: \"@mlflow_experiment_name\"\n", + " run_name: \"@mlflow_run_name\"\n", + " output_transform: \"$src.mlflow.mlflow_transform\"\n", + " dataset_dict:\n", + " train: \"@train#train_dataset\"\n", + " dataset_keys: 'case_identifier'\n", + " state_attributes:\n", + " - \"iteration\"\n", + " - \"epoch\"\n", + " tag_name: 'Train_Loss'\n", + " experiment_param: \"$src.mlflow.create_mlflow_experiment_params( @bundle_root + '/nnUNet/params.yaml')\"\n", + " #artifacts=None\n", + " optimizer_param_names: 'lr'\n", + " #close_on_complete: False\n", + " - _target_: LrScheduleHandler\n", + " lr_scheduler: '@lr_scheduler'\n", + " print_lr: true\n", + " - _target_: ValidationHandler\n", + " epoch_level: true\n", + " interval: '@val_interval'\n", + " validator: '@validate#evaluator'\n", + " #- _target_: StatsHandler\n", + " # output_transform: $monai.handlers.from_engine(['loss'], first=True)\n", + " # tag_name: train_loss\n", + " - _target_: TensorBoardStatsHandler\n", + " log_dir: '@log_dir'\n", + " output_transform: $monai.handlers.from_engine(['loss'], first=True)\n", + " tag_name: train_loss\n", + " - _target_: CheckpointLoader\n", + " load_dict:\n", + " network_weights: '$@nnunet_trainer.network._orig_mod'\n", + " optimizer_state: '$@nnunet_trainer.optimizer'\n", + " scheduler: '$@nnunet_trainer.lr_scheduler'\n", + " load_path: '$@bundle_root + \"/models/checkpoint_epoch=\"+str(src.trainer.get_checkpoint(@reload_checkpoint_epoch, @bundle_root+\"/models\"))+\".pt\"'\n", + " map_location: '@device'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f55c6569", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/val_metrics.yaml\n", + "\n", + "val_key_metric:\n", + " Val_Dice:\n", + " _target_: \"MeanDice\"\n", + " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", + " reduction: \"mean\"\n", + " include_background: False\n", + " \n", + "val_additional_metrics:\n", + " Val_Dice_per_class:\n", + " _target_: \"MeanDice\"\n", + " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", + " reduction: \"mean_batch\"\n", + " include_background: False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21af5ce1", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/val_handlers.yaml\n", + "\n", + "val_handlers:\n", + " handlers:\n", + " - _target_: StatsHandler\n", + " iteration_log: false\n", + " - _target_: TensorBoardStatsHandler\n", + " iteration_log: false\n", + " log_dir: '@log_dir'\n", + " - _target_: \"$src.mlflow.MLFlownnUNetHandler\"\n", + " label_dict: \"@label_dict\"\n", + " tracking_uri: \"@tracking_uri\"\n", + " experiment_name: \"@mlflow_experiment_name\"\n", + " run_name: \"@mlflow_run_name\"\n", + " output_transform: \"$src.mlflow.mlflow_transform\"\n", + " iteration_log: False\n", + " state_attributes:\n", + " - \"best_metric\"\n", + " - \"best_metric_epoch\"\n", + " - _target_: \"CheckpointSaver\"\n", + " save_dir: \"$str(@bundle_root)+'/models'\"\n", + " save_interval: 1\n", + " n_saved: 1\n", + " save_key_metric: true\n", + " save_dict:\n", + " network_weights: '$@nnunet_trainer.network._orig_mod'\n", + " optimizer_state: '$@nnunet_trainer.optimizer'\n", + " scheduler: '$@nnunet_trainer.lr_scheduler'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d3b2a5f", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/validate.yaml\n", + "\n", + "val_interval: 1\n", + "validate:\n", + " pbar:\n", + " _target_: \"ignite.contrib.handlers.tqdm_logger.ProgressBar\"\n", + " dataloader: $@nnunet_trainer.dataloader_val\n", + " evaluator:\n", + " _target_: SupervisedEvaluator\n", + " additional_metrics: '@val_additional_metrics'\n", + " amp: true\n", + " epoch_length: $@nnunet_trainer.num_val_iterations_per_epoch\n", + " device: '@device'\n", + " inferer: '@validate#inferer'\n", + " key_val_metric: '@val_key_metric'\n", + " network: '@network'\n", + " postprocessing: '@train_postprocessing'\n", + " val_data_loader: '@validate#dataloader'\n", + " val_handlers: '@val_handlers#handlers'\n", + " prepare_batch: \"$src.trainer.prepare_nnunet_batch\"\n", + " inferer:\n", + " _target_: SimpleInferer\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51fae1b1", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/nnUNet/evaluator/evaluator.yaml\n", + "\n", + "#Remove CheckpointSaver from val_handlers\n", + "\n", + "run:\n", + "- \"$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])\"\n", + "- $@validate#evaluator.run()\n", + "\n", + "initialize:\n", + "- \"$setattr(torch.backends.cudnn, 'benchmark', True)\"" + ] + }, + { + "cell_type": "markdown", + "id": "fd192db9", + "metadata": {}, + "source": [ + "## Adding Python Utility Scripts\n", + "\n", + "We finally add the MLFlow and Training utility scripts to the MONAI Bundle." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f4fb6da", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/src/__init__.py\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cea0e36", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/src/mlflow.py\n", + "\n", + "import re\n", + "from monai.handlers import MLFlowHandler\n", + "import yaml\n", + "from monai.bundle import ConfigParser\n", + "from pathlib import Path\n", + "import monai\n", + "import torch\n", + "\n", + "def mlflow_transform(state_output):\n", + " \"\"\"\n", + " Extracts the 'loss' value from the first element of the state_output list.\n", + "\n", + " Parameters\n", + " ----------\n", + " state_output : list of dict\n", + " A list where each element is a dictionary containing various metrics, including 'loss'.\n", + "\n", + " Returns\n", + " -------\n", + " float\n", + " The 'loss' value from the first element of the state_output list.\n", + " \"\"\"\n", + " return state_output[0]['loss']\n", + "\n", + "class MLFlownnUNetHandler(MLFlowHandler):\n", + " \"\"\"\n", + " A handler for logging nnUNet metrics to MLFlow.\n", + " Parameters\n", + " ----------\n", + " label_dict : dict\n", + " A dictionary mapping label indices to label names.\n", + " **kwargs : dict\n", + " Additional keyword arguments passed to the parent class.\n", + " \"\"\"\n", + " def __init__(self, label_dict, **kwargs):\n", + " super(MLFlownnUNetHandler, self).__init__(**kwargs)\n", + " self.label_dict = label_dict\n", + " \n", + " def _default_epoch_log(self, engine) -> None:\n", + " \"\"\"\n", + " Logs the metrics and state attributes at the end of each epoch.\n", + "\n", + " Parameters\n", + " ----------\n", + " engine : Engine\n", + " The engine object that contains the state and metrics to be logged.\n", + "\n", + " Returns\n", + " -------\n", + " None\n", + " \"\"\"\n", + " log_dict = engine.state.metrics\n", + " if not log_dict:\n", + " return\n", + "\n", + " current_epoch = self.global_epoch_transform(engine.state.epoch)\n", + "\n", + " new_log_dict = {}\n", + "\n", + " for metric in log_dict:\n", + " if type(log_dict[metric]) == torch.Tensor:\n", + " for i,val in enumerate(log_dict[metric]):\n", + " new_log_dict[metric+\"_{}\".format(list(self.label_dict.keys())[i+1])] = val\n", + " else:\n", + " new_log_dict[metric] = log_dict[metric]\n", + " self._log_metrics(new_log_dict, step=current_epoch)\n", + "\n", + " if self.state_attributes is not None:\n", + " attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}\n", + " self._log_metrics(attrs, step=current_epoch)\n", + "\n", + "def create_mlflow_experiment_params(params_file, custom_params=None):\n", + " \"\"\"\n", + " Create a dictionary of parameters for an MLflow experiment.\n", + "\n", + " This function reads configuration values from MONAI, GPU information, and a YAML configuration file,\n", + " and combines them into a single dictionary. Optionally, custom parameters can also be added to the dictionary.\n", + "\n", + " Parameters\n", + " ----------\n", + " params_file : str\n", + " Path to the YAML configuration file.\n", + " custom_params : dict, optional\n", + " A dictionary of custom parameters to be added to the final parameters dictionary (default is None).\n", + "\n", + " Returns\n", + " -------\n", + " dict\n", + " A dictionary containing all the combined parameters.\n", + " \"\"\"\n", + " params_dict = {}\n", + " config_values = monai.config.deviceconfig.get_config_values()\n", + " for k in config_values:\n", + " params_dict[re.sub(\"[()]\",\" \",str(k))] = config_values[k]\n", + "\n", + " optional_config_values = monai.config.deviceconfig.get_optional_config_values()\n", + " for k in optional_config_values:\n", + " params_dict[re.sub(\"[()]\",\" \",str(k))] = optional_config_values[k]\n", + "\n", + " gpu_info = monai.config.deviceconfig.get_gpu_info()\n", + " for k in gpu_info:\n", + " params_dict[re.sub(\"[()]\",\" \",str(k))] = str(gpu_info[k])\n", + "\n", + " yaml_config_files = [params_file]\n", + " # %%\n", + " monai_config = {}\n", + " for config_file in yaml_config_files:\n", + " with open(config_file, 'r') as file:\n", + " monai_config.update(yaml.safe_load(file))\n", + "\n", + " monai_config[\"bundle_root\"] = str(Path(Path(params_file).parent).parent)\n", + "\n", + " parser = ConfigParser(monai_config, globals={\"os\": \"os\",\n", + " \"pathlib\": \"pathlib\",\n", + " \"json\": \"json\",\n", + " \"ignite\": \"ignite\"\n", + " })\n", + "\n", + " parser.parse(True)\n", + "\n", + " for k in monai_config:\n", + " params_dict[k] = parser.get_parsed_content(k,instantiate=True)\n", + "\n", + " if custom_params is not None:\n", + " for k in custom_params:\n", + " params_dict[k] = custom_params[k]\n", + " return params_dict\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d679a387", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/src/trainer.py\n", + "\n", + "import os\n", + "\n", + "def subfiles(directory, prefix=None, suffix=None, join=True, sort=True):\n", + " \"\"\"\n", + " List files in a directory with optional filtering by prefix and/or suffix.\n", + " \n", + " Parameters\n", + " ----------\n", + " directory : str\n", + " The path to the directory to list files from.\n", + " prefix : str, optional\n", + " If specified, only files starting with this prefix will be included.\n", + " suffix : str, optional\n", + " If specified, only files ending with this suffix will be included.\n", + " join : bool, optional\n", + " If True, the directory path will be joined with the filenames. Default is True.\n", + " sort : bool, optional\n", + " If True, the list of files will be sorted. Default is True.\n", + " \n", + " Returns\n", + " -------\n", + " list of str\n", + " A list of filenames (with full paths if `join` is True) that match the specified criteria.\n", + " \"\"\"\n", + "\n", + " \n", + " files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]\n", + " if prefix is not None:\n", + " files = [f for f in files if f.startswith(prefix)]\n", + " if suffix is not None:\n", + " files = [f for f in files if f.endswith(suffix)]\n", + " if join:\n", + " files = [os.path.join(directory, f) for f in files]\n", + " if sort:\n", + " files.sort()\n", + " return files\n", + "\n", + "def prepare_nnunet_batch(batch, device, non_blocking):\n", + " \"\"\"\n", + " Prepares a batch of data and targets for nnU-Net training by transferring them to the specified device.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch : dict\n", + " A dictionary containing the data and target tensors. The key \"data\" corresponds to the input data tensor,\n", + " and the key \"target\" corresponds to the target tensor or a list of target tensors.\n", + " device : torch.device\n", + " The device to which the data and target tensors should be transferred (e.g., 'cuda' or 'cpu').\n", + " non_blocking : bool\n", + " If True, allows non-blocking data transfer to the device.\n", + "\n", + " Returns\n", + " -------\n", + " tuple\n", + " A tuple containing the data tensor and the target tensor(s) after being transferred to the specified device.\n", + " \"\"\"\n", + " data = batch[\"data\"].to(device, non_blocking=non_blocking)\n", + " if isinstance(batch[\"target\"], list):\n", + " target = [i.to(device, non_blocking=non_blocking) for i in batch[\"target\"]]\n", + " else:\n", + " target = batch[\"target\"].to(device, non_blocking=non_blocking)\n", + " return data, target\n", + "\n", + "def get_checkpoint(epoch, ckpt_dir):\n", + " \"\"\"\n", + " Retrieves the checkpoint for a given epoch from the checkpoint directory.\n", + "\n", + " Parameters\n", + " ----------\n", + " epoch : int or str\n", + " The epoch number to retrieve. If 'latest', the function will return the latest checkpoint.\n", + " ckpt_dir : str\n", + " The directory where checkpoints are stored.\n", + "\n", + " Returns\n", + " -------\n", + " int\n", + " The epoch number of the checkpoint to be retrieved. If 'latest', returns the latest epoch number.\n", + " \"\"\"\n", + " if epoch == \"latest\":\n", + "\n", + " latest_checkpoints = subfiles(ckpt_dir, prefix=\"checkpoint_epoch\", sort=True,\n", + " join=False)\n", + " epochs = []\n", + " for latest_checkpoint in latest_checkpoints:\n", + " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\"):-len(\".pt\")]))\n", + "\n", + " epochs.sort()\n", + " latest_epoch = epochs[-1]\n", + " return latest_epoch\n", + " else:\n", + " return epoch\n", + "\n", + "def reload_checkpoint(trainer, epoch, num_train_batches_per_epoch, ckpt_dir):\n", + " \"\"\"\n", + " Reloads the checkpoint for a given epoch and updates the trainer's state.\n", + "\n", + " Parameters\n", + " ----------\n", + " trainer : object\n", + " The trainer object whose state needs to be updated.\n", + " epoch : int\n", + " The epoch number to load the checkpoint from.\n", + " num_train_batches_per_epoch : int\n", + " The number of training batches per epoch.\n", + " ckpt_dir : str\n", + " The directory where the checkpoints are stored.\n", + "\n", + " Returns\n", + " -------\n", + " None\n", + " \"\"\"\n", + "\n", + " epoch_to_load = get_checkpoint(epoch, ckpt_dir)\n", + " trainer.state.epoch = epoch_to_load\n", + " trainer.state.iteration = (epoch_to_load* num_train_batches_per_epoch) +1\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7122458f", + "metadata": {}, + "outputs": [], + "source": [ + "def create_config(config_folder, output_file):\n", + " config_files = [f.path for f in os.scandir(config_folder) if f.path.endswith(\".yaml\")]\n", + " config = {}\n", + " for config_file in config_files:\n", + " with open(config_file, 'r') as file:\n", + " config.update(yaml.safe_load(file))\n", + "\n", + " if output_file.endswith(\".yaml\"):\n", + " with open(output_file, 'w') as file:\n", + " yaml.dump(config, file)\n", + " if output_file.endswith(\".json\"):\n", + " with open(output_file, 'w') as file:\n", + " json.dump(config, file)\n", + "\n", + " return config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8218fc23", + "metadata": {}, + "outputs": [], + "source": [ + "config = create_config(\"nnUNetBundle/nnUNet\", \"nnUNetBundle/configs/train.yaml\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6453aa35", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "export BUNDLE_ROOT=nnUNetBundle\n", + "export PYTHONPATH=$PYTHONPATH:$BUNDLE_ROOT\n", + "\n", + "python -m monai.bundle run \\\n", + " --bundle_root $BUNDLE_ROOT \\\n", + " --config_file $BUNDLE_ROOT/configs/train.yaml\n", + "\n", + "#Option to resume training\n", + "#--config_file \"['$BUNDLE_ROOT/configs/train.yaml','$BUNDLE_ROOT/configs/train_resume.yaml']\"\n", + "\n", + "# Log to Local MLFlow\n", + "#--tracking_uri mlruns\n" + ] + }, + { + "cell_type": "markdown", + "id": "3f221410", + "metadata": {}, + "source": [ + "## Inference\n", + "\n", + "After training the nnUNet model, we can then perform inference on new data. We use a `ModelnnUNetWrapper` as a wrapper around the nnUNet model to perform inference from the MONAI Bundle. In this way, the nnUNet preprocessing, inference and postprocessing steps are handled by the `ModelnnUNetWrapper`, with the Bundle blocks only needing to handle the input data loading and sending to the nnUnet block and the nnUNet prediction postprocessing.\n", + "\n", + "The `ModelnnUNetWrapper` receives as input the data dictionary loaded by the DataLoader, and returns the model predictions as a MetaTensor.\n", + "\n", + "To get the `ModelnnUNetWrapper` object, we can use the `get_nnunet_monai_predictor` function, which receives the following parameters:\n", + "\n", + "- `model_folder`: The path to the nnUNet model folder.\n", + "- `model_name`: [Optional] The name of the model to be loaded. If not provided, the function will load the checkpoint named `model.pt`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f22a7868", + "metadata": {}, + "outputs": [], + "source": [ + "# To Select the lastest checkpoint\n", + "\n", + "from nnUNetBundle.src.trainer import get_checkpoint\n", + "\n", + "ckpt_epoch = get_checkpoint(\"latest\", \"nnUNetBundle/models\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6adfa0a0", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.bundle.nnunet import get_nnunet_monai_predictor\n", + "\n", + "nnunet_config = {\n", + " \"model_folder\": \"nnUNetBundle/models\",\n", + "}\n", + "\n", + "monai_predictor = get_nnunet_monai_predictor(**nnunet_config, model_name=f\"checkpoint_epoch={ckpt_epoch}.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "d9258c3b", + "metadata": {}, + "source": [ + "## Test Data Preparation\n", + "\n", + "The Bundle accepts the test dataset in the following format:\n", + "\n", + "```bash\n", + "Dataset\n", + "├── Case1\n", + "│ └── Case1.nii.gz\n", + "├── Case2\n", + "│ └── Case2.nii.gz\n", + "└── Case3\n", + " └── Case3.nii.gz\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07229e86", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "mkdir -p nnUNetBundle/test_input/spleen_1\n", + "mkdir -p nnUNetBundle/test_output\n", + "\n", + "cp MONAI/Data/Task09_Spleen/imagesTs/spleen_1.nii.gz nnUNetBundle/test_input/spleen_1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "308abf67", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "tree nnUNetBundle/test_input" + ] + }, + { + "cell_type": "markdown", + "id": "6d7e5b73", + "metadata": {}, + "source": [ + "### Data Loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f10d59d7", + "metadata": {}, + "outputs": [], + "source": [ + "def get_subfolder_dataset(data_dir,modality_conf):\n", + " data_list = []\n", + " for f in os.scandir(data_dir):\n", + "\n", + " if f.is_dir():\n", + " subject_dict = {key:str(pathlib.Path(f.path).joinpath(f.name+modality_conf[key]['suffix'])) for key in modality_conf}\n", + " data_list.append(subject_dict)\n", + " return data_list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f3972b9", + "metadata": {}, + "outputs": [], + "source": [ + "modalities = {\n", + " \"image\": {\"suffix\": \".nii.gz\"},\n", + "}\n", + "\n", + "data = get_subfolder_dataset(\"nnUNetBundle/test_input\",modalities)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e24a629", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.transforms import LoadImaged\n", + "from monai.data import Dataset, DataLoader\n", + "\n", + "preprocessing = LoadImaged(keys=[\"image\"],ensure_channel_first=True, image_only=False)\n", + "\n", + "\n", + "test_dataset = Dataset(data,transform=preprocessing)\n", + "\n", + "test_loader = DataLoader(test_dataset, batch_size=1)" + ] + }, + { + "cell_type": "markdown", + "id": "8c637fba", + "metadata": {}, + "source": [ + "### Test ModelnnUNetWrapper\n", + "\n", + "To test the `ModelnnUNetWrapper`, we can provide a test case to the `ModelnnUNetWrapper` and extract the model predictions returned by the wrapper." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2386fc9c", + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(iter(test_loader))\n", + "\n", + "pred = monai_predictor(batch[\"image\"])" + ] + }, + { + "cell_type": "markdown", + "id": "3e010b7e", + "metadata": {}, + "source": [ + "### Postprocessing and Save Predictions\n", + "\n", + "After obtaining the model predictions, we can apply postprocessing transformations to the predictions and save the results to disk.\n", + "\n", + "The `Transposed` transform is required to unify the axis order convention between MONAI and nnUNet. The nnUNet model uses the `zyx` axis order, while MONAI uses the `xyz` axis order." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccd5a438", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.transforms import Compose, Transposed, SaveImaged, Decollated\n", + "\n", + "\n", + "postprocessing = Compose([\n", + " #Decollated(keys=None, detach=True),\n", + " Transposed(keys=\"pred\",indices=[0,3,2,1]),\n", + " SaveImaged(keys=\"pred\",\n", + " output_dir=\"nnUNetBundle/test_output\",\n", + " output_postfix=\"prediction\",\n", + " meta_keys=\"image_meta_dict\",\n", + " )\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d82112a2", + "metadata": {}, + "outputs": [], + "source": [ + "postprocessing({\"pred\":pred})" + ] + }, + { + "cell_type": "markdown", + "id": "9c85dd88", + "metadata": {}, + "source": [ + "## Evaluator\n", + "\n", + "Combining everything together, we can create an `Evaluator` that encapsulates the data loading, model inference, postprocessing, and evaluation steps. The `Evaluator` can be used to evaluate the model on the test dataset ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbf1fec9", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.engines import SupervisedEvaluator\n", + "\n", + "validator = SupervisedEvaluator(\n", + " val_data_loader=test_loader,\n", + " device = \"cuda:0\",\n", + " network = monai_predictor,\n", + " postprocessing= postprocessing\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67970bc2", + "metadata": {}, + "outputs": [], + "source": [ + "validator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f63d8dd8", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/configs/inference.yaml\n", + "\n", + "imports: \n", + " - $import json\n", + " - $from pathlib import Path\n", + " - $import os\n", + " - $import monai.bundle.nnunet\n", + " - $from ignite.contrib.handlers.tqdm_logger import ProgressBar\n", + " - $import shutil\n", + " - $import src\n", + " - $import src.dataset\n", + "\n", + "\n", + "output_dir: \".\"\n", + "bundle_root: \".\"\n", + "data_list_file : \".\"\n", + "data_dir: \".\"\n", + "\n", + "prediction_suffix: \"prediction\"\n", + "\n", + "\n", + "modality_conf:\n", + " image:\n", + " suffix: \".nii.gz\"\n", + "\n", + "test_data_list: \"$src.dataset.get_subfolder_dataset(@data_dir,@modality_conf)\"\n", + "#test_data_list: \"$monai.data.load_decathlon_datalist(@data_list_file, is_segmentation=True, data_list_key='testing', base_dir=@data_dir)\"\n", + "image_modality_keys: \"$list(@modality_conf.keys())\"\n", + "image_key: \"image\"\n", + "image_suffix: \"@image_key\"\n", + "\n", + "preprocessing:\n", + " _target_: Compose\n", + " transforms:\n", + " - _target_: LoadImaged\n", + " keys: \"image\"\n", + " ensure_channel_first: True\n", + " image_only: False\n", + "\n", + "test_dataset:\n", + " _target_: Dataset\n", + " data: \"$@test_data_list\"\n", + " transform: \"@preprocessing\"\n", + "\n", + "test_loader:\n", + " _target_: DataLoader\n", + " dataset: \"@test_dataset\"\n", + " batch_size: 1\n", + "\n", + "\n", + "device: \"$torch.device('cuda')\"\n", + "\n", + "nnunet_config:\n", + " model_folder: \"$@bundle_root + '/models'\"\n", + "\n", + "network_def: \"$monai.bundle.nnunet.get_nnunet_monai_predictor(**@nnunet_config)\"\n", + "\n", + "postprocessing:\n", + " _target_: \"Compose\"\n", + " transforms:\n", + " - _target_: Transposed\n", + " keys: \"pred\"\n", + " indices:\n", + " - 0\n", + " - 3\n", + " - 2\n", + " - 1\n", + " - _target_: SaveImaged\n", + " keys: \"pred\"\n", + " resample: False\n", + " output_postfix: \"@prediction_suffix\"\n", + " output_dir: \"@output_dir\"\n", + " meta_keys: \"image_meta_dict\"\n", + "\n", + "\n", + "testing:\n", + " dataloader: \"$@test_loader\"\n", + " pbar:\n", + " _target_: \"ignite.contrib.handlers.tqdm_logger.ProgressBar\"\n", + " test_inferer: \"$@inferer\"\n", + "\n", + "inferer: \n", + " _target_: \"SimpleInferer\"\n", + "\n", + "validator:\n", + " _target_: \"SupervisedEvaluator\"\n", + " postprocessing: \"$@postprocessing\"\n", + " device: \"$@device\"\n", + " inferer: \"$@testing#test_inferer\"\n", + " val_data_loader: \"$@testing#dataloader\"\n", + " network: \"@network_def\"\n", + " val_handlers:\n", + " - _target_: \"CheckpointLoader\"\n", + " load_path: \"$@bundle_root+'/models/model.pt'\"\n", + " load_dict:\n", + " network_weights: '$@network_def.network_weights'\n", + "run:\n", + " - \"$@testing#pbar.attach(@validator)\"\n", + " - \"$@validator.run()\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62668710", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/src/dataset.py\n", + "\n", + "import pathlib\n", + "import os\n", + "\n", + "def get_subfolder_dataset(data_dir,modality_conf):\n", + " data_list = []\n", + " for f in os.scandir(data_dir):\n", + "\n", + " if f.is_dir():\n", + " subject_dict = {key:str(pathlib.Path(f.path).joinpath(f.name+modality_conf[key]['suffix'])) for key in modality_conf}\n", + " data_list.append(subject_dict)\n", + " return data_list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b218dfd3", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "\n", + "export BUNDLE_ROOT=nnUNetBundle\n", + "export PYTHONPATH=$PYTHONPATH:$BUNDLE_ROOT\n", + "\n", + "python -m monai.bundle run \\\n", + " --config-file $BUNDLE_ROOT/configs/inference.yaml \\\n", + " --bundle-root $BUNDLE_ROOT \\\n", + " --data-dir $BUNDLE_ROOT/test_input \\\n", + " --output-dir $BUNDLE_ROOT/test_output \\\n", + " --logging-file $BUNDLE_ROOT/configs/logging.conf" + ] + }, + { + "cell_type": "markdown", + "id": "0f80f44e", + "metadata": {}, + "source": [ + "## Utilities" + ] + }, + { + "cell_type": "markdown", + "id": "e7d75b56-80b8-412c-97d8-b79e1111caa0", + "metadata": {}, + "source": [ + "### MONAI Bundle to nnUNet Conversion" + ] + }, + { + "cell_type": "markdown", + "id": "d65a50b3-063d-412b-a8e0-5ec45c003925", + "metadata": {}, + "source": [ + "To convert a MONAI Bundle to a nnUNet Bundle, we need to combine the MONAI checkpoint with the nnUNet checkpoint. This is done by loading the MONAI checkpoint and the nnUNet checkpoint, and updating the nnUNet model weights with the MONAI model weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba864564", + "metadata": {}, + "outputs": [], + "source": [ + "from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n", + "from nnunetv2.training.logging.nnunet_logger import nnUNetLogger\n", + "import shutil\n", + "\n", + "def subfiles(directory, prefix=None, suffix=None, join=True, sort=True):\n", + " \"\"\"\n", + " List files in a directory with optional filtering by prefix and/or suffix.\n", + " \n", + " Parameters\n", + " ----------\n", + " directory : str\n", + " The path to the directory to list files from.\n", + " prefix : str, optional\n", + " If specified, only files starting with this prefix will be included.\n", + " suffix : str, optional\n", + " If specified, only files ending with this suffix will be included.\n", + " join : bool, optional\n", + " If True, the directory path will be joined with the filenames. Default is True.\n", + " sort : bool, optional\n", + " If True, the list of files will be sorted. Default is True.\n", + " \n", + " Returns\n", + " -------\n", + " list of str\n", + " A list of filenames (with full paths if `join` is True) that match the specified criteria.\n", + " \"\"\"\n", + "\n", + " \n", + " files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]\n", + " if prefix is not None:\n", + " files = [f for f in files if f.startswith(prefix)]\n", + " if suffix is not None:\n", + " files = [f for f in files if f.endswith(suffix)]\n", + " if join:\n", + " files = [os.path.join(directory, f) for f in files]\n", + " if sort:\n", + " files.sort()\n", + " return files\n", + "\n", + "\n", + "def convert_monai_bundle_to_nnunet(nnunet_config, bundle_path):\n", + "\n", + " nnunet_trainer = \"nnUNetTrainer\"\n", + " nnunet_plans = \"nnUNetPlans\"\n", + "\n", + " if \"nnunet_trainer\" in nnunet_config:\n", + " nnunet_trainer = nnunet_config[\"nnunet_trainer\"]\n", + "\n", + " if \"nnunet_plans\" in nnunet_config:\n", + " nnunet_plans = nnunet_config[\"nnunet_plans\"]\n", + "\n", + " nnunet_model_folder = Path(os.environ[\"nnUNet_results\"]).joinpath(\n", + " maybe_convert_to_dataset_name(nnunet_config[\"dataset_name_or_id\"]),\n", + " f\"{nnunet_trainer}__{nnunet_plans}__3d_fullres\")\n", + " \n", + " nnunet_preprocess_model_folder = Path(os.environ[\"nnUNet_preprocessed\"]).joinpath(\n", + " maybe_convert_to_dataset_name(nnunet_config[\"dataset_name_or_id\"]))\n", + " \n", + " Path(nnunet_model_folder).joinpath(\"fold_0\").mkdir(parents=True, exist_ok=True)\n", + "\n", + "\n", + " nnunet_checkpoint = torch.load(f\"{bundle_path}/models/nnunet_checkpoint.pth\")\n", + " latest_checkpoints = subfiles(Path(bundle_path).joinpath(\"models\"),prefix=\"checkpoint_epoch\",sort=True,join=False)\n", + " epochs = []\n", + " for latest_checkpoint in latest_checkpoints:\n", + " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\"):-len(\".pt\")]))\n", + "\n", + " epochs.sort()\n", + " final_epoch = epochs[-1]\n", + " monai_last_checkpoint = torch.load(f\"{bundle_path}/models/checkpoint_epoch={final_epoch}.pt\")\n", + "\n", + " best_checkpoints = subfiles(Path(bundle_path).joinpath(\"models\"), prefix=\"checkpoint_key_metric\", sort=True,\n", + " join=False)\n", + " key_metrics = []\n", + " for best_checkpoint in best_checkpoints:\n", + " key_metrics.append(str(best_checkpoint[len(\"checkpoint_key_metric=\"):-len(\".pt\")]))\n", + "\n", + " key_metrics.sort()\n", + " best_key_metric = key_metrics[-1]\n", + " monai_best_checkpoint = torch.load(f\"{bundle_path}/models/checkpoint_key_metric={best_key_metric}.pt\")\n", + "\n", + " nnunet_checkpoint['optimizer_state'] = monai_last_checkpoint['optimizer_state']\n", + "\n", + "\n", + "\n", + " nnunet_checkpoint['network_weights'] = odict()\n", + "\n", + " for key in monai_last_checkpoint['network_weights']:\n", + " nnunet_checkpoint['network_weights'][key] = monai_last_checkpoint['network_weights'][key]\n", + "\n", + " nnunet_checkpoint['current_epoch'] = final_epoch\n", + " nnunet_checkpoint['logging'] = nnUNetLogger().get_checkpoint()\n", + " nnunet_checkpoint['_best_ema'] = 0\n", + " nnunet_checkpoint['grad_scaler_state'] = None\n", + "\n", + "\n", + "\n", + " torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(\"fold_0\",\"checkpoint_final.pth\"))\n", + "\n", + " nnunet_checkpoint['network_weights'] = odict()\n", + "\n", + " nnunet_checkpoint['optimizer_state'] = monai_best_checkpoint['optimizer_state']\n", + "\n", + " for key in monai_best_checkpoint['network_weights']:\n", + " nnunet_checkpoint['network_weights'][key] = \\\n", + " monai_best_checkpoint['network_weights'][key]\n", + "\n", + " torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(\"fold_0\", \"checkpoint_best.pth\"))\n", + "\n", + "\n", + " shutil.move(f\"{bundle_path}/models/dataset.json\",nnunet_model_folder)\n", + " shutil.move(f\"{bundle_path}/models/plans.json\",nnunet_model_folder)\n", + " shutil.move(f\"{nnunet_preprocess_model_folder}/dataset_fingerprint.json\",nnunet_model_folder)\n", + " shutil.move(f\"{bundle_path}/models/nnunet_checkpoint.pth\",nnunet_model_folder)\n", + " shutil.move(f\"{bundle_path}/models/checkpoint_epoch={final_epoch}.pt\",f\"{bundle_path}/models/model.pt\")\n", + " shutil.move(f\"{bundle_path}/models/checkpoint_key_metric={best_key_metric}.pt\",f\"{bundle_path}/models/best_model.pt\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de36e0d6", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"nnUNet_results\"] = \"MONAI/Data/nnUNet/nnUNet_trained_models\"\n", + "os.environ[\"nnUNet_raw\"] = \"MONAI/Data/nnUNet/nnUNet_raw_data_base\"\n", + "os.environ[\"nnUNet_preprocessed\"] = \"MONAI/Data/nnUNet/nnUNet_preprocessed\"\n", + "\n", + "nnunet_config = {\n", + " \"dataset_name_or_id\": \"009\",\n", + " \"nnunet_trainer\": \"nnUNetTrainer\",\n", + "}\n", + "\n", + "convert_monai_bundle_to_nnunet(nnunet_config, \"nnUNetBundle\")" + ] + }, + { + "cell_type": "markdown", + "id": "47355821", + "metadata": {}, + "source": [ + "### Testing the nnUNet Model\n", + "\n", + "We now test the nnUNet model by performing inference on the test dataset and evaluating the model predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2dd2920", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.bundle.config_parser import ConfigParser\n", + "from monai.apps.nnunet import nnUNetV2Runner\n", + "\n", + "\n", + "root_dir = \"MONAI/Data\"\n", + "nnunet_root_dir = os.path.join(root_dir, \"nnUNet\")\n", + "\n", + "os.makedirs(nnunet_root_dir, exist_ok=True)\n", + "\n", + "data_src_cfg = os.path.join(nnunet_root_dir, \"data_src_cfg.yaml\")\n", + "data_src = {\n", + " \"modality\": \"CT\",\n", + " \"dataset_name_or_id\": \"09\",\n", + " \"datalist\": os.path.join(root_dir, \"Task09_Spleen/msd_task09_spleen_folds.json\"),\n", + " \"dataroot\": os.path.join(root_dir, \"Task09_Spleen\"),\n", + "}\n", + "\n", + "ConfigParser.export_config_file(data_src, data_src_cfg)\n", + "\n", + "runner = nnUNetV2Runner(\n", + " input_config=data_src_cfg, trainer_class_name=\"nnUNetTrainer\", work_dir=nnunet_root_dir\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3559f95d", + "metadata": {}, + "outputs": [], + "source": [ + "runner.train_single_model(config=\"3d_fullres\", fold=0, val=\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4802257", + "metadata": {}, + "outputs": [], + "source": [ + "runner.find_best_configuration(configs=[\"3d_fullres\"],folds=[0],allow_ensembling=False,num_processes=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efd19e64", + "metadata": {}, + "outputs": [], + "source": [ + "runner.predict_ensemble_postprocessing(folds=[0],run_ensemble=False,run_postprocessing=False)" + ] + }, + { + "cell_type": "markdown", + "id": "8190d533", + "metadata": {}, + "source": [ + "### nnUNet to MONAI Bundle Conversion\n", + "\n", + "To convert a nnUNet trained Model to a MONAI Bundle, we need to separate the MONAI checkpoint from the nnUNet checkpoint. This is done by loading the nnUNet checkpoint and the MONAI checkpoint, and updating the MONAI model weights with the nnUNet model weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4e79429", + "metadata": {}, + "outputs": [], + "source": [ + "from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", + "import os\n", + "\n", + "os.environ[\"nnUNet_results\"] = \"MONAI/Data/nnUNet/nnUNet_trained_models\"\n", + "os.environ[\"nnUNet_raw\"] = \"MONAI/Data/nnUNet/nnUNet_raw_data_base\"\n", + "os.environ[\"nnUNet_preprocessed\"] = \"MONAI/Data/nnUNet/nnUNet_preprocessed\"\n", + "\n", + "nnunet_config = {\n", + " \"dataset_name_or_id\": \"009\",\n", + " \"nnunet_trainer\": \"nnUNetTrainer_10epochs\",\n", + "}\n", + "\n", + "bundle_root = \"nnUNetBundle\"\n", + "\n", + "convert_nnunet_to_monai_bundle(nnunet_config, bundle_root, 0)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "MONAI", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a46cb8c1875e8cb628bfd17203dedcc097e76f54 Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 12 Feb 2025 17:33:17 +0000 Subject: [PATCH 07/12] DCO Remediation Commit for simben I, simben , hereby add my Signed-off-by to this commit: e65a170f09d3041be0f362dada707b0db5a0fcff Signed-off-by: simben --- bundle/nnUNet_Bundle.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bundle/nnUNet_Bundle.ipynb b/bundle/nnUNet_Bundle.ipynb index 02c5b4602..2e1857be0 100644 --- a/bundle/nnUNet_Bundle.ipynb +++ b/bundle/nnUNet_Bundle.ipynb @@ -2433,7 +2433,7 @@ "\n", "bundle_root = \"nnUNetBundle\"\n", "\n", - "convert_nnunet_to_monai_bundle(nnunet_config, bundle_root, 0)" + "convert_nnunet_to_monai_bundle(nnunet_config, bundle_root, 0)\n" ] } ], From 7be4718dd7c698bdd861b2db038d43138f906bd1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Feb 2025 17:36:12 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- bundle/06_nnunet_monai_bundle.ipynb | 2 +- bundle/nnUNet_Bundle.ipynb | 480 +++++++++++++--------------- 2 files changed, 224 insertions(+), 258 deletions(-) diff --git a/bundle/06_nnunet_monai_bundle.ipynb b/bundle/06_nnunet_monai_bundle.ipynb index 1482340ac..cbd9998ef 100644 --- a/bundle/06_nnunet_monai_bundle.ipynb +++ b/bundle/06_nnunet_monai_bundle.ipynb @@ -194,7 +194,7 @@ "metadata": {}, "outputs": [], "source": [ - "datalist_file = Path(root_dir).joinpath(\"Task09_Spleen\",\"Task09_Spleen_folds.json\")\n", + "datalist_file = Path(root_dir).joinpath(\"Task09_Spleen\", \"Task09_Spleen_folds.json\")\n", "with open(datalist_file, \"w\", encoding=\"utf-8\") as f:\n", " json.dump(datalist_json, f, ensure_ascii=False, indent=4)\n", "print(f\"Datalist is saved to {datalist_file}\")" diff --git a/bundle/nnUNet_Bundle.ipynb b/bundle/nnUNet_Bundle.ipynb index 2e1857be0..df5ed364d 100644 --- a/bundle/nnUNet_Bundle.ipynb +++ b/bundle/nnUNet_Bundle.ipynb @@ -29,7 +29,17 @@ "source": [ "import torch\n", "from monai.data import Dataset\n", - "from monai.handlers import StatsHandler, from_engine, MeanDice, ValidationHandler, LrScheduleHandler, CheckpointSaver, CheckpointLoader, TensorBoardStatsHandler, MLFlowHandler\n", + "from monai.handlers import (\n", + " StatsHandler,\n", + " from_engine,\n", + " MeanDice,\n", + " ValidationHandler,\n", + " LrScheduleHandler,\n", + " CheckpointSaver,\n", + " CheckpointLoader,\n", + " TensorBoardStatsHandler,\n", + " MLFlowHandler,\n", + ")\n", "from monai.engines import SupervisedTrainer, SupervisedEvaluator\n", "\n", "from monai.transforms import Compose, Lambdad, Activationsd, AsDiscreted\n", @@ -42,7 +52,7 @@ "from monai.bundle import ConfigParser\n", "import monai\n", "from pathlib import Path\n", - "from odict import odict\n" + "from odict import odict" ] }, { @@ -147,7 +157,7 @@ "outputs": [], "source": [ "train_dataloader = nnunet_trainer.dataloader_train\n", - "train_data = [{'case_identifier':k} for k in nnunet_trainer.dataloader_train.generator._data.dataset.keys()]\n", + "train_data = [{\"case_identifier\": k} for k in nnunet_trainer.dataloader_train.generator._data.dataset.keys()]\n", "train_dataset = Dataset(data=train_data)" ] }, @@ -159,7 +169,7 @@ "outputs": [], "source": [ "val_dataloader = nnunet_trainer.dataloader_val\n", - "val_data = [{'case_identifier':k} for k in nnunet_trainer.dataloader_val.generator._data.dataset.keys()]\n", + "val_data = [{\"case_identifier\": k} for k in nnunet_trainer.dataloader_val.generator._data.dataset.keys()]\n", "val_dataset = Dataset(data=val_data)" ] }, @@ -219,7 +229,7 @@ "metadata": {}, "outputs": [], "source": [ - "image, label = prepare_nnunet_batch(next(iter(train_dataloader)),device=\"cpu\",non_blocking=True)" + "image, label = prepare_nnunet_batch(next(iter(train_dataloader)), device=\"cpu\", non_blocking=True)" ] }, { @@ -239,12 +249,7 @@ "metadata": {}, "outputs": [], "source": [ - "train_handlers = [\n", - " StatsHandler(\n", - " output_transform= from_engine(['loss'], first=True),\n", - " tag_name= \"train_loss\"\n", - " )\n", - "]\n" + "train_handlers = [StatsHandler(output_transform=from_engine([\"loss\"], first=True), tag_name=\"train_loss\")]" ] }, { @@ -266,16 +271,16 @@ "outputs": [], "source": [ "trainer = SupervisedTrainer(\n", - " amp= True,\n", - " device = device,\n", - " epoch_length = iterations,\n", - " loss_function = loss,\n", - " max_epochs = epochs,\n", - " network = network,\n", - " prepare_batch = prepare_nnunet_batch,\n", - " optimizer = optimizer,\n", - " train_data_loader = train_dataloader,\n", - " train_handlers= train_handlers\n", + " amp=True,\n", + " device=device,\n", + " epoch_length=iterations,\n", + " loss_function=loss,\n", + " max_epochs=epochs,\n", + " network=network,\n", + " prepare_batch=prepare_nnunet_batch,\n", + " optimizer=optimizer,\n", + " train_data_loader=train_dataloader,\n", + " train_handlers=train_handlers,\n", ")" ] }, @@ -306,20 +311,15 @@ "metadata": {}, "outputs": [], "source": [ - "val_key_metric = MeanDice(\n", - " output_transform = from_engine(['pred', 'label']),\n", - " reduction = \"mean\",\n", - " include_background = False\n", - "\n", - ")\n", + "val_key_metric = MeanDice(output_transform=from_engine([\"pred\", \"label\"]), reduction=\"mean\", include_background=False)\n", "\n", "additional_metrics = {\n", - " \"Val_Dice_Per_Class\": MeanDice(\n", - " output_transform = from_engine(['pred', 'label']),\n", - " reduction = \"mean_batch\",\n", - " include_background = False,\n", - " )\n", - " }" + " \"Val_Dice_Per_Class\": MeanDice(\n", + " output_transform=from_engine([\"pred\", \"label\"]),\n", + " reduction=\"mean_batch\",\n", + " include_background=False,\n", + " )\n", + "}" ] }, { @@ -341,26 +341,14 @@ "\n", "postprocessing = Compose(\n", " transforms=[\n", - " ## Extract only high-res predictions from Deep Supervision\n", - " Lambdad( \n", - " keys= [\"pred\",\"label\"],\n", - " func = lambda x: x[0]\n", - " ),\n", - " ## Apply Softmax to the predictions\n", - " Activationsd(\n", - " keys= \"pred\",\n", - " softmax= True\n", - " ),\n", - " ## Binarize the predictions\n", - " AsDiscreted(\n", - " keys= \"pred\",\n", - " threshold= 0.5\n", - " ),\n", - " ## Convert the labels to one-hot\n", - " AsDiscreted(\n", - " keys= \"label\",\n", - " to_onehot= num_classes\n", - " )\n", + " ## Extract only high-res predictions from Deep Supervision\n", + " Lambdad(keys=[\"pred\", \"label\"], func=lambda x: x[0]),\n", + " ## Apply Softmax to the predictions\n", + " Activationsd(keys=\"pred\", softmax=True),\n", + " ## Binarize the predictions\n", + " AsDiscreted(keys=\"pred\", threshold=0.5),\n", + " ## Convert the labels to one-hot\n", + " AsDiscreted(keys=\"label\", to_onehot=num_classes),\n", " ]\n", ")" ] @@ -372,9 +360,7 @@ "metadata": {}, "outputs": [], "source": [ - "val_handlers = [StatsHandler(\n", - " iteration_log = False\n", - ")]" + "val_handlers = [StatsHandler(iteration_log=False)]" ] }, { @@ -396,16 +382,16 @@ "outputs": [], "source": [ "evaluator = SupervisedEvaluator(\n", - " amp= True,\n", - " device = device,\n", - " epoch_length = val_iterations,\n", - " network = network,\n", + " amp=True,\n", + " device=device,\n", + " epoch_length=val_iterations,\n", + " network=network,\n", " key_val_metric={\"Val_Dice\": val_key_metric},\n", - " prepare_batch= prepare_nnunet_batch,\n", - " val_data_loader = val_dataloader,\n", - " val_handlers= val_handlers,\n", - " postprocessing= postprocessing,\n", - " additional_metrics= additional_metrics,\n", + " prepare_batch=prepare_nnunet_batch,\n", + " val_data_loader=val_dataloader,\n", + " val_handlers=val_handlers,\n", + " postprocessing=postprocessing,\n", + " additional_metrics=additional_metrics,\n", ")" ] }, @@ -424,13 +410,7 @@ "metadata": {}, "outputs": [], "source": [ - "train_handlers.append(\n", - " ValidationHandler(\n", - " epoch_level = True,\n", - " interval= val_interval,\n", - " validator = evaluator\n", - " )\n", - ")" + "train_handlers.append(ValidationHandler(epoch_level=True, interval=val_interval, validator=evaluator))" ] }, { @@ -448,19 +428,14 @@ "metadata": {}, "outputs": [], "source": [ - "train_key_metric = MeanDice(\n", - " output_transform = from_engine(['pred', 'label']),\n", - " reduction = \"mean\",\n", - " include_background = False\n", - "\n", - ")\n", + "train_key_metric = MeanDice(output_transform=from_engine([\"pred\", \"label\"]), reduction=\"mean\", include_background=False)\n", "\n", "additional_metrics = {\n", " \"Train_Dice_Per_Class\": MeanDice(\n", - " output_transform = from_engine(['pred', 'label']),\n", - " reduction = \"mean_batch\",\n", - " include_background = False,\n", - " )\n", + " output_transform=from_engine([\"pred\", \"label\"]),\n", + " reduction=\"mean_batch\",\n", + " include_background=False,\n", + " )\n", "}" ] }, @@ -472,19 +447,19 @@ "outputs": [], "source": [ "trainer = SupervisedTrainer(\n", - " amp= True,\n", - " device = device,\n", - " epoch_length = iterations,\n", - " loss_function = loss,\n", - " max_epochs = epochs,\n", - " network = network,\n", - " prepare_batch = prepare_nnunet_batch,\n", - " optimizer = optimizer,\n", - " train_data_loader = train_dataloader,\n", - " train_handlers= train_handlers,\n", - " key_train_metric = {\"Train_Dice\": train_key_metric},\n", - " postprocessing= postprocessing,\n", - " additional_metrics = additional_metrics\n", + " amp=True,\n", + " device=device,\n", + " epoch_length=iterations,\n", + " loss_function=loss,\n", + " max_epochs=epochs,\n", + " network=network,\n", + " prepare_batch=prepare_nnunet_batch,\n", + " optimizer=optimizer,\n", + " train_data_loader=train_dataloader,\n", + " train_handlers=train_handlers,\n", + " key_train_metric={\"Train_Dice\": train_key_metric},\n", + " postprocessing=postprocessing,\n", + " additional_metrics=additional_metrics,\n", ")" ] }, @@ -515,12 +490,7 @@ "metadata": {}, "outputs": [], "source": [ - "train_handlers.append(\n", - " LrScheduleHandler(\n", - " lr_scheduler = lr_scheduler,\n", - " print_lr = True\n", - " )\n", - ")" + "train_handlers.append(LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True))" ] }, { @@ -531,19 +501,19 @@ "outputs": [], "source": [ "trainer = SupervisedTrainer(\n", - " amp= True,\n", - " device = device,\n", - " epoch_length = iterations,\n", - " loss_function = loss,\n", - " max_epochs = epochs,\n", - " network = network,\n", - " prepare_batch = prepare_nnunet_batch,\n", - " optimizer = optimizer,\n", - " train_data_loader = train_dataloader,\n", - " train_handlers= train_handlers,\n", - " key_train_metric = {\"Train_Dice\": train_key_metric},\n", - " postprocessing= postprocessing,\n", - " additional_metrics = additional_metrics\n", + " amp=True,\n", + " device=device,\n", + " epoch_length=iterations,\n", + " loss_function=loss,\n", + " max_epochs=epochs,\n", + " network=network,\n", + " prepare_batch=prepare_nnunet_batch,\n", + " optimizer=optimizer,\n", + " train_data_loader=train_dataloader,\n", + " train_handlers=train_handlers,\n", + " key_train_metric={\"Train_Dice\": train_key_metric},\n", + " postprocessing=postprocessing,\n", + " additional_metrics=additional_metrics,\n", ")" ] }, @@ -589,14 +559,18 @@ "\n", "val_handlers.append(\n", " CheckpointSaver(\n", - " save_dir= ckpt_dir,\n", - " save_dict= {\"network_weights\": nnunet_trainer.network._orig_mod, \"optimizer_state\": nnunet_trainer.optimizer, \"scheduler\": nnunet_trainer.lr_scheduler},\n", - " #save_final= True,\n", - " save_interval= 1,\n", - " save_key_metric= True,\n", - " #final_filename= \"model_final.pt\",\n", - " #key_metric_filename= \"model.pt\",\n", - " n_saved= 1\n", + " save_dir=ckpt_dir,\n", + " save_dict={\n", + " \"network_weights\": nnunet_trainer.network._orig_mod,\n", + " \"optimizer_state\": nnunet_trainer.optimizer,\n", + " \"scheduler\": nnunet_trainer.lr_scheduler,\n", + " },\n", + " # save_final= True,\n", + " save_interval=1,\n", + " save_key_metric=True,\n", + " # final_filename= \"model_final.pt\",\n", + " # key_metric_filename= \"model.pt\",\n", + " n_saved=1,\n", " )\n", ")" ] @@ -618,7 +592,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def subfiles(folder, prefix=None, suffix=None, join=True, sort=True):\n", " files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]\n", " if prefix is not None:\n", @@ -631,14 +604,14 @@ " files = [os.path.join(folder, f) for f in files]\n", " return files\n", "\n", + "\n", "def get_checkpoint(epoch, ckpt_dir):\n", " if epoch == \"latest\":\n", "\n", - " latest_checkpoints = subfiles(ckpt_dir, prefix=\"checkpoint_epoch\", sort=True,\n", - " join=False)\n", + " latest_checkpoints = subfiles(ckpt_dir, prefix=\"checkpoint_epoch\", sort=True, join=False)\n", " epochs = []\n", " for latest_checkpoint in latest_checkpoints:\n", - " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\"):-len(\".pt\")]))\n", + " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\") : -len(\".pt\")]))\n", "\n", " epochs.sort()\n", " latest_epoch = epochs[-1]\n", @@ -646,11 +619,12 @@ " else:\n", " return epoch\n", "\n", + "\n", "def reload_checkpoint(trainer, epoch, num_train_batches_per_epoch, ckpt_dir):\n", "\n", " epoch_to_load = get_checkpoint(epoch, ckpt_dir)\n", " trainer.state.epoch = epoch_to_load\n", - " trainer.state.iteration = (epoch_to_load* num_train_batches_per_epoch) +1" + " trainer.state.iteration = (epoch_to_load * num_train_batches_per_epoch) + 1" ] }, { @@ -660,14 +634,19 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "reload_checkpoint_epoch = \"latest\"\n", "\n", "train_handlers.append(\n", " CheckpointLoader(\n", - " load_path= os.path.join(ckpt_dir,'checkpoint_epoch='+str(get_checkpoint(reload_checkpoint_epoch, ckpt_dir))+'.pt'),\n", - " load_dict= {\"network_weights\": nnunet_trainer.network._orig_mod, \"optimizer_state\": nnunet_trainer.optimizer, \"scheduler\": nnunet_trainer.lr_scheduler},\n", - " map_location= device\n", + " load_path=os.path.join(\n", + " ckpt_dir, \"checkpoint_epoch=\" + str(get_checkpoint(reload_checkpoint_epoch, ckpt_dir)) + \".pt\"\n", + " ),\n", + " load_dict={\n", + " \"network_weights\": nnunet_trainer.network._orig_mod,\n", + " \"optimizer_state\": nnunet_trainer.optimizer,\n", + " \"scheduler\": nnunet_trainer.lr_scheduler,\n", + " },\n", + " map_location=device,\n", " )\n", ")" ] @@ -690,11 +669,11 @@ "outputs": [], "source": [ "checkpoint = {\n", - " \"inference_allowed_mirroring_axes\": nnunet_trainer.inference_allowed_mirroring_axes,\n", + " \"inference_allowed_mirroring_axes\": nnunet_trainer.inference_allowed_mirroring_axes,\n", " \"init_args\": nnunet_trainer.my_init_kwargs,\n", - " \"trainer_name\": nnunet_trainer.__class__.__name__\n", + " \"trainer_name\": nnunet_trainer.__class__.__name__,\n", "}\n", - "checkpoint_filename = os.path.join(ckpt_dir,'nnunet_checkpoint.pth')\n", + "checkpoint_filename = os.path.join(ckpt_dir, \"nnunet_checkpoint.pth\")\n", "\n", "torch.save(checkpoint, checkpoint_filename)" ] @@ -716,23 +695,13 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "log_dir = \"nnUNetBundle/logs\"\n", "\n", "train_handlers.append(\n", - " TensorBoardStatsHandler(\n", - " log_dir= log_dir,\n", - " output_transform= from_engine(['loss'], first=True),\n", - " tag_name = \"train_loss\"\n", - " )\n", + " TensorBoardStatsHandler(log_dir=log_dir, output_transform=from_engine([\"loss\"], first=True), tag_name=\"train_loss\")\n", ")\n", "\n", - "val_handlers.append(\n", - " TensorBoardStatsHandler(\n", - " log_dir= log_dir,\n", - " iteration_log = False\n", - " )\n", - ")" + "val_handlers.append(TensorBoardStatsHandler(log_dir=log_dir, iteration_log=False))" ] }, { @@ -743,13 +712,14 @@ "outputs": [], "source": [ "def mlflow_transform(state_output):\n", - " return state_output[0]['loss']\n", + " return state_output[0][\"loss\"]\n", + "\n", "\n", "class MLFlownnUNetHandler(MLFlowHandler):\n", " def __init__(self, label_dict, **kwargs):\n", " super(MLFlownnUNetHandler, self).__init__(**kwargs)\n", " self.label_dict = label_dict\n", - " \n", + "\n", " def _default_epoch_log(self, engine) -> None:\n", " \"\"\"\n", " Execute epoch level log operation.\n", @@ -770,8 +740,8 @@ "\n", " for metric in log_dict:\n", " if type(log_dict[metric]) == torch.Tensor:\n", - " for i,val in enumerate(log_dict[metric]):\n", - " new_log_dict[metric+\"_{}\".format(list(self.label_dict.keys())[i+1])] = val\n", + " for i, val in enumerate(log_dict[metric]):\n", + " new_log_dict[metric + \"_{}\".format(list(self.label_dict.keys())[i + 1])] = val\n", " else:\n", " new_log_dict[metric] = log_dict[metric]\n", " self._log_metrics(new_log_dict, step=current_epoch)\n", @@ -792,35 +762,31 @@ " params_dict = {}\n", " config_values = monai.config.deviceconfig.get_config_values()\n", " for k in config_values:\n", - " params_dict[re.sub(\"[()]\",\" \",str(k))] = config_values[k]\n", + " params_dict[re.sub(\"[()]\", \" \", str(k))] = config_values[k]\n", "\n", " optional_config_values = monai.config.deviceconfig.get_optional_config_values()\n", " for k in optional_config_values:\n", - " params_dict[re.sub(\"[()]\",\" \",str(k))] = optional_config_values[k]\n", + " params_dict[re.sub(\"[()]\", \" \", str(k))] = optional_config_values[k]\n", "\n", " gpu_info = monai.config.deviceconfig.get_gpu_info()\n", " for k in gpu_info:\n", - " params_dict[re.sub(\"[()]\",\" \",str(k))] = str(gpu_info[k])\n", + " params_dict[re.sub(\"[()]\", \" \", str(k))] = str(gpu_info[k])\n", "\n", " yaml_config_files = [params_file]\n", " # %%\n", " monai_config = {}\n", " for config_file in yaml_config_files:\n", - " with open(config_file, 'r') as file:\n", + " with open(config_file, \"r\") as file:\n", " monai_config.update(yaml.safe_load(file))\n", "\n", " monai_config[\"bundle_root\"] = str(Path(Path(params_file).parent).parent)\n", "\n", - " parser = ConfigParser(monai_config, globals={\"os\": \"os\",\n", - " \"pathlib\": \"pathlib\",\n", - " \"json\": \"json\",\n", - " \"ignite\": \"ignite\"\n", - " })\n", + " parser = ConfigParser(monai_config, globals={\"os\": \"os\", \"pathlib\": \"pathlib\", \"json\": \"json\", \"ignite\": \"ignite\"})\n", "\n", " parser.parse(True)\n", "\n", " for k in monai_config:\n", - " params_dict[k] = parser.get_parsed_content(k,instantiate=True)\n", + " params_dict[k] = parser.get_parsed_content(k, instantiate=True)\n", "\n", " if custom_params is not None:\n", " for k in custom_params:\n", @@ -870,28 +836,28 @@ "\n", "train_handlers.append(\n", " MLFlownnUNetHandler(\n", - " dataset_dict = {\"train\": train_dataset},\n", - " dataset_keys = \"case_identifier\",\n", - " experiment_param = create_mlflow_experiment_params(params_file),\n", - " experiment_name= mlflow_experiment_name,\n", - " label_dict = label_dict,\n", - " output_transform = mlflow_transform,\n", - " run_name = mlflow_run_name,\n", - " state_attributes = [\"best_metric\", \"best_metric_epoch\"],\n", - " tag_name = \"Train_Loss\",\n", - " tracking_uri = tracking_uri,\n", + " dataset_dict={\"train\": train_dataset},\n", + " dataset_keys=\"case_identifier\",\n", + " experiment_param=create_mlflow_experiment_params(params_file),\n", + " experiment_name=mlflow_experiment_name,\n", + " label_dict=label_dict,\n", + " output_transform=mlflow_transform,\n", + " run_name=mlflow_run_name,\n", + " state_attributes=[\"best_metric\", \"best_metric_epoch\"],\n", + " tag_name=\"Train_Loss\",\n", + " tracking_uri=tracking_uri,\n", " )\n", ")\n", "\n", "val_handlers.append(\n", " MLFlownnUNetHandler(\n", - " experiment_name= mlflow_experiment_name,\n", - " iteration_log = False,\n", - " label_dict = label_dict,\n", - " output_transform = mlflow_transform,\n", - " run_name = mlflow_run_name,\n", - " state_attributes = [\"best_metric\", \"best_metric_epoch\"],\n", - " tracking_uri = tracking_uri,\n", + " experiment_name=mlflow_experiment_name,\n", + " iteration_log=False,\n", + " label_dict=label_dict,\n", + " output_transform=mlflow_transform,\n", + " run_name=mlflow_run_name,\n", + " state_attributes=[\"best_metric\", \"best_metric_epoch\"],\n", + " tracking_uri=tracking_uri,\n", " )\n", ")" ] @@ -921,19 +887,19 @@ "outputs": [], "source": [ "trainer = SupervisedTrainer(\n", - " amp= True,\n", - " device = device,\n", - " epoch_length = iterations,\n", - " loss_function = loss,\n", - " max_epochs = epochs,\n", - " network = network,\n", - " prepare_batch = prepare_nnunet_batch,\n", - " optimizer = optimizer,\n", - " train_data_loader = train_dataloader,\n", - " train_handlers= train_handlers,\n", - " key_train_metric = {\"Train_Dice\": train_key_metric},\n", - " postprocessing= postprocessing,\n", - " additional_metrics = additional_metrics\n", + " amp=True,\n", + " device=device,\n", + " epoch_length=iterations,\n", + " loss_function=loss,\n", + " max_epochs=epochs,\n", + " network=network,\n", + " prepare_batch=prepare_nnunet_batch,\n", + " optimizer=optimizer,\n", + " train_data_loader=train_dataloader,\n", + " train_handlers=train_handlers,\n", + " key_train_metric={\"Train_Dice\": train_key_metric},\n", + " postprocessing=postprocessing,\n", + " additional_metrics=additional_metrics,\n", ")" ] }, @@ -1712,14 +1678,14 @@ " config_files = [f.path for f in os.scandir(config_folder) if f.path.endswith(\".yaml\")]\n", " config = {}\n", " for config_file in config_files:\n", - " with open(config_file, 'r') as file:\n", + " with open(config_file, \"r\") as file:\n", " config.update(yaml.safe_load(file))\n", "\n", " if output_file.endswith(\".yaml\"):\n", - " with open(output_file, 'w') as file:\n", + " with open(output_file, \"w\") as file:\n", " yaml.dump(config, file)\n", " if output_file.endswith(\".json\"):\n", - " with open(output_file, 'w') as file:\n", + " with open(output_file, \"w\") as file:\n", " json.dump(config, file)\n", "\n", " return config" @@ -1802,7 +1768,7 @@ " \"model_folder\": \"nnUNetBundle/models\",\n", "}\n", "\n", - "monai_predictor = get_nnunet_monai_predictor(**nnunet_config, model_name=f\"checkpoint_epoch={ckpt_epoch}.pt\")" + "monai_predictor = get_nnunet_monai_predictor(**nnunet_config, model_name=f\"checkpoint_epoch={ckpt_epoch}.pt\")" ] }, { @@ -1867,12 +1833,14 @@ "metadata": {}, "outputs": [], "source": [ - "def get_subfolder_dataset(data_dir,modality_conf):\n", + "def get_subfolder_dataset(data_dir, modality_conf):\n", " data_list = []\n", " for f in os.scandir(data_dir):\n", "\n", " if f.is_dir():\n", - " subject_dict = {key:str(pathlib.Path(f.path).joinpath(f.name+modality_conf[key]['suffix'])) for key in modality_conf}\n", + " subject_dict = {\n", + " key: str(pathlib.Path(f.path).joinpath(f.name + modality_conf[key][\"suffix\"])) for key in modality_conf\n", + " }\n", " data_list.append(subject_dict)\n", " return data_list" ] @@ -1888,7 +1856,7 @@ " \"image\": {\"suffix\": \".nii.gz\"},\n", "}\n", "\n", - "data = get_subfolder_dataset(\"nnUNetBundle/test_input\",modalities)" + "data = get_subfolder_dataset(\"nnUNetBundle/test_input\", modalities)" ] }, { @@ -1901,10 +1869,10 @@ "from monai.transforms import LoadImaged\n", "from monai.data import Dataset, DataLoader\n", "\n", - "preprocessing = LoadImaged(keys=[\"image\"],ensure_channel_first=True, image_only=False)\n", + "preprocessing = LoadImaged(keys=[\"image\"], ensure_channel_first=True, image_only=False)\n", "\n", "\n", - "test_dataset = Dataset(data,transform=preprocessing)\n", + "test_dataset = Dataset(data, transform=preprocessing)\n", "\n", "test_loader = DataLoader(test_dataset, batch_size=1)" ] @@ -1953,15 +1921,18 @@ "from monai.transforms import Compose, Transposed, SaveImaged, Decollated\n", "\n", "\n", - "postprocessing = Compose([\n", - " #Decollated(keys=None, detach=True),\n", - " Transposed(keys=\"pred\",indices=[0,3,2,1]),\n", - " SaveImaged(keys=\"pred\",\n", - " output_dir=\"nnUNetBundle/test_output\",\n", - " output_postfix=\"prediction\",\n", - " meta_keys=\"image_meta_dict\",\n", - " )\n", - "])" + "postprocessing = Compose(\n", + " [\n", + " # Decollated(keys=None, detach=True),\n", + " Transposed(keys=\"pred\", indices=[0, 3, 2, 1]),\n", + " SaveImaged(\n", + " keys=\"pred\",\n", + " output_dir=\"nnUNetBundle/test_output\",\n", + " output_postfix=\"prediction\",\n", + " meta_keys=\"image_meta_dict\",\n", + " ),\n", + " ]\n", + ")" ] }, { @@ -1971,7 +1942,7 @@ "metadata": {}, "outputs": [], "source": [ - "postprocessing({\"pred\":pred})" + "postprocessing({\"pred\": pred})" ] }, { @@ -1994,10 +1965,7 @@ "from monai.engines import SupervisedEvaluator\n", "\n", "validator = SupervisedEvaluator(\n", - " val_data_loader=test_loader,\n", - " device = \"cuda:0\",\n", - " network = monai_predictor,\n", - " postprocessing= postprocessing\n", + " val_data_loader=test_loader, device=\"cuda:0\", network=monai_predictor, postprocessing=postprocessing\n", ")" ] }, @@ -2196,10 +2164,11 @@ "from nnunetv2.training.logging.nnunet_logger import nnUNetLogger\n", "import shutil\n", "\n", + "\n", "def subfiles(directory, prefix=None, suffix=None, join=True, sort=True):\n", " \"\"\"\n", " List files in a directory with optional filtering by prefix and/or suffix.\n", - " \n", + "\n", " Parameters\n", " ----------\n", " directory : str\n", @@ -2212,14 +2181,13 @@ " If True, the directory path will be joined with the filenames. Default is True.\n", " sort : bool, optional\n", " If True, the list of files will be sorted. Default is True.\n", - " \n", + "\n", " Returns\n", " -------\n", " list of str\n", " A list of filenames (with full paths if `join` is True) that match the specified criteria.\n", " \"\"\"\n", "\n", - " \n", " files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]\n", " if prefix is not None:\n", " files = [f for f in files if f.startswith(prefix)]\n", @@ -2245,69 +2213,69 @@ "\n", " nnunet_model_folder = Path(os.environ[\"nnUNet_results\"]).joinpath(\n", " maybe_convert_to_dataset_name(nnunet_config[\"dataset_name_or_id\"]),\n", - " f\"{nnunet_trainer}__{nnunet_plans}__3d_fullres\")\n", - " \n", + " f\"{nnunet_trainer}__{nnunet_plans}__3d_fullres\",\n", + " )\n", + "\n", " nnunet_preprocess_model_folder = Path(os.environ[\"nnUNet_preprocessed\"]).joinpath(\n", - " maybe_convert_to_dataset_name(nnunet_config[\"dataset_name_or_id\"]))\n", - " \n", - " Path(nnunet_model_folder).joinpath(\"fold_0\").mkdir(parents=True, exist_ok=True)\n", + " maybe_convert_to_dataset_name(nnunet_config[\"dataset_name_or_id\"])\n", + " )\n", "\n", + " Path(nnunet_model_folder).joinpath(\"fold_0\").mkdir(parents=True, exist_ok=True)\n", "\n", " nnunet_checkpoint = torch.load(f\"{bundle_path}/models/nnunet_checkpoint.pth\")\n", - " latest_checkpoints = subfiles(Path(bundle_path).joinpath(\"models\"),prefix=\"checkpoint_epoch\",sort=True,join=False)\n", + " latest_checkpoints = subfiles(\n", + " Path(bundle_path).joinpath(\"models\"), prefix=\"checkpoint_epoch\", sort=True, join=False\n", + " )\n", " epochs = []\n", " for latest_checkpoint in latest_checkpoints:\n", - " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\"):-len(\".pt\")]))\n", + " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\") : -len(\".pt\")]))\n", "\n", " epochs.sort()\n", " final_epoch = epochs[-1]\n", " monai_last_checkpoint = torch.load(f\"{bundle_path}/models/checkpoint_epoch={final_epoch}.pt\")\n", "\n", - " best_checkpoints = subfiles(Path(bundle_path).joinpath(\"models\"), prefix=\"checkpoint_key_metric\", sort=True,\n", - " join=False)\n", + " best_checkpoints = subfiles(\n", + " Path(bundle_path).joinpath(\"models\"), prefix=\"checkpoint_key_metric\", sort=True, join=False\n", + " )\n", " key_metrics = []\n", " for best_checkpoint in best_checkpoints:\n", - " key_metrics.append(str(best_checkpoint[len(\"checkpoint_key_metric=\"):-len(\".pt\")]))\n", + " key_metrics.append(str(best_checkpoint[len(\"checkpoint_key_metric=\") : -len(\".pt\")]))\n", "\n", " key_metrics.sort()\n", " best_key_metric = key_metrics[-1]\n", " monai_best_checkpoint = torch.load(f\"{bundle_path}/models/checkpoint_key_metric={best_key_metric}.pt\")\n", "\n", - " nnunet_checkpoint['optimizer_state'] = monai_last_checkpoint['optimizer_state']\n", - "\n", - "\n", + " nnunet_checkpoint[\"optimizer_state\"] = monai_last_checkpoint[\"optimizer_state\"]\n", "\n", - " nnunet_checkpoint['network_weights'] = odict()\n", + " nnunet_checkpoint[\"network_weights\"] = odict()\n", "\n", - " for key in monai_last_checkpoint['network_weights']:\n", - " nnunet_checkpoint['network_weights'][key] = monai_last_checkpoint['network_weights'][key]\n", + " for key in monai_last_checkpoint[\"network_weights\"]:\n", + " nnunet_checkpoint[\"network_weights\"][key] = monai_last_checkpoint[\"network_weights\"][key]\n", "\n", - " nnunet_checkpoint['current_epoch'] = final_epoch\n", - " nnunet_checkpoint['logging'] = nnUNetLogger().get_checkpoint()\n", - " nnunet_checkpoint['_best_ema'] = 0\n", - " nnunet_checkpoint['grad_scaler_state'] = None\n", + " nnunet_checkpoint[\"current_epoch\"] = final_epoch\n", + " nnunet_checkpoint[\"logging\"] = nnUNetLogger().get_checkpoint()\n", + " nnunet_checkpoint[\"_best_ema\"] = 0\n", + " nnunet_checkpoint[\"grad_scaler_state\"] = None\n", "\n", + " torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(\"fold_0\", \"checkpoint_final.pth\"))\n", "\n", + " nnunet_checkpoint[\"network_weights\"] = odict()\n", "\n", - " torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(\"fold_0\",\"checkpoint_final.pth\"))\n", + " nnunet_checkpoint[\"optimizer_state\"] = monai_best_checkpoint[\"optimizer_state\"]\n", "\n", - " nnunet_checkpoint['network_weights'] = odict()\n", - "\n", - " nnunet_checkpoint['optimizer_state'] = monai_best_checkpoint['optimizer_state']\n", - "\n", - " for key in monai_best_checkpoint['network_weights']:\n", - " nnunet_checkpoint['network_weights'][key] = \\\n", - " monai_best_checkpoint['network_weights'][key]\n", + " for key in monai_best_checkpoint[\"network_weights\"]:\n", + " nnunet_checkpoint[\"network_weights\"][key] = monai_best_checkpoint[\"network_weights\"][key]\n", "\n", " torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(\"fold_0\", \"checkpoint_best.pth\"))\n", "\n", - "\n", - " shutil.move(f\"{bundle_path}/models/dataset.json\",nnunet_model_folder)\n", - " shutil.move(f\"{bundle_path}/models/plans.json\",nnunet_model_folder)\n", - " shutil.move(f\"{nnunet_preprocess_model_folder}/dataset_fingerprint.json\",nnunet_model_folder)\n", - " shutil.move(f\"{bundle_path}/models/nnunet_checkpoint.pth\",nnunet_model_folder)\n", - " shutil.move(f\"{bundle_path}/models/checkpoint_epoch={final_epoch}.pt\",f\"{bundle_path}/models/model.pt\")\n", - " shutil.move(f\"{bundle_path}/models/checkpoint_key_metric={best_key_metric}.pt\",f\"{bundle_path}/models/best_model.pt\")\n" + " shutil.move(f\"{bundle_path}/models/dataset.json\", nnunet_model_folder)\n", + " shutil.move(f\"{bundle_path}/models/plans.json\", nnunet_model_folder)\n", + " shutil.move(f\"{nnunet_preprocess_model_folder}/dataset_fingerprint.json\", nnunet_model_folder)\n", + " shutil.move(f\"{bundle_path}/models/nnunet_checkpoint.pth\", nnunet_model_folder)\n", + " shutil.move(f\"{bundle_path}/models/checkpoint_epoch={final_epoch}.pt\", f\"{bundle_path}/models/model.pt\")\n", + " shutil.move(\n", + " f\"{bundle_path}/models/checkpoint_key_metric={best_key_metric}.pt\", f\"{bundle_path}/models/best_model.pt\"\n", + " )" ] }, { @@ -2367,9 +2335,7 @@ "\n", "ConfigParser.export_config_file(data_src, data_src_cfg)\n", "\n", - "runner = nnUNetV2Runner(\n", - " input_config=data_src_cfg, trainer_class_name=\"nnUNetTrainer\", work_dir=nnunet_root_dir\n", - ")" + "runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name=\"nnUNetTrainer\", work_dir=nnunet_root_dir)" ] }, { @@ -2389,7 +2355,7 @@ "metadata": {}, "outputs": [], "source": [ - "runner.find_best_configuration(configs=[\"3d_fullres\"],folds=[0],allow_ensembling=False,num_processes=1)" + "runner.find_best_configuration(configs=[\"3d_fullres\"], folds=[0], allow_ensembling=False, num_processes=1)" ] }, { @@ -2399,7 +2365,7 @@ "metadata": {}, "outputs": [], "source": [ - "runner.predict_ensemble_postprocessing(folds=[0],run_ensemble=False,run_postprocessing=False)" + "runner.predict_ensemble_postprocessing(folds=[0], run_ensemble=False, run_postprocessing=False)" ] }, { @@ -2433,7 +2399,7 @@ "\n", "bundle_root = \"nnUNetBundle\"\n", "\n", - "convert_nnunet_to_monai_bundle(nnunet_config, bundle_root, 0)\n" + "convert_nnunet_to_monai_bundle(nnunet_config, bundle_root, 0)" ] } ], From 482bb70a92eecb107f6ab1e7f427f0e3282fbecf Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 26 Feb 2025 10:15:59 +0000 Subject: [PATCH 09/12] Enhance nnUNet bundle notebook with DecathlonDataset integration and logging configuration --- bundle/06_nnunet_monai_bundle.ipynb | 32 +++++++++++++------- bundle/nnUNet_Bundle.ipynb | 45 ++++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/bundle/06_nnunet_monai_bundle.ipynb b/bundle/06_nnunet_monai_bundle.ipynb index cbd9998ef..b9572849e 100644 --- a/bundle/06_nnunet_monai_bundle.ipynb +++ b/bundle/06_nnunet_monai_bundle.ipynb @@ -65,6 +65,7 @@ "outputs": [], "source": [ "from monai.config import print_config\n", + "from monai.apps import DecathlonDataset\n", "import os\n", "import tempfile\n", "from monai.bundle.config_parser import ConfigParser\n", @@ -73,6 +74,7 @@ "from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", "import json\n", "from pathlib import Path\n", + "import nnunetv2\n", "\n", "print_config()" ] @@ -94,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"/home/maia-user/Documents/GitHub/tutorials/bundle/MONAI/Data\"" + "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"/home/maia-user/Documents/MONAI/Data\"" ] }, { @@ -119,6 +121,15 @@ "To get the Decathlon Spleen dataset and generate the corresponding data list, you can follow the instructions in the [MSD Datalist Generator Notebook](../auto3dseg/notebooks/msd_datalist_generator.ipynb)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DecathlonDataset(root_dir, \"Task09_Spleen\", \"training\", download=True)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -230,7 +241,7 @@ "data_src = {\n", " \"modality\": \"CT\",\n", " \"dataset_name_or_id\": \"09\",\n", - " \"datalist\": os.path.join(root_dir, \"Task09_Spleen/msd_task09_spleen_folds.json\"),\n", + " \"datalist\": str(datalist_file),\n", " \"dataroot\": os.path.join(root_dir, \"Task09_Spleen\"),\n", "}\n", "\n", @@ -477,12 +488,12 @@ "\n", "\n", "BUNDLE_ROOT=nnUNetBundle\n", - "MONAI_DATA_DIRECTORY=/home/maia-user/Documents/GitHub/tutorials/bundle/MONAI/Data\n", + "MONAI_DATA_DIRECTORY=/home/maia-user/Documents/MONAI/Data\n", "\n", "python -m monai.bundle run \\\n", " --config-file $BUNDLE_ROOT/configs/inference.yaml \\\n", " --bundle-root $BUNDLE_ROOT \\\n", - " --data_list_file $MONAI_DATA_DIRECTORY/Task09_Spleen/msd_task09_spleen_folds.json \\\n", + " --data_list_file $MONAI_DATA_DIRECTORY/Task09_Spleen/Task09_Spleen_folds.json \\\n", " --output-dir $BUNDLE_ROOT/pred_output \\\n", " --data_dir $MONAI_DATA_DIRECTORY/Task09_Spleen \\\n", " --logging-file$BUNDLE_ROOT/configs/logging.conf" @@ -755,9 +766,7 @@ "metadata": {}, "outputs": [], "source": [ - "import nnunetv2\n", - "\n", - "print(nnunetv2.__file__)" + "nnunet_training_file = Path(nnunetv2.training.__file__).parent.joinpath(\"lr_scheduler\", \"polylr.py\")" ] }, { @@ -773,7 +782,7 @@ "metadata": {}, "outputs": [], "source": [ - "%%writefile /training/lr_scheduler/polylr.py\n", + "%%writefile $nnunet_training_file\n", "\n", "from torch.optim.lr_scheduler import _LRScheduler\n", "\n", @@ -816,7 +825,7 @@ "outputs": [], "source": [ "%%bash\n", - "\n", + "export MONAI_DATA_DIRECTORY=/home/maia-user/Documents/MONAI/Data\n", "export nnUNet_raw=$MONAI_DATA_DIRECTORY\"/nnUNet/nnUNet_raw_data_base\"\n", "export nnUNet_preprocessed=$MONAI_DATA_DIRECTORY\"/nnUNet/nnUNet_preprocessed\"\n", "export nnUNet_results=$MONAI_DATA_DIRECTORY\"/nnUNet/nnUNet_trained_models\"\n", @@ -824,11 +833,12 @@ "export BUNDLE=nnUNetBundle\n", "export PYTHONPATH=$BUNDLE\n", "\n", - "#export nnUNet_def_n_proc=2\n", - "#export nnUNet_n_proc_DA=2\n", + "export nnUNet_def_n_proc=2\n", + "export nnUNet_n_proc_DA=2\n", "\n", "python -m monai.bundle run \\\n", "--bundle-root nnUNetBundle \\\n", + "--dataset_name_or_id 009 \\\n", "--config-file nnUNetBundle/configs/train.yaml" ] }, diff --git a/bundle/nnUNet_Bundle.ipynb b/bundle/nnUNet_Bundle.ipynb index df5ed364d..eb85bb6f5 100644 --- a/bundle/nnUNet_Bundle.ipynb +++ b/bundle/nnUNet_Bundle.ipynb @@ -938,6 +938,49 @@ "which tree && tree nnUNetBundle || true" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "70a37817", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/configs/logging.conf\n", + "[loggers]\n", + "keys=root\n", + "\n", + "[handlers]\n", + "keys=consoleHandler\n", + "\n", + "[formatters]\n", + "keys=fullFormatter\n", + "\n", + "[logger_root]\n", + "level=INFO\n", + "handlers=consoleHandler\n", + "\n", + "[handler_consoleHandler]\n", + "class=StreamHandler\n", + "level=INFO\n", + "formatter=fullFormatter\n", + "args=(sys.stdout,)\n", + "\n", + "[formatter_fullFormatter]\n", + "format=%(asctime)s - %(name)s - %(levelname)s - %(message)s\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c25feaf", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile nnUNetBundle/configs/metadata.json\n", + "\n", + "##TODO: Add metadata, following the instructions in https://docs.monai.io/en/stable/mb_specification.html" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1717,7 +1760,7 @@ " --bundle_root $BUNDLE_ROOT \\\n", " --config_file $BUNDLE_ROOT/configs/train.yaml\n", "\n", - "#Option to resume training\n", + "# Option to resume training\n", "#--config_file \"['$BUNDLE_ROOT/configs/train.yaml','$BUNDLE_ROOT/configs/train_resume.yaml']\"\n", "\n", "# Log to Local MLFlow\n", From 3f71b085029c532ac7fff37379e151d5f662e9d4 Mon Sep 17 00:00:00 2001 From: simben Date: Wed, 26 Feb 2025 10:19:44 +0000 Subject: [PATCH 10/12] Update MONAI_DATA_DIRECTORY path in nnUNet bundle notebook to relative path --- bundle/06_nnunet_monai_bundle.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bundle/06_nnunet_monai_bundle.ipynb b/bundle/06_nnunet_monai_bundle.ipynb index b9572849e..ee6f22b0a 100644 --- a/bundle/06_nnunet_monai_bundle.ipynb +++ b/bundle/06_nnunet_monai_bundle.ipynb @@ -96,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"/home/maia-user/Documents/MONAI/Data\"" + "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"MONAI/Data\"" ] }, { @@ -488,7 +488,7 @@ "\n", "\n", "BUNDLE_ROOT=nnUNetBundle\n", - "MONAI_DATA_DIRECTORY=/home/maia-user/Documents/MONAI/Data\n", + "MONAI_DATA_DIRECTORY=MONAI/Data\n", "\n", "python -m monai.bundle run \\\n", " --config-file $BUNDLE_ROOT/configs/inference.yaml \\\n", @@ -825,7 +825,7 @@ "outputs": [], "source": [ "%%bash\n", - "export MONAI_DATA_DIRECTORY=/home/maia-user/Documents/MONAI/Data\n", + "export MONAI_DATA_DIRECTORY=MONAI/Data\n", "export nnUNet_raw=$MONAI_DATA_DIRECTORY\"/nnUNet/nnUNet_raw_data_base\"\n", "export nnUNet_preprocessed=$MONAI_DATA_DIRECTORY\"/nnUNet/nnUNet_preprocessed\"\n", "export nnUNet_results=$MONAI_DATA_DIRECTORY\"/nnUNet/nnUNet_trained_models\"\n", From 2bb8b6a319b51562cc75feb582ab88777137a98a Mon Sep 17 00:00:00 2001 From: simben Date: Fri, 28 Feb 2025 07:47:34 +0000 Subject: [PATCH 11/12] Add copyright notice and enhance nnUNet bundle notebook with environment setup instructions --- bundle/nnUNet_Bundle.ipynb | 90 +++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/bundle/nnUNet_Bundle.ipynb b/bundle/nnUNet_Bundle.ipynb index eb85bb6f5..a767e5b45 100644 --- a/bundle/nnUNet_Bundle.ipynb +++ b/bundle/nnUNet_Bundle.ipynb @@ -1,5 +1,22 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "d8f10c45", + "metadata": {}, + "source": [ + "Copyright (c) MONAI Consortium \n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); \n", + "you may not use this file except in compliance with the License. \n", + "You may obtain a copy of the License at \n", + "    http://www.apache.org/licenses/LICENSE-2.0 \n", + "Unless required by applicable law or agreed to in writing, software \n", + "distributed under the License is distributed on an \"AS IS\" BASIS, \n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n", + "See the License for the specific language governing permissions and \n", + "limitations under the License." + ] + }, { "cell_type": "markdown", "id": "bec25bff", @@ -12,6 +29,25 @@ "The tutorial assumes that the Spleen Dataset has been already downloaded and preprocessed as described in the [nnUNet MONAI Bundle Notebook](./06_nnunet_monai_bundle.ipynb)." ] }, + { + "cell_type": "markdown", + "id": "a597f118", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e6fdd58", + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, tqdm]\"\n", + "!python -c \"import nnunetv2\" || pip install -q nnunetv2" + ] + }, { "cell_type": "markdown", "id": "70a2adb6", @@ -28,7 +64,7 @@ "outputs": [], "source": [ "import torch\n", - "from monai.data import Dataset\n", + "from monai.data import Dataset, DataLoader\n", "from monai.handlers import (\n", " StatsHandler,\n", " from_engine,\n", @@ -41,8 +77,7 @@ " MLFlowHandler,\n", ")\n", "from monai.engines import SupervisedTrainer, SupervisedEvaluator\n", - "\n", - "from monai.transforms import Compose, Lambdad, Activationsd, AsDiscreted\n", + "from monai.transforms import Compose, Lambdad, Activationsd, AsDiscreted, Transposed, SaveImaged, LoadImaged\n", "\n", "import re\n", "import pathlib\n", @@ -52,15 +87,16 @@ "from monai.bundle import ConfigParser\n", "import monai\n", "from pathlib import Path\n", - "from odict import odict" - ] - }, - { - "cell_type": "markdown", - "id": "297a2bb9", - "metadata": {}, - "source": [ - "## Setup environment" + "from odict import odict\n", + "\n", + "from monai.bundle.nnunet import get_nnunet_trainer, get_nnunet_monai_predictor, convert_nnunet_to_monai_bundle\n", + "\n", + "from monai.apps.nnunet import nnUNetV2Runner\n", + "\n", + "from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n", + "from nnunetv2.training.logging.nnunet_logger import nnUNetLogger\n", + "import shutil\n", + "\n" ] }, { @@ -70,8 +106,6 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "\n", "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"MONAI/Data\"\n", "\n", "work_dir = os.path.join(os.environ[\"MONAI_DATA_DIRECTORY\"], \"nnUNet\")\n", @@ -113,8 +147,6 @@ "metadata": {}, "outputs": [], "source": [ - "from monai.bundle.nnunet import get_nnunet_trainer\n", - "\n", "nnunet_config = {\n", " \"dataset_name_or_id\": \"009\",\n", " \"configuration\": \"3d_fullres\",\n", @@ -978,7 +1010,7 @@ "source": [ "%%writefile nnUNetBundle/configs/metadata.json\n", "\n", - "##TODO: Add metadata, following the instructions in https://docs.monai.io/en/stable/mb_specification.html" + "#TODO: Add MetaData, following the instructions in https://docs.monai.io/en/stable/mb_specification.html" ] }, { @@ -1805,8 +1837,6 @@ "metadata": {}, "outputs": [], "source": [ - "from monai.bundle.nnunet import get_nnunet_monai_predictor\n", - "\n", "nnunet_config = {\n", " \"model_folder\": \"nnUNetBundle/models\",\n", "}\n", @@ -1909,9 +1939,6 @@ "metadata": {}, "outputs": [], "source": [ - "from monai.transforms import LoadImaged\n", - "from monai.data import Dataset, DataLoader\n", - "\n", "preprocessing = LoadImaged(keys=[\"image\"], ensure_channel_first=True, image_only=False)\n", "\n", "\n", @@ -1961,9 +1988,6 @@ "metadata": {}, "outputs": [], "source": [ - "from monai.transforms import Compose, Transposed, SaveImaged, Decollated\n", - "\n", - "\n", "postprocessing = Compose(\n", " [\n", " # Decollated(keys=None, detach=True),\n", @@ -2005,8 +2029,6 @@ "metadata": {}, "outputs": [], "source": [ - "from monai.engines import SupervisedEvaluator\n", - "\n", "validator = SupervisedEvaluator(\n", " val_data_loader=test_loader, device=\"cuda:0\", network=monai_predictor, postprocessing=postprocessing\n", ")" @@ -2203,11 +2225,6 @@ "metadata": {}, "outputs": [], "source": [ - "from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n", - "from nnunetv2.training.logging.nnunet_logger import nnUNetLogger\n", - "import shutil\n", - "\n", - "\n", "def subfiles(directory, prefix=None, suffix=None, join=True, sort=True):\n", " \"\"\"\n", " List files in a directory with optional filtering by prefix and/or suffix.\n", @@ -2328,8 +2345,6 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "\n", "os.environ[\"nnUNet_results\"] = \"MONAI/Data/nnUNet/nnUNet_trained_models\"\n", "os.environ[\"nnUNet_raw\"] = \"MONAI/Data/nnUNet/nnUNet_raw_data_base\"\n", "os.environ[\"nnUNet_preprocessed\"] = \"MONAI/Data/nnUNet/nnUNet_preprocessed\"\n", @@ -2359,10 +2374,6 @@ "metadata": {}, "outputs": [], "source": [ - "from monai.bundle.config_parser import ConfigParser\n", - "from monai.apps.nnunet import nnUNetV2Runner\n", - "\n", - "\n", "root_dir = \"MONAI/Data\"\n", "nnunet_root_dir = os.path.join(root_dir, \"nnUNet\")\n", "\n", @@ -2428,9 +2439,6 @@ "metadata": {}, "outputs": [], "source": [ - "from monai.bundle.nnunet import convert_nnunet_to_monai_bundle\n", - "import os\n", - "\n", "os.environ[\"nnUNet_results\"] = \"MONAI/Data/nnUNet/nnUNet_trained_models\"\n", "os.environ[\"nnUNet_raw\"] = \"MONAI/Data/nnUNet/nnUNet_raw_data_base\"\n", "os.environ[\"nnUNet_preprocessed\"] = \"MONAI/Data/nnUNet/nnUNet_preprocessed\"\n", From 37772336711d51ee544f97885d1b2aab956c4b52 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Feb 2025 07:48:39 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- bundle/nnUNet_Bundle.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bundle/nnUNet_Bundle.ipynb b/bundle/nnUNet_Bundle.ipynb index a767e5b45..41cd299e6 100644 --- a/bundle/nnUNet_Bundle.ipynb +++ b/bundle/nnUNet_Bundle.ipynb @@ -95,8 +95,7 @@ "\n", "from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n", "from nnunetv2.training.logging.nnunet_logger import nnUNetLogger\n", - "import shutil\n", - "\n" + "import shutil" ] }, {