diff --git a/generation/maisi/README.md b/generation/maisi/README.md
index 727b1a2667..5c596b75d2 100644
--- a/generation/maisi/README.md
+++ b/generation/maisi/README.md
@@ -2,14 +2,25 @@
This example demonstrates the applications of training and validating NVIDIA MAISI, a 3D Latent Diffusion Model (LDM) capable of generating large CT images accompanied by corresponding segmentation masks. It supports variable volume size and voxel spacing and allows for the precise control of organ/tumor size.
## MAISI Model Highlight
+**Initial Version (August 2024):** First release `maisi3d-ddpm`.
+
- A Foundation Variational Auto-Encoder (VAE) model for latent feature compression that works for both CT and MRI with flexible volume size and voxel size. Tensor parallel is included to reduce GPU memory usage.
- A Foundation Diffusion model that can generate large CT volumes up to 512 × 512 × 768 size, with flexible volume size and voxel size
- A ControlNet to generate image/mask pairs that can improve downstream tasks, with controllable organ/tumor size
More details can be found in our WACV 2025 paper:
-[Guo, P., Zhao, C., Yang, D., Xu, Z., Nath, V., Tang, Y., ... & Xu, D. (2024). MAISI: Medical AI for Synthetic Imaging. arXiv preprint arXiv:2409.11169](https://arxiv.org/pdf/2409.11169)
-Welcome to try our GUI demo at [https://build.nvidia.com/nvidia/maisi](https://build.nvidia.com/nvidia/maisi).
+[Guo, P., Zhao, C., Yang, D., Xu, Z., Nath, V., Tang, Y., ... & Xu, D. (2024). MAISI: Medical AI for Synthetic Imaging. WACV 2025](https://arxiv.org/pdf/2409.11169)
+
+ππππππ**Release Note (March 2025):** ππππππ
+
+We are excited to announce the new MAISI Version `maisi3d-rflow`. Compared with the previous version `maisi3d-ddpm`, **it accelerated latent diffusion model inference by 33x**. The MAISI VAE is not changed. The differences are:
+- The maisi version `maisi3d-ddpm` uses basic noise scheduler DDPM. `maisi3d-rflow` uses Rectified Flow scheduler. The diffusion model inference can be 33 times faster.
+- The maisi version `maisi3d-ddpm` requires training images to be labeled with body regions (`"top_region_index"` and `"bottom_region_index"`), while `maisi3d-rflow` does not have such requirement. In other words, it is easier to prepare training data for `maisi3d-rflow`.
+- For the released model weights, `maisi3d-rflow` can generate images with better quality for head region and small output volumes than `maisi3d-ddpm`; they have comparable quality for other cases.
+- `maisi3d-rflow` added a diffusion model input `modality`, which gives it flexibility to extend to other modalities. Currently it is set as always equal to 1 since this version only supports CT generation. We predefined some modalities in [./configs/modality_mapping.json](./configs/modality_mapping.json).
+
+**GUI demo:** Welcome to try our GUI demo at [https://build.nvidia.com/nvidia/maisi](https://build.nvidia.com/nvidia/maisi).
The GUI is only a demo for toy examples. This Github repo is the full version.
@@ -29,7 +40,8 @@ We retrained several state-of-the-art diffusion model-based methods using our da
| [DDPM](https://proceedings.neurips.cc/paper_files/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf) | 18.524 | 23.696 | 25.604 | 22.608 |
| [LDM](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf) | 16.853 | 10.191 | 10.093 | 12.379 |
| [HA-GAN](https://ieeexplore.ieee.org/document/9770375) | 17.432 | 10.266 | 13.572 | 13.757 |
-| MAISI | 3.301 | 5.838 | 9.109 | 6.083 |
+| MAISI (`maisi3d-ddpm`) | 3.301 | 5.838 | 9.109 | 6.083 |
+| MAISI (`maisi3d-rflow`) | 2.685 | 4.723 | 7.963 | 5.124 |
**Table 1.** Comparison of FrΓ©chet Inception Distance scores between our foundation model and retrained baseline methods
using the unseen public dataset [autoPET 2023](https://www.nature.com/articles/s41597-022-01718-3) as the reference.
@@ -39,7 +51,7 @@ We retrained several state-of-the-art diffusion model-based methods using our da

-**Figure 1.** Qualitative comparison of generated images between baseline methods
(retrained using our large-scale dataset) and our method.
+**Figure 1.** Qualitative comparison of generated images between baseline methods
(retrained using our large-scale dataset) and our method. The MAISI here refers to `maisi3d-ddpm`.
@@ -58,21 +70,20 @@ We retrained several state-of-the-art diffusion model-based methods using our da
## Time Cost and GPU Memory Usage
### Inference Time Cost and GPU Memory Usage
-| `output_size` | latent size |`autoencoder_sliding_window_infer_size` | `autoencoder_tp_num_splits` | Peak Memory | DM Time | VAE Time |
-|---------------|:--------------------------------------:|:--------------------------------------:|:---------------------------:|:-----------:|:-------:|:--------:|
-| [256x256x128](./configs/config_infer_16g_256x256x128.json) |4x64x64x32| >=[64,64,32], not used | 2 | 14G | 57s | 1s |
-| [256x256x256](./configs/config_infer_16g_256x256x256.json) |4x64x64x64| [48,48,64], 4 patches | 2 | 14G | 81s | 7s |
-| [512x512x128](./configs/config_infer_16g_512x512x128.json) |4x128x128x32| [64,64,32], 9 patches | 1 | 14G | 138s | 7s |
-| | | | | | |
-| [256x256x256](./configs/config_infer_24g_256x256x256.json) |4x64x64x64| >=[64,64,64], not used | 4 | 22G | 81s | 2s |
-| [512x512x128](./configs/config_infer_24g_512x512x128.json) |4x128x128x32| [80,80,32], 4 patches | 1 | 18G | 138s | 9s |
-| [512x512x512](./configs/config_infer_24g_512x512x512.json) |4x128x128x128| [64,64,48], 36 patches | 2 | 22G | 569s | 29s |
-| | | | | | |
-| [512x512x512](./configs/config_infer_32g_512x512x512.json) |4x128x128x128| [64,64,64], 27 patches | 2 | 26G | 569s | 40s |
-| | | | | | |
-| [512x512x128](./configs/config_infer_80g_512x512x128.json) |4x128x128x32| >=[128,128,32], not used | 4 | 37G | 138s | 140s |
-| [512x512x512](./configs/config_infer_80g_512x512x512.json) |4x128x128x128| [80,80,80], 8 patches | 2 | 44G | 569s | 30s |
-| [512x512x768](./configs/config_infer_24g_512x512x768.json) |4x128x128x192| [80,80,112], 8 patches | 4 | 55G | 904s | 48s |
+| `output_size` | Peak Memory | VAE Time + DM Time (`maisi3d-ddpm`) | VAE Time + DM Time (`maisi3d-rflow`) | latent size | `autoencoder_sliding_window_infer_size` | `autoencoder_tp_num_splits` | VAE Time | DM Time (`maisi3d-ddpm`) | DM Time (`maisi3d-rflow`) |
+|---------------|:-----------:|:------------------------:|:------------------------:|:--------------------------------------:|:--------------------------------------:|:---------------------------:|:--------:|:---------------:|:---------------:|
+| [256x256x128](./configs/config_infer_16g_256x256x128.json) | 15.0G | 58s | 3s | 4x64x64x32 | >=[64,64,32], not used | 2 | 1s | 57s | 2s |
+| [256x256x256](./configs/config_infer_16g_256x256x256.json) | 15.4G | 86s | 8s | 4x64x64x64 | [48,48,64], 4 patches | 4 | 5s | 81s | 3s |
+| [512x512x128](./configs/config_infer_16g_512x512x128.json) | 15.7G | 146s | 13s | 4x128x128x32 | [64,64,32], 9 patches | 2 | 8s | 138s | 5s |
+| | | | | | | | | | |
+| [256x256x256](./configs/config_infer_24g_256x256x256.json) | 22.7G | 83s | 5s | 4x64x64x64 | >=[64,64,64], not used | 4 | 2s | 81s | 3s |
+| [512x512x128](./configs/config_infer_24g_512x512x128.json) | 21.0G | 144s | 11s | 4x128x128x32 | [80,80,32], 4 patches | 2 | 6s | 138s | 5s |
+| [512x512x512](./configs/config_infer_24g_512x512x512.json) | 22.8G | 598s | 48s | 4x128x128x128 | [64,64,48], 36 patches | 2 | 29s | 569s | 19s |
+| | | | | | | | | | |
+| [512x512x512](./configs/config_infer_32g_512x512x512.json) | 28.4G | 599s | 49s | 4x128x128x128 | [80,80,48], 16 patches | 4 | 30s | 569s | 19s |
+| | | | | | | | | | |
+| [512x512x512](./configs/config_infer_80g_512x512x512.json) | 45.3G | 601s | 51s | 4x128x128x128 | [80,80,80], 8 patches | 2 | 32s | 569s | 19s |
+| [512x512x768](./configs/config_infer_80g_512x512x768.json) | 49.7G | 961s | 87s | 4x128x128x192 | [80,80,96], 12 patches | 4 | 57s | 904s | 30s |
**Table 3:** Inference Time Cost and GPU Memory Usage. `DM Time` refers to the time required for diffusion model inference. `VAE Time` refers to the time required for VAE decoder inference. The total inference time is the sum of `DM Time` and `VAE Time`. The experiment was conducted on an A100 80G GPU.
@@ -86,7 +97,7 @@ When `autoencoder_sliding_window_infer_size` is equal to or larger than the late
### Training GPU Memory Usage
The VAE is trained on patches and can be trained using a 16G GPU if the patch size is set to a small value, such as [64, 64, 64]. Users can adjust the patch size to fit the available GPU memory. For the released model, we initially trained the autoencoder on 16G V100 GPUs with a small patch size of [64, 64, 64], and then continued training on 32G V100 GPUs with a larger patch size of [128, 128, 128].
-The DM and ControlNet are trained on whole images rather than patches. The GPU memory usage during training depends on the size of the input images.
+The DM and ControlNet are trained on whole images rather than patches. The GPU memory usage during training depends on the size of the input images. There is no big difference on memory usage between `maisi3d-ddpm` and `maisi3d-rflow`.
| image size | latent size | Peak Memory |
|--------------|:------------- |:-----------:|
@@ -104,13 +115,13 @@ The DM and ControlNet are trained on whole images rather than patches. The GPU m
## MAISI Model Workflow
The training and inference workflows of MAISI are depicted in the figure below. It begins by training an autoencoder in pixel space to encode images into latent features. Following that, it trains a diffusion model in the latent space to denoise the noisy latent features. During inference, it first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Finally, it decodes the denoised latent features into images using the trained autoencoder.
-
+
Figure 1: MAISI training scheme
-
Figure 2: MAISI inference scheme
@@ -120,6 +131,8 @@ MAISI is based on the following papers:
[**ControlNet:** Lvmin Zhang, Anyi Rao, Maneesh Agrawala; βAdding Conditional Control to Text-to-Image Diffusion Models.β ICCV 2023.](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhang_Adding_Conditional_Control_to_Text-to-Image_Diffusion_Models_ICCV_2023_paper.pdf)
+[**Rectified Flow:** Liu, Xingchao, Chengyue Gong, and Qiang Liu. "Flow straight and fast: Learning to generate and transfer data with rectified flow." ICLR 2023](https://arxiv.org/pdf/2209.03003).
+
### 1. Network Definition
Network definition is stored in [./configs/config_maisi.json](./configs/config_maisi.json). Training and inference should use the same [./configs/config_maisi.json](./configs/config_maisi.json).
@@ -131,7 +144,7 @@ The information for the inference input, such as the body region and anatomy to
- `"spacing"`: The voxel size of the generated images. For example, if set to `[1.5, 1.5, 2.0]`, it generates images with a resolution of 1.5x1.5x2.0 mm.
- `"output_size"`: The volume size of the generated images. For example, if set to `[512, 512, 256]`, it generates images of size 512x512x256. The values must be divisible by 16. If GPU memory is limited, adjust these to smaller numbers. Note that `"spacing"` and `"output_size"` together determine the output field of view (FOV). For example, if set to `[1.5, 1.5, 2.0]` mm and `[512, 512, 256]`, the FOV is 768x768x512 mm. We recommend the FOV in the x and y axes to be at least 256 mm for the head and at least 384 mm for other body regions like the abdomen. There is no restriction for the z-axis.
- `"controllable_anatomy_size"`: A list specifying controllable anatomy and their size scale (0β1). For example, if set to `[["liver", 0.5], ["hepatic tumor", 0.3]]`, the generated image will contain a liver of median size (around the 50th percentile) and a relatively small hepatic tumor (around the 30th percentile). The output will include paired images and segmentation masks for the controllable anatomy.
-- `"body_region"`: If `"controllable_anatomy_size"` is not specified, `"body_region"` will constrain the region of the generated images. It must be chosen from `"head"`, `"chest"`, `"thorax"`, `"abdomen"`, `"pelvis"`, or `"lower"`. Please set a reasonable `"body_region"` for the given FOV determined by `"spacing"` and `"output_size"`. For example, if FOV is only 128mm in z-axis, we should not expect `"body_region"` to contain all of [`"head"`, `"chest"`, `"thorax"`, `"abdomen"`, `"pelvis"`, `"lower"`].
+- `"body_region"`: For `maisi3d_rflow`, it is deprecated and can be set as `[]`. The output body region will be determined by `"anatomy_list"`. For `maisi3d_ddpm`, if `"controllable_anatomy_size"` is not specified, `"body_region"` will constrain the region of the generated images. It must be chosen from `"head"`, `"chest"`, `"thorax"`, `"abdomen"`, `"pelvis"`, or `"lower"`. Please set a reasonable `"body_region"` for the given FOV determined by `"spacing"` and `"output_size"`. For example, if FOV is only 128mm in z-axis, we should not expect `"body_region"` to contain all of [`"head"`, `"chest"`, `"thorax"`, `"abdomen"`, `"pelvis"`, `"lower"`].
- `"anatomy_list"`: If `"controllable_anatomy_size"` is not specified, the output will include paired images and segmentation masks for the anatomy listed in `"./configs/label_dict.json"`.
- `"autoencoder_sliding_window_infer_size"`: To save GPU memory, sliding window inference is used when decoding latents into images if `"output_size"` is large. This parameter specifies the patch size of the sliding window. Smaller values reduce GPU memory usage but increase the time cost. The values must be divisible by 16. If GPU memory is sufficient, select a larger value for this parameter.
- `"autoencoder_sliding_window_infer_overlap"`: A float between 0 and 1. Larger values reduce stitching artifacts when patches are stitched during sliding window inference but increase the time cost. If you do not observe seam lines in the generated image, you can use a smaller value to save inference time.
@@ -164,12 +177,21 @@ For example,
|[512, 512, 512] | [1.0, 1.0, 1.0] |
#### Execute Inference:
-To run the inference script, please run:
+To run the inference script with MAISI DDPM, please set `"num_inference_steps": 1000` in `./configs/config_infer.json`, and run:
```bash
export MONAI_DATA_DIRECTORY=
-python -m scripts.inference -c ./configs/config_maisi.json -i ./configs/config_infer.json -e ./configs/environment.json --random-seed 0
+python -m scripts.inference -c ./configs/config_maisi3d-ddpm.json -i ./configs/config_infer.json -e ./configs/environment_maisi3d-ddpm.json --random-seed 0 --version maisi3d-ddpm
```
+To run the inference script with MAISI RFlow, please set `"num_inference_steps": 30` in `./configs/config_infer.json`, and run:
+```bash
+export MONAI_DATA_DIRECTORY=
+python -m scripts.inference -c ./configs/config_maisi3d-rflow.json -i ./configs/config_infer.json -e ./configs/environment_maisi3d-rflow.json --random-seed 0 --version maisi3d-rflow
+```
+
+If GPU OOM happens, please increase `autoencoder_tp_num_splits` or reduce `autoencoder_sliding_window_infer_size` in `./configs/config_infer.json`.
+To reduce time cost, please reduce `autoencoder_sliding_window_infer_overlap` in `./configs/config_infer.json`, while monitoring whether stitching artifact occurs.
+
Please refer to [maisi_inference_tutorial.ipynb](maisi_inference_tutorial.ipynb) for the tutorial for MAISI model inference.
@@ -177,7 +199,12 @@ Please refer to [maisi_inference_tutorial.ipynb](maisi_inference_tutorial.ipynb)
To run the inference script with TensorRT acceleration, please run:
```bash
export MONAI_DATA_DIRECTORY=
-python -m scripts.inference -c ./configs/config_maisi.json -i ./configs/config_infer.json -e ./configs/environment.json -x ./configs/config_trt.json --random-seed 0
+python -m scripts.inference -c ./configs/config_maisi3d-ddpm.json -i ./configs/config_infer.json -e ./configs/environment_maisi3d-ddpm.json -x ./configs/config_trt.json --random-seed 0 --version maisi3d-ddpm
+```
+
+```bash
+export MONAI_DATA_DIRECTORY=
+python -m scripts.inference -c ./configs/config_maisi3d-rflow.json -i ./configs/config_infer.json -e ./configs/environment_maisi3d-rflow.json -x ./configs/config_trt.json --random-seed 0 --version maisi3d-rflow
```
Extra config file, [./configs/config_trt.json](./configs/config_trt.json) is using `trt_compile()` utility from MONAI to convert select modules to TensorRT by overriding their definitions from [./configs/config_infer.json](./configs/config_infer.json).
@@ -215,7 +242,7 @@ Please refer to [maisi_train_vae_tutorial.ipynb](maisi_train_vae_tutorial.ipynb)
#### [3.2 3D Latent Diffusion Training](./scripts/diff_model_train.py)
-Please refer to [maisi_diff_unet_training_tutorial.ipynb](maisi_diff_unet_training_tutorial.ipynb) for the tutorial for MAISI diffusion model training.
+Please refer to [maisi_train_diff_unet_tutorial.ipynb](maisi_train_diff_unet_tutorial.ipynb) for the tutorial for MAISI diffusion model training.
#### [3.3 3D ControlNet Training](./scripts/train_controlnet.py)
@@ -233,7 +260,11 @@ The training was performed with the following:
#### Execute Training:
To train with a single GPU, please run:
```bash
-python -m scripts.train_controlnet -c ./configs/config_maisi.json -t ./configs/config_maisi_controlnet_train.json -e ./configs/environment_maisi_controlnet_train.json -g 1
+python -m scripts.train_controlnet -c ./configs/config_maisi3d-ddpm.json -t ./configs/config_maisi_controlnet_train.json -e ./configs/environment_maisi_controlnet_train.json -g 1
+```
+
+```bash
+python -m scripts.train_controlnet -c ./configs/config_maisi3d-rflow.json -t ./configs/config_maisi_controlnet_train.json -e ./configs/environment_maisi_controlnet_train.json -g 1
```
The training script also enables multi-GPU training. For instance, if you are using eight GPUs, you can run the training script with the following command:
@@ -243,7 +274,16 @@ torchrun \
--nproc_per_node=${NUM_GPUS_PER_NODE} \
--nnodes=1 \
--master_addr=localhost --master_port=1234 \
- -m scripts.train_controlnet -c ./configs/config_maisi.json -t ./configs/config_maisi_controlnet_train.json -e ./configs/environment_maisi_controlnet_train.json -g ${NUM_GPUS_PER_NODE}
+ -m scripts.train_controlnet -c ./configs/config_maisi3d-ddpm.json -t ./configs/config_maisi_controlnet_train.json -e ./configs/environment_maisi_controlnet_train.json -g ${NUM_GPUS_PER_NODE}
+```
+
+```bash
+export NUM_GPUS_PER_NODE=8
+torchrun \
+ --nproc_per_node=${NUM_GPUS_PER_NODE} \
+ --nnodes=1 \
+ --master_addr=localhost --master_port=1234 \
+ -m scripts.train_controlnet -c ./configs/config_maisi3d-rflow.json -t ./configs/config_maisi_controlnet_train.json -e ./configs/environment_maisi_controlnet_train.json -g ${NUM_GPUS_PER_NODE}
```
Please also check [maisi_train_controlnet_tutorial.ipynb](./maisi_train_controlnet_tutorial.ipynb) for more details about data preparation and training parameters.
diff --git a/generation/maisi/configs/config_infer.json b/generation/maisi/configs/config_infer.json
index fc08a7bda4..19b12a6d22 100644
--- a/generation/maisi/configs/config_infer.json
+++ b/generation/maisi/configs/config_infer.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "body_region": ["chest"],
+ "anatomy_list": ["lung tumor"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
256,
@@ -18,10 +18,11 @@
2.0
],
"autoencoder_sliding_window_infer_size": [48,48,48],
- "autoencoder_sliding_window_infer_overlap": 0.25,
+ "autoencoder_sliding_window_infer_overlap": 0.6666,
"controlnet": "$@controlnet_def",
"diffusion_unet": "$@diffusion_unet_def",
"autoencoder": "$@autoencoder_def",
"mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
- "mask_generation_diffusion": "$@mask_generation_diffusion_def"
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_infer_16g_256x256x128.json b/generation/maisi/configs/config_infer_16g_256x256x128.json
index 72933304ba..1c6d424f2e 100644
--- a/generation/maisi/configs/config_infer_16g_256x256x128.json
+++ b/generation/maisi/configs/config_infer_16g_256x256x128.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "body_region": ["chest"],
+ "anatomy_list": ["lung tumor"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
256,
@@ -19,5 +19,11 @@
],
"autoencoder_sliding_window_infer_size": [96,96,96],
"autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 2
+ "autoencoder_tp_num_splits": 2,
+ "controlnet": "$@controlnet_def",
+ "diffusion_unet": "$@diffusion_unet_def",
+ "autoencoder": "$@autoencoder_def",
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_infer_16g_256x256x256.json b/generation/maisi/configs/config_infer_16g_256x256x256.json
index d4ec9e1a88..8ccd0bc2ca 100644
--- a/generation/maisi/configs/config_infer_16g_256x256x256.json
+++ b/generation/maisi/configs/config_infer_16g_256x256x256.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "body_region": ["chest"],
+ "anatomy_list": ["lung tumor"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
256,
@@ -18,6 +18,12 @@
2.0
],
"autoencoder_sliding_window_infer_size": [48,48,64],
- "autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 2
+ "autoencoder_sliding_window_infer_overlap": 0.6666,
+ "autoencoder_tp_num_splits": 4,
+ "controlnet": "$@controlnet_def",
+ "diffusion_unet": "$@diffusion_unet_def",
+ "autoencoder": "$@autoencoder_def",
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_infer_16g_512x512x128.json b/generation/maisi/configs/config_infer_16g_512x512x128.json
index 5e067cd4b4..ec80d72a84 100644
--- a/generation/maisi/configs/config_infer_16g_512x512x128.json
+++ b/generation/maisi/configs/config_infer_16g_512x512x128.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "body_region": ["chest"],
+ "anatomy_list": ["lung tumor"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
512,
@@ -18,6 +18,12 @@
4.0
],
"autoencoder_sliding_window_infer_size": [64,64,32],
- "autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 1
+ "autoencoder_sliding_window_infer_overlap": 0.5,
+ "autoencoder_tp_num_splits": 2,
+ "controlnet": "$@controlnet_def",
+ "diffusion_unet": "$@diffusion_unet_def",
+ "autoencoder": "$@autoencoder_def",
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_infer_24g_256x256x256.json b/generation/maisi/configs/config_infer_24g_256x256x256.json
index bb0806f635..a0be706b1c 100644
--- a/generation/maisi/configs/config_infer_24g_256x256x256.json
+++ b/generation/maisi/configs/config_infer_24g_256x256x256.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "body_region": ["chest"],
+ "anatomy_list": ["lung tumor"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
256,
@@ -19,5 +19,11 @@
],
"autoencoder_sliding_window_infer_size": [64,64,64],
"autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 4
+ "autoencoder_tp_num_splits": 4,
+ "controlnet": "$@controlnet_def",
+ "diffusion_unet": "$@diffusion_unet_def",
+ "autoencoder": "$@autoencoder_def",
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_infer_24g_512x512x128.json b/generation/maisi/configs/config_infer_24g_512x512x128.json
index 6d2b9d7eab..95bd38795a 100644
--- a/generation/maisi/configs/config_infer_24g_512x512x128.json
+++ b/generation/maisi/configs/config_infer_24g_512x512x128.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "body_region": ["chest"],
+ "anatomy_list": ["lung tumor"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
512,
@@ -18,6 +18,12 @@
4.0
],
"autoencoder_sliding_window_infer_size": [80,80,32],
- "autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 1
+ "autoencoder_sliding_window_infer_overlap": 0.4,
+ "autoencoder_tp_num_splits": 2,
+ "controlnet": "$@controlnet_def",
+ "diffusion_unet": "$@diffusion_unet_def",
+ "autoencoder": "$@autoencoder_def",
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_infer_24g_512x512x512.json b/generation/maisi/configs/config_infer_24g_512x512x512.json
index 2cbfb9573f..0606c5e945 100644
--- a/generation/maisi/configs/config_infer_24g_512x512x512.json
+++ b/generation/maisi/configs/config_infer_24g_512x512x512.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "body_region": ["chest"],
+ "anatomy_list": ["lung tumor"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
512,
@@ -18,6 +18,12 @@
1.0
],
"autoencoder_sliding_window_infer_size": [64,64,48],
- "autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 2
+ "autoencoder_sliding_window_infer_overlap": 0.4,
+ "autoencoder_tp_num_splits": 2,
+ "controlnet": "$@controlnet_def",
+ "diffusion_unet": "$@diffusion_unet_def",
+ "autoencoder": "$@autoencoder_def",
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_infer_32g_512x512x512.json b/generation/maisi/configs/config_infer_32g_512x512x512.json
index 5dcbcacbe0..5044955c99 100644
--- a/generation/maisi/configs/config_infer_32g_512x512x512.json
+++ b/generation/maisi/configs/config_infer_32g_512x512x512.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "body_region": ["chest"],
+ "anatomy_list": ["lung tumor"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
512,
@@ -17,7 +17,13 @@
0.75,
1.0
],
- "autoencoder_sliding_window_infer_size": [64,64,64],
- "autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 2
+ "autoencoder_sliding_window_infer_size": [80,80,48],
+ "autoencoder_sliding_window_infer_overlap": 0.4,
+ "autoencoder_tp_num_splits": 4,
+ "controlnet": "$@controlnet_def",
+ "diffusion_unet": "$@diffusion_unet_def",
+ "autoencoder": "$@autoencoder_def",
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_infer_80g_512x512x128.json b/generation/maisi/configs/config_infer_80g_512x512x128.json
deleted file mode 100644
index d20dbbc76b..0000000000
--- a/generation/maisi/configs/config_infer_80g_512x512x128.json
+++ /dev/null
@@ -1,23 +0,0 @@
-{
- "num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
- "controllable_anatomy_size": [],
- "num_inference_steps": 1000,
- "mask_generation_num_inference_steps": 1000,
- "output_size": [
- 512,
- 512,
- 128
- ],
- "image_output_ext": ".nii.gz",
- "label_output_ext": ".nii.gz",
- "spacing": [
- 0.75,
- 0.75,
- 4.0
- ],
- "autoencoder_sliding_window_infer_size": [128,128,32],
- "autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 4
-}
diff --git a/generation/maisi/configs/config_infer_80g_512x512x512.json b/generation/maisi/configs/config_infer_80g_512x512x512.json
index bfcd6b7dc7..71f5d031e3 100644
--- a/generation/maisi/configs/config_infer_80g_512x512x512.json
+++ b/generation/maisi/configs/config_infer_80g_512x512x512.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
- "body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "body_region": ["chest"],
+ "anatomy_list": ["lung tumor"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
512,
@@ -18,6 +18,12 @@
1.0
],
"autoencoder_sliding_window_infer_size": [80,80,80],
- "autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 2
+ "autoencoder_sliding_window_infer_overlap": 0.4,
+ "autoencoder_tp_num_splits": 2,
+ "controlnet": "$@controlnet_def",
+ "diffusion_unet": "$@diffusion_unet_def",
+ "autoencoder": "$@autoencoder_def",
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_infer_80g_512x512x768.json b/generation/maisi/configs/config_infer_80g_512x512x768.json
index 9cb7e61b61..9d1bee4cd2 100644
--- a/generation/maisi/configs/config_infer_80g_512x512x768.json
+++ b/generation/maisi/configs/config_infer_80g_512x512x768.json
@@ -1,9 +1,9 @@
{
"num_output_samples": 1,
"body_region": ["abdomen"],
- "anatomy_list": ["liver","hepatic tumor"],
+ "anatomy_list": ["liver"],
"controllable_anatomy_size": [],
- "num_inference_steps": 1000,
+ "num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"output_size": [
512,
@@ -17,7 +17,13 @@
0.75,
0.66667
],
- "autoencoder_sliding_window_infer_size": [80,80,112],
- "autoencoder_sliding_window_infer_overlap": 0.25,
- "autoencoder_tp_num_splits": 4
+ "autoencoder_sliding_window_infer_size": [80,80,96],
+ "autoencoder_sliding_window_infer_overlap": 0.4,
+ "autoencoder_tp_num_splits": 4,
+ "controlnet": "$@controlnet_def",
+ "diffusion_unet": "$@diffusion_unet_def",
+ "autoencoder": "$@autoencoder_def",
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
+ "mask_generation_diffusion": "$@mask_generation_diffusion_def",
+ "modality": 1
}
diff --git a/generation/maisi/configs/config_maisi.json b/generation/maisi/configs/config_maisi3d-ddpm.json
similarity index 96%
rename from generation/maisi/configs/config_maisi.json
rename to generation/maisi/configs/config_maisi3d-ddpm.json
index 8a781ca5b4..67b38e9c15 100644
--- a/generation/maisi/configs/config_maisi.json
+++ b/generation/maisi/configs/config_maisi3d-ddpm.json
@@ -2,6 +2,7 @@
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 4,
+ "include_body_region": true,
"mask_generation_latent_shape": [
4,
64,
@@ -60,8 +61,8 @@
],
"num_res_blocks": 2,
"use_flash_attention": true,
- "include_top_region_index_input": true,
- "include_bottom_region_index_input": true,
+ "include_top_region_index_input": "@include_body_region",
+ "include_bottom_region_index_input": "@include_body_region",
"include_spacing_input": true
},
"controlnet_def": {
diff --git a/generation/maisi/configs/config_maisi3d-rflow.json b/generation/maisi/configs/config_maisi3d-rflow.json
new file mode 100644
index 0000000000..d76da08bf3
--- /dev/null
+++ b/generation/maisi/configs/config_maisi3d-rflow.json
@@ -0,0 +1,150 @@
+{
+ "spatial_dims": 3,
+ "image_channels": 1,
+ "latent_channels": 4,
+ "include_body_region": false,
+ "mask_generation_latent_shape": [
+ 4,
+ 64,
+ 64,
+ 64
+ ],
+ "autoencoder_def": {
+ "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
+ "spatial_dims": "@spatial_dims",
+ "in_channels": "@image_channels",
+ "out_channels": "@image_channels",
+ "latent_channels": "@latent_channels",
+ "num_channels": [
+ 64,
+ 128,
+ 256
+ ],
+ "num_res_blocks": [2,2,2],
+ "norm_num_groups": 32,
+ "norm_eps": 1e-06,
+ "attention_levels": [
+ false,
+ false,
+ false
+ ],
+ "with_encoder_nonlocal_attn": false,
+ "with_decoder_nonlocal_attn": false,
+ "use_checkpointing": false,
+ "use_convtranspose": false,
+ "norm_float16": true,
+ "num_splits": 4,
+ "dim_split": 1
+ },
+ "diffusion_unet_def": {
+ "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
+ "spatial_dims": "@spatial_dims",
+ "in_channels": "@latent_channels",
+ "out_channels": "@latent_channels",
+ "num_channels": [64, 128, 256, 512],
+ "attention_levels": [
+ false,
+ false,
+ true,
+ true
+ ],
+ "num_head_channels": [
+ 0,
+ 0,
+ 32,
+ 32
+ ],
+ "num_res_blocks": 2,
+ "use_flash_attention": true,
+ "include_top_region_index_input": "@include_body_region",
+ "include_bottom_region_index_input": "@include_body_region",
+ "include_spacing_input": true,
+ "num_class_embeds": 128,
+ "resblock_updown": true,
+ "include_fc": true
+ },
+ "controlnet_def": {
+ "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
+ "spatial_dims": "@spatial_dims",
+ "in_channels": "@latent_channels",
+ "num_channels": [64, 128, 256, 512],
+ "attention_levels": [
+ false,
+ false,
+ true,
+ true
+ ],
+ "num_head_channels": [
+ 0,
+ 0,
+ 32,
+ 32
+ ],
+ "num_res_blocks": 2,
+ "use_flash_attention": true,
+ "conditioning_embedding_in_channels": 8,
+ "conditioning_embedding_num_channels": [8, 32, 64],
+ "num_class_embeds": 128,
+ "resblock_updown": true,
+ "include_fc": true
+ },
+ "mask_generation_autoencoder_def": {
+ "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
+ "spatial_dims": "@spatial_dims",
+ "in_channels": 8,
+ "out_channels": 125,
+ "latent_channels": "@latent_channels",
+ "num_channels": [
+ 32,
+ 64,
+ 128
+ ],
+ "num_res_blocks": [1, 2, 2],
+ "norm_num_groups": 32,
+ "norm_eps": 1e-06,
+ "attention_levels": [
+ false,
+ false,
+ false
+ ],
+ "with_encoder_nonlocal_attn": false,
+ "with_decoder_nonlocal_attn": false,
+ "use_flash_attention": false,
+ "use_checkpointing": true,
+ "use_convtranspose": true,
+ "norm_float16": true,
+ "num_splits": 8,
+ "dim_split": 1
+ },
+ "mask_generation_diffusion_def": {
+ "_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
+ "spatial_dims": "@spatial_dims",
+ "in_channels": "@latent_channels",
+ "out_channels": "@latent_channels",
+ "channels":[64, 128, 256, 512],
+ "attention_levels":[false, false, true, true],
+ "num_head_channels":[0, 0, 32, 32],
+ "num_res_blocks": 2,
+ "use_flash_attention": true,
+ "with_conditioning": true,
+ "upcast_attention": true,
+ "cross_attention_dim": 10
+ },
+ "mask_generation_scale_factor": 1.0055984258651733,
+ "noise_scheduler": {
+ "_target_": "monai.networks.schedulers.rectified_flow.RFlowScheduler",
+ "num_train_timesteps": 1000,
+ "use_discrete_timesteps": false,
+ "use_timestep_transform": true,
+ "sample_method": "uniform",
+ "scale":1.4
+ },
+ "mask_generation_noise_scheduler": {
+ "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
+ "num_train_timesteps": 1000,
+ "beta_start": 0.0015,
+ "beta_end": 0.0195,
+ "schedule": "scaled_linear_beta",
+ "clip_sample": false
+ }
+}
diff --git a/generation/maisi/configs/config_maisi_controlnet_train.json b/generation/maisi/configs/config_maisi_controlnet_train.json
index 4ac94efe63..9bca444bd4 100644
--- a/generation/maisi/configs/config_maisi_controlnet_train.json
+++ b/generation/maisi/configs/config_maisi_controlnet_train.json
@@ -9,7 +9,9 @@
"weighted_loss": 100
},
"controlnet_infer": {
- "num_inference_steps": 1000,
- "autoencoder_sliding_window_infer_size": [96, 96, 96]
+ "num_inference_steps": 10,
+ "autoencoder_sliding_window_infer_size": [80, 80, 80],
+ "autoencoder_sliding_window_infer_overlap": 0.4,
+ "modality": 1
}
}
diff --git a/generation/maisi/configs/config_maisi_diff_model.json b/generation/maisi/configs/config_maisi_diff_model.json
index 8407dbdcc1..f97a749c89 100644
--- a/generation/maisi/configs/config_maisi_diff_model.json
+++ b/generation/maisi/configs/config_maisi_diff_model.json
@@ -29,6 +29,7 @@
0
],
"random_seed": 0,
- "num_inference_steps": 10
+ "num_inference_steps": 10,
+ "modality": 1
}
}
diff --git a/generation/maisi/configs/environment.json b/generation/maisi/configs/environment_maisi3d-ddpm.json
similarity index 75%
rename from generation/maisi/configs/environment.json
rename to generation/maisi/configs/environment_maisi3d-ddpm.json
index 5e017645f1..97ba2b590a 100644
--- a/generation/maisi/configs/environment.json
+++ b/generation/maisi/configs/environment_maisi3d-ddpm.json
@@ -1,8 +1,8 @@
{
"output_dir": "output",
"trained_autoencoder_path": "models/autoencoder_epoch273.pt",
- "trained_diffusion_path": "models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt",
- "trained_controlnet_path": "models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt",
+ "trained_diffusion_path": "models/diff_unet_3d_ddpm.pt",
+ "trained_controlnet_path": "models/controlnet_3d_ddpm.pt",
"trained_mask_generation_autoencoder_path": "models/mask_generation_autoencoder.pt",
"trained_mask_generation_diffusion_path": "models/mask_generation_diffusion_unet.pt",
"all_mask_files_base_dir": "datasets/all_masks_flexible_size_and_spacing_3000",
diff --git a/generation/maisi/configs/environment_maisi3d-rflow.json b/generation/maisi/configs/environment_maisi3d-rflow.json
new file mode 100644
index 0000000000..c1be37b592
--- /dev/null
+++ b/generation/maisi/configs/environment_maisi3d-rflow.json
@@ -0,0 +1,13 @@
+{
+ "output_dir": "output",
+ "trained_autoencoder_path": "models/autoencoder_epoch273.pt",
+ "trained_diffusion_path": "models/diff_unet_3d_rflow.pt",
+ "trained_controlnet_path": "models/controlnet_3d_rflow.pt",
+ "trained_mask_generation_autoencoder_path": "models/mask_generation_autoencoder.pt",
+ "trained_mask_generation_diffusion_path": "models/mask_generation_diffusion_unet.pt",
+ "all_mask_files_base_dir": "datasets/all_masks_flexible_size_and_spacing_4000",
+ "all_mask_files_json": "./configs/candidate_masks_flexible_size_and_spacing_4000.json",
+ "all_anatomy_size_conditions_json": "./configs/all_anatomy_size_condtions.json",
+ "label_dict_json": "./configs/label_dict.json",
+ "label_dict_remap_json": "./configs/label_dict_124_to_132.json"
+}
diff --git a/generation/maisi/configs/modality_mapping.json b/generation/maisi/configs/modality_mapping.json
new file mode 100644
index 0000000000..38bd3ee321
--- /dev/null
+++ b/generation/maisi/configs/modality_mapping.json
@@ -0,0 +1,15 @@
+{
+ "unknown":0,
+ "ct":1,
+ "ct_wo_contrast":2,
+ "ct_contrast":3,
+ "mri":8,
+ "mri_t1":9,
+ "mri_t2":10,
+ "mri_flair":11,
+ "mri_pd":12,
+ "mri_dwi":13,
+ "mri_adc":14,
+ "mri_ssfp":15,
+ "mri_mra":16
+}
diff --git a/generation/maisi/figures/maisi_infer.jpg b/generation/maisi/figures/maisi_infer.jpg
deleted file mode 100644
index 9210da5fdd..0000000000
Binary files a/generation/maisi/figures/maisi_infer.jpg and /dev/null differ
diff --git a/generation/maisi/figures/maisi_infer.png b/generation/maisi/figures/maisi_infer.png
new file mode 100644
index 0000000000..4bd18ea188
Binary files /dev/null and b/generation/maisi/figures/maisi_infer.png differ
diff --git a/generation/maisi/figures/maisi_train.jpg b/generation/maisi/figures/maisi_train.jpg
deleted file mode 100644
index 8c4936456d..0000000000
Binary files a/generation/maisi/figures/maisi_train.jpg and /dev/null differ
diff --git a/generation/maisi/figures/maisi_train.png b/generation/maisi/figures/maisi_train.png
new file mode 100644
index 0000000000..d0ec7fda8e
Binary files /dev/null and b/generation/maisi/figures/maisi_train.png differ
diff --git a/generation/maisi/maisi_inference_tutorial.ipynb b/generation/maisi/maisi_inference_tutorial.ipynb
index d121f886b5..69ce91b9b5 100644
--- a/generation/maisi/maisi_inference_tutorial.ipynb
+++ b/generation/maisi/maisi_inference_tutorial.ipynb
@@ -18,7 +18,9 @@
"\n",
"# MAISI Inference Tutorial\n",
"\n",
- "This tutorial illustrates how to use trained MAISI model and codebase to generate synthetic 3D images and paired masks."
+ "This tutorial illustrates how to use trained MAISI model and codebase to generate synthetic 3D images and paired masks.\n",
+ "\n",
+ "`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi3d-rflow'`. Compared with the previous version `'maisi3d-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section."
]
},
{
@@ -61,32 +63,32 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "MONAI version: 1.4.0rc10\n",
- "Numpy version: 1.24.4\n",
- "Pytorch version: 2.5.0a0+872d972e41.nv24.08.01\n",
+ "MONAI version: 1.4.1rc1+32.g34f37973\n",
+ "Numpy version: 1.26.4\n",
+ "Pytorch version: 2.5.0+cu124\n",
"MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
- "MONAI rev id: cac21f6936a2e8d6e4e57e4e958f8e32aae1585e\n",
- "MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py\n",
+ "MONAI rev id: 34f379735c5e18e7f809453eb1b3606c225c788b\n",
+ "MONAI __file__: /localhome//.local/lib/python3.10/site-packages/monai/__init__.py\n",
"\n",
"Optional dependencies:\n",
"Pytorch Ignite version: 0.4.11\n",
"ITK version: 5.4.0\n",
- "Nibabel version: 5.2.1\n",
- "scikit-image version: 0.23.2\n",
- "scipy version: 1.13.1\n",
- "Pillow version: 10.4.0\n",
- "Tensorboard version: 2.17.0\n",
+ "Nibabel version: 5.3.2\n",
+ "scikit-image version: 0.24.0\n",
+ "scipy version: 1.14.1\n",
+ "Pillow version: 11.0.0\n",
+ "Tensorboard version: 2.18.0\n",
"gdown version: 5.2.0\n",
- "TorchVision version: 0.20.0a0\n",
- "tqdm version: 4.66.4\n",
+ "TorchVision version: 0.20.0+cu124\n",
+ "tqdm version: 4.66.5\n",
"lmdb version: 1.5.1\n",
- "psutil version: 5.9.8\n",
- "pandas version: 2.2.2\n",
- "einops version: 0.7.0\n",
+ "psutil version: 6.1.0\n",
+ "pandas version: 2.2.3\n",
+ "einops version: 0.8.0\n",
"transformers version: 4.40.2\n",
- "mlflow version: 2.16.0\n",
+ "mlflow version: 2.17.1\n",
"pynrrd version: 1.0.0\n",
- "clearml version: 1.16.3\n",
+ "clearml version: 1.16.5rc2\n",
"\n",
"For details about installing the optional dependencies, please visit:\n",
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
@@ -109,8 +111,52 @@
"from scripts.sample import LDMSampler, check_input\n",
"from scripts.utils import define_instance\n",
"from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image\n",
+ "from scripts.diff_model_setting import setup_logging\n",
+ "\n",
+ "print_config()\n",
+ "\n",
+ "logger = setup_logging(\"notebook\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "37c2759e-d1fa-42d7-b208-fbe306ac1e06",
+ "metadata": {},
+ "source": [
+ "## Set up the MAISI version\n",
"\n",
- "print_config()"
+ "Choose between `'maisi3d-ddpm'` and `'maisi3d-rflow'`. The differences are:\n",
+ "- The maisi version `'maisi3d-ddpm'` uses basic noise scheduler DDPM. `'maisi3d-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n",
+ "- The maisi version `'maisi3d-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi3d-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi3d-rflow'`.\n",
+ "- For the released model weights, `'maisi3d-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi3d-ddpm'`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "bf4252b1-089d-48c1-b6d6-aa24a93a5839",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[2025-03-19 05:06:22.467][ INFO](notebook) - MAISI version is maisi3d-rflow, whether to use body_region is False\n"
+ ]
+ }
+ ],
+ "source": [
+ "maisi_version = \"maisi3d-rflow\"\n",
+ "if maisi_version == \"maisi3d-ddpm\":\n",
+ " model_def_path = \"./configs/config_maisi3d-ddpm.json\"\n",
+ "elif maisi_version == \"maisi3d-rflow\":\n",
+ " model_def_path = \"./configs/config_maisi3d-rflow.json\"\n",
+ "else:\n",
+ " raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n",
+ "with open(model_def_path, \"r\") as f:\n",
+ " model_def = json.load(f)\n",
+ "include_body_region = model_def[\"include_body_region\"]\n",
+ "logger.info(f\"MAISI version is {maisi_version}, whether to use body_region is {include_body_region}\")"
]
},
{
@@ -119,14 +165,12 @@
"metadata": {},
"source": [
"## Setup data directory\n",
- "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.\n",
- "This allows you to save results and reuse downloads.\n",
- "If not specified a temporary directory will be used."
+ "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable. This allows you to save results and reuse downloads. If not specified a temporary directory will be used."
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"id": "e3c12dcc",
"metadata": {},
"outputs": [
@@ -134,26 +178,27 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "2024-09-30 06:40:49,932 - INFO - Expected md5 is None, skip md5 check for file models/autoencoder_epoch273.pt.\n",
- "2024-09-30 06:40:49,933 - INFO - File exists: models/autoencoder_epoch273.pt, skipped downloading.\n",
- "2024-09-30 06:40:49,933 - INFO - Expected md5 is None, skip md5 check for file models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt.\n",
- "2024-09-30 06:40:49,933 - INFO - File exists: models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt, skipped downloading.\n",
- "2024-09-30 06:40:49,934 - INFO - Expected md5 is None, skip md5 check for file models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt.\n",
- "2024-09-30 06:40:49,934 - INFO - File exists: models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt, skipped downloading.\n",
- "2024-09-30 06:40:49,934 - INFO - Expected md5 is None, skip md5 check for file models/mask_generation_autoencoder.pt.\n",
- "2024-09-30 06:40:49,934 - INFO - File exists: models/mask_generation_autoencoder.pt, skipped downloading.\n",
- "2024-09-30 06:40:49,935 - INFO - Expected md5 is None, skip md5 check for file models/mask_generation_diffusion_unet.pt.\n",
- "2024-09-30 06:40:49,935 - INFO - File exists: models/mask_generation_diffusion_unet.pt, skipped downloading.\n",
- "2024-09-30 06:40:49,935 - INFO - Expected md5 is None, skip md5 check for file configs/candidate_masks_flexible_size_and_spacing_3000.json.\n",
- "2024-09-30 06:40:49,935 - INFO - File exists: configs/candidate_masks_flexible_size_and_spacing_3000.json, skipped downloading.\n",
- "2024-09-30 06:40:49,936 - INFO - Expected md5 is None, skip md5 check for file configs/all_anatomy_size_condtions.json.\n",
- "2024-09-30 06:40:49,936 - INFO - File exists: configs/all_anatomy_size_condtions.json, skipped downloading.\n",
- "2024-09-30 06:40:49,936 - INFO - Expected md5 is None, skip md5 check for file /workspace/data/datasets/all_masks_flexible_size_and_spacing_3000.zip.\n",
- "2024-09-30 06:40:49,936 - INFO - File exists: /workspace/data/datasets/all_masks_flexible_size_and_spacing_3000.zip, skipped downloading.\n"
+ "2025-03-19 05:06:22,476 - INFO - Expected md5 is None, skip md5 check for file models/autoencoder_epoch273.pt.\n",
+ "2025-03-19 05:06:22,477 - INFO - File exists: models/autoencoder_epoch273.pt, skipped downloading.\n",
+ "2025-03-19 05:06:22,478 - INFO - Expected md5 is None, skip md5 check for file models/mask_generation_autoencoder.pt.\n",
+ "2025-03-19 05:06:22,478 - INFO - File exists: models/mask_generation_autoencoder.pt, skipped downloading.\n",
+ "2025-03-19 05:06:22,479 - INFO - Expected md5 is None, skip md5 check for file models/mask_generation_diffusion_unet.pt.\n",
+ "2025-03-19 05:06:22,480 - INFO - File exists: models/mask_generation_diffusion_unet.pt, skipped downloading.\n",
+ "2025-03-19 05:06:22,481 - INFO - Expected md5 is None, skip md5 check for file configs/all_anatomy_size_condtions.json.\n",
+ "2025-03-19 05:06:22,482 - INFO - File exists: configs/all_anatomy_size_condtions.json, skipped downloading.\n",
+ "2025-03-19 05:06:22,482 - INFO - Expected md5 is None, skip md5 check for file temp_work_dir_inference_demo/datasets/all_masks_flexible_size_and_spacing_4000.zip.\n",
+ "2025-03-19 05:06:22,483 - INFO - File exists: temp_work_dir_inference_demo/datasets/all_masks_flexible_size_and_spacing_4000.zip, skipped downloading.\n",
+ "2025-03-19 05:06:22,483 - INFO - Expected md5 is None, skip md5 check for file models/diff_unet_3d_rflow.pt.\n",
+ "2025-03-19 05:06:22,484 - INFO - File exists: models/diff_unet_3d_rflow.pt, skipped downloading.\n",
+ "2025-03-19 05:06:22,484 - INFO - Expected md5 is None, skip md5 check for file models/controlnet_3d_rflow.pt.\n",
+ "2025-03-19 05:06:22,485 - INFO - File exists: models/controlnet_3d_rflow.pt, skipped downloading.\n",
+ "2025-03-19 05:06:22,485 - INFO - Expected md5 is None, skip md5 check for file configs/candidate_masks_flexible_size_and_spacing_4000.json.\n",
+ "2025-03-19 05:06:22,486 - INFO - File exists: configs/candidate_masks_flexible_size_and_spacing_4000.json, skipped downloading.\n"
]
}
],
"source": [
+ "os.environ[\"MONAI_DATA_DIRECTORY\"] = \"temp_work_dir_inference_demo\"\n",
"directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
"if directory is not None:\n",
" os.makedirs(directory, exist_ok=True)\n",
@@ -167,16 +212,6 @@
" \"/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt\",\n",
" },\n",
" {\n",
- " \"path\": \"models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt\",\n",
- " \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo\"\n",
- " \"/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt\",\n",
- " },\n",
- " {\n",
- " \"path\": \"models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt\",\n",
- " \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo\"\n",
- " \"/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt\",\n",
- " },\n",
- " {\n",
" \"path\": \"models/mask_generation_autoencoder.pt\",\n",
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai\" \"/tutorials/mask_generation_autoencoder.pt\",\n",
" },\n",
@@ -186,21 +221,54 @@
" \"/tutorials/model_zoo/model_maisi_mask_generation_diffusion_unet_v2.pt\",\n",
" },\n",
" {\n",
- " \"path\": \"configs/candidate_masks_flexible_size_and_spacing_3000.json\",\n",
- " \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai\"\n",
- " \"/tutorials/candidate_masks_flexible_size_and_spacing_3000.json\",\n",
- " },\n",
- " {\n",
" \"path\": \"configs/all_anatomy_size_condtions.json\",\n",
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/all_anatomy_size_condtions.json\",\n",
" },\n",
" {\n",
- " \"path\": \"datasets/all_masks_flexible_size_and_spacing_3000.zip\",\n",
+ " \"path\": \"datasets/all_masks_flexible_size_and_spacing_4000.zip\",\n",
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai\"\n",
- " \"/tutorials/model_zoo/model_maisi_all_masks_flexible_size_and_spacing_3000.zip\",\n",
+ " \"/tutorials/all_masks_flexible_size_and_spacing_4000.zip\",\n",
" },\n",
"]\n",
"\n",
+ "if maisi_version == \"maisi3d-ddpm\":\n",
+ " files += [\n",
+ " {\n",
+ " \"path\": \"models/diff_unet_3d_ddpm.pt\",\n",
+ " \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo\"\n",
+ " \"/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt\",\n",
+ " },\n",
+ " {\n",
+ " \"path\": \"models/controlnet_3d_ddpm.pt\",\n",
+ " \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo\"\n",
+ " \"/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt\",\n",
+ " },\n",
+ " {\n",
+ " \"path\": \"configs/candidate_masks_flexible_size_and_spacing_3000.json\",\n",
+ " \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai\"\n",
+ " \"/tutorials/candidate_masks_flexible_size_and_spacing_3000.json\",\n",
+ " },\n",
+ " ]\n",
+ "elif maisi_version == \"maisi3d-rflow\":\n",
+ " files += [\n",
+ " {\n",
+ " \"path\": \"models/diff_unet_3d_rflow.pt\",\n",
+ " \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/\"\n",
+ " \"diff_unet_ckpt_rflow_epoch19350.pt\",\n",
+ " },\n",
+ " {\n",
+ " \"path\": \"models/controlnet_3d_rflow.pt\",\n",
+ " \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/controlnet_rflow_epoch60.pt\",\n",
+ " },\n",
+ " {\n",
+ " \"path\": \"configs/candidate_masks_flexible_size_and_spacing_4000.json\",\n",
+ " \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai\"\n",
+ " \"/tutorials/candidate_masks_flexible_size_and_spacing_4000.json\",\n",
+ " },\n",
+ " ]\n",
+ "else:\n",
+ " raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n",
+ "\n",
"for file in files:\n",
" file[\"path\"] = file[\"path\"] if \"datasets/\" not in file[\"path\"] else os.path.join(root_dir, file[\"path\"])\n",
" download_url(url=file[\"url\"], filepath=file[\"path\"])"
@@ -218,43 +286,49 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"id": "c38b4c33",
"metadata": {
"scrolled": true
},
"outputs": [
{
- "name": "stdout",
+ "name": "stderr",
"output_type": "stream",
"text": [
- "output_dir: output\n",
- "trained_autoencoder_path: models/autoencoder_epoch273.pt\n",
- "trained_diffusion_path: models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt\n",
- "trained_controlnet_path: models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt\n",
- "trained_mask_generation_autoencoder_path: models/mask_generation_autoencoder.pt\n",
- "trained_mask_generation_diffusion_path: models/mask_generation_diffusion_unet.pt\n",
- "all_mask_files_base_dir: /workspace/data/datasets/all_masks_flexible_size_and_spacing_3000\n",
- "all_mask_files_json: ./configs/candidate_masks_flexible_size_and_spacing_3000.json\n",
- "all_anatomy_size_conditions_json: ./configs/all_anatomy_size_condtions.json\n",
- "label_dict_json: ./configs/label_dict.json\n",
- "label_dict_remap_json: ./configs/label_dict_124_to_132.json\n",
- "Global config variables have been loaded.\n"
+ "[2025-03-19 05:06:22.493][ INFO](notebook) - output_dir: output\n",
+ "[2025-03-19 05:06:22.494][ INFO](notebook) - trained_autoencoder_path: models/autoencoder_epoch273.pt\n",
+ "[2025-03-19 05:06:22.494][ INFO](notebook) - trained_diffusion_path: models/diff_unet_3d_rflow.pt\n",
+ "[2025-03-19 05:06:22.495][ INFO](notebook) - trained_controlnet_path: models/controlnet_3d_rflow.pt\n",
+ "[2025-03-19 05:06:22.496][ INFO](notebook) - trained_mask_generation_autoencoder_path: models/mask_generation_autoencoder.pt\n",
+ "[2025-03-19 05:06:22.497][ INFO](notebook) - trained_mask_generation_diffusion_path: models/mask_generation_diffusion_unet.pt\n",
+ "[2025-03-19 05:06:22.497][ INFO](notebook) - all_mask_files_base_dir: temp_work_dir_inference_demo/datasets/all_masks_flexible_size_and_spacing_4000\n",
+ "[2025-03-19 05:06:22.498][ INFO](notebook) - all_mask_files_json: ./configs/candidate_masks_flexible_size_and_spacing_4000.json\n",
+ "[2025-03-19 05:06:22.498][ INFO](notebook) - all_anatomy_size_conditions_json: ./configs/all_anatomy_size_condtions.json\n",
+ "[2025-03-19 05:06:22.499][ INFO](notebook) - label_dict_json: ./configs/label_dict.json\n",
+ "[2025-03-19 05:06:22.500][ INFO](notebook) - label_dict_remap_json: ./configs/label_dict_124_to_132.json\n",
+ "[2025-03-19 05:06:22.501][ INFO](notebook) - Global config variables have been loaded.\n"
]
}
],
"source": [
"args = argparse.Namespace()\n",
"\n",
- "environment_file = \"./configs/environment.json\"\n",
+ "if maisi_version == \"maisi3d-ddpm\":\n",
+ " environment_file = \"./configs/environment_maisi3d-ddpm.json\"\n",
+ "elif maisi_version == \"maisi3d-rflow\":\n",
+ " environment_file = \"./configs/environment_maisi3d-rflow.json\"\n",
+ "else:\n",
+ " raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n",
+ "\n",
"with open(environment_file, \"r\") as f:\n",
" env_dict = json.load(f)\n",
"for k, v in env_dict.items():\n",
" # Update the path to the downloaded dataset in MONAI_DATA_DIRECTORY\n",
" val = v if \"datasets/\" not in v else os.path.join(root_dir, v)\n",
" setattr(args, k, val)\n",
- " print(f\"{k}: {val}\")\n",
- "print(\"Global config variables have been loaded.\")"
+ " logger.info(f\"{k}: {val}\")\n",
+ "logger.info(\"Global config variables have been loaded.\")"
]
},
{
@@ -269,7 +343,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 6,
"id": "533414f3-bef5-49f7-b082-f803b5e494bf",
"metadata": {},
"outputs": [
@@ -277,36 +351,35 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:root:`controllable_anatomy_size` is empty.\n",
- "We will synthesize based on `body_region`: (['abdomen']) and `anatomy_list`: (['liver', 'hepatic tumor']).\n",
- "INFO:root:The generate results will have voxel size to be [1.5, 1.5, 2.0]mm, volume size to be [256, 256, 256].\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "num_output_samples: 1\n",
- "body_region: ['abdomen']\n",
- "anatomy_list: ['liver', 'hepatic tumor']\n",
- "controllable_anatomy_size: []\n",
- "num_inference_steps: 1000\n",
- "mask_generation_num_inference_steps: 1000\n",
- "output_size: [256, 256, 256]\n",
- "image_output_ext: .nii.gz\n",
- "label_output_ext: .nii.gz\n",
- "spacing: [1.5, 1.5, 2.0]\n",
- "autoencoder_sliding_window_infer_size: [48, 48, 48]\n",
- "autoencoder_sliding_window_infer_overlap: 0.25\n",
- "Network definition and inference inputs have been loaded.\n"
+ "[2025-03-19 05:06:22.508][ INFO](notebook) - num_output_samples: 1\n",
+ "[2025-03-19 05:06:22.509][ INFO](notebook) - body_region: ['chest']\n",
+ "[2025-03-19 05:06:22.509][ INFO](notebook) - anatomy_list: ['lung tumor']\n",
+ "[2025-03-19 05:06:22.510][ INFO](notebook) - controllable_anatomy_size: []\n",
+ "[2025-03-19 05:06:22.510][ INFO](notebook) - num_inference_steps: 30\n",
+ "[2025-03-19 05:06:22.512][ INFO](notebook) - mask_generation_num_inference_steps: 1000\n",
+ "[2025-03-19 05:06:22.512][ INFO](notebook) - output_size: [256, 256, 256]\n",
+ "[2025-03-19 05:06:22.513][ INFO](notebook) - image_output_ext: .nii.gz\n",
+ "[2025-03-19 05:06:22.514][ INFO](notebook) - label_output_ext: .nii.gz\n",
+ "[2025-03-19 05:06:22.514][ INFO](notebook) - spacing: [1.5, 1.5, 2.0]\n",
+ "[2025-03-19 05:06:22.515][ INFO](notebook) - autoencoder_sliding_window_infer_size: [48, 48, 48]\n",
+ "[2025-03-19 05:06:22.515][ INFO](notebook) - autoencoder_sliding_window_infer_overlap: 0.6666\n",
+ "[2025-03-19 05:06:22.516][ INFO](notebook) - controlnet: $@controlnet_def\n",
+ "[2025-03-19 05:06:22.517][ INFO](notebook) - diffusion_unet: $@diffusion_unet_def\n",
+ "[2025-03-19 05:06:22.518][ INFO](notebook) - autoencoder: $@autoencoder_def\n",
+ "[2025-03-19 05:06:22.518][ INFO](notebook) - mask_generation_autoencoder: $@mask_generation_autoencoder_def\n",
+ "[2025-03-19 05:06:22.519][ INFO](notebook) - mask_generation_diffusion: $@mask_generation_diffusion_def\n",
+ "[2025-03-19 05:06:22.520][ INFO](notebook) - modality: 1\n",
+ "[2025-03-19 05:06:22.521][ INFO](root) - `controllable_anatomy_size` is empty.\n",
+ "We will synthesize based on `body_region`: (['chest']) and `anatomy_list`: (['lung tumor']).\n",
+ "[2025-03-19 05:06:22.522][ INFO](root) - The generate results will have voxel size to be [1.5, 1.5, 2.0]mm, volume size to be [256, 256, 256].\n",
+ "[2025-03-19 05:06:22.522][ INFO](notebook) - Network definition and inference inputs have been loaded.\n"
]
}
],
"source": [
- "config_file = \"./configs/config_maisi.json\"\n",
- "with open(config_file, \"r\") as f:\n",
- " config_dict = json.load(f)\n",
- "for k, v in config_dict.items():\n",
+ "with open(model_def_path, \"r\") as f:\n",
+ " model_def = json.load(f)\n",
+ "for k, v in model_def.items():\n",
" setattr(args, k, v)\n",
"\n",
"# check the format of inference inputs\n",
@@ -315,7 +388,7 @@
" config_infer_dict = json.load(f)\n",
"for k, v in config_infer_dict.items():\n",
" setattr(args, k, v)\n",
- " print(f\"{k}: {v}\")\n",
+ " logger.info(f\"{k}: {v}\")\n",
"\n",
"check_input(\n",
" args.body_region,\n",
@@ -326,7 +399,7 @@
" args.controllable_anatomy_size,\n",
")\n",
"latent_shape = [args.latent_channels, args.output_size[0] // 4, args.output_size[1] // 4, args.output_size[2] // 4]\n",
- "print(\"Network definition and inference inputs have been loaded.\")"
+ "logger.info(\"Network definition and inference inputs have been loaded.\")"
]
},
{
@@ -339,7 +412,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"id": "87ba613d-a2f5-4afc-95df-65ad21fafedd",
"metadata": {},
"outputs": [],
@@ -360,7 +433,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 8,
"id": "d499f7b1",
"metadata": {
"lines_to_next_cell": 2
@@ -370,8 +443,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "2024-09-30 06:34:48,434 - INFO - 'dst' model updated: 158 of 206 variables.\n",
- "All the trained model weights have been loaded.\n"
+ "2025-03-19 05:06:28,553 - INFO - 'dst' model updated: 180 of 231 variables.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[2025-03-19 05:06:30.944][ INFO](notebook) - All the trained model weights have been loaded.\n"
]
}
],
@@ -404,7 +483,7 @@
"mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet[\"unet_state_dict\"])\n",
"mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet[\"scale_factor\"]\n",
"\n",
- "print(\"All the trained model weights have been loaded.\")"
+ "logger.info(\"All the trained model weights have been loaded.\")"
]
},
{
@@ -417,7 +496,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 9,
"id": "8685da6e",
"metadata": {},
"outputs": [
@@ -425,7 +504,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:root:LDM sampler initialized.\n"
+ "[2025-03-19 05:06:30.969][ INFO](root) - LDM sampler initialized.\n"
]
}
],
@@ -456,6 +535,7 @@
" image_output_ext=args.image_output_ext,\n",
" label_output_ext=args.label_output_ext,\n",
" spacing=args.spacing,\n",
+ " modality=args.modality,\n",
" num_inference_steps=args.num_inference_steps,\n",
" mask_generation_num_inference_steps=args.mask_generation_num_inference_steps,\n",
" random_seed=args.random_seed,\n",
@@ -477,95 +557,96 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 10,
"id": "271f91bf-1c55-46e2-ae56-8677cd8eb81f",
"metadata": {
"scrolled": true
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The generated image/mask pairs will be saved in output.\n",
- "Extracting /workspace/data/datasets/all_masks_flexible_size_and_spacing_3000.zip to /workspace/data/datasets\n",
- "2024-09-30 06:34:50,652 - INFO - Writing into directory: /workspace/data/datasets.\n"
- ]
- },
{
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:root:Resample mask file to get desired output size and spacing\n"
+ "[2025-03-19 05:06:30.974][ INFO](notebook) - The generated image/mask pairs will be saved in output.\n",
+ "[2025-03-19 05:06:31.018][ INFO](root) - Resample mask file to get desired output size and spacing\n",
+ "[2025-03-19 05:06:32.950][ INFO](root) - Resampling mask to target shape and spacing\n",
+ "[2025-03-19 05:06:32.953][ INFO](root) - Resize Spacing: [tensor(0.7988, dtype=torch.float64), tensor(0.7988, dtype=torch.float64), tensor(1.1016, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
+ "[2025-03-19 05:06:32.954][ INFO](root) - Output size: [512, 512, 384] -> [256, 256, 256]\n",
+ "[2025-03-19 05:06:35.876][ INFO](root) - Resampling mask to target shape and spacing\n",
+ "[2025-03-19 05:06:35.878][ INFO](root) - Resize Spacing: [tensor(0.7031, dtype=torch.float64), tensor(0.7031, dtype=torch.float64), tensor(1.4795, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
+ "[2025-03-19 05:06:35.879][ INFO](root) - Output size: [512, 512, 256] -> [256, 256, 256]\n",
+ "[2025-03-19 05:06:38.564][ INFO](root) - Resampling mask to target shape and spacing\n",
+ "[2025-03-19 05:06:38.566][ INFO](root) - Resize Spacing: [tensor(0.7617, dtype=torch.float64), tensor(0.7617, dtype=torch.float64), tensor(1.2939, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
+ "[2025-03-19 05:06:38.567][ INFO](root) - Output size: [512, 512, 256] -> [256, 256, 256]\n",
+ "[2025-03-19 05:06:41.321][ INFO](root) - Resampling mask to target shape and spacing\n",
+ "[2025-03-19 05:06:41.324][ INFO](root) - Resize Spacing: [tensor(0.7031, dtype=torch.float64), tensor(0.7031, dtype=torch.float64), tensor(1.4209, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
+ "[2025-03-19 05:06:41.325][ INFO](root) - Output size: [512, 512, 256] -> [256, 256, 256]\n",
+ "[2025-03-19 05:06:46.041][ INFO](root) - Resampling mask to target shape and spacing\n",
+ "[2025-03-19 05:06:46.043][ INFO](root) - Resize Spacing: [tensor(0.7422, dtype=torch.float64), tensor(0.7422, dtype=torch.float64), tensor(0.5752, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
+ "[2025-03-19 05:06:46.044][ INFO](root) - Output size: [512, 512, 640] -> [256, 256, 256]\n",
+ "[2025-03-19 05:06:47.854][ INFO](root) - Images will be generated based on [{'mask_file': {'pseudo_label': 'temp_work_dir_inference_demo/datasets/all_masks_flexible_size_and_spacing_4000/./Task06/labelsTr/lung_070_133combined_aug_wbdm.nii.gz', 'spacing': [1.5, 1.5, 2.0], 'dim': [256, 256, 256], 'top_region_index': [1, 0, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}, 'if_aug': True}, {'mask_file': {'pseudo_label': 'temp_work_dir_inference_demo/datasets/all_masks_flexible_size_and_spacing_4000/./Task06/labelsTr/lung_031_133combined_aug_wbdm.nii.gz', 'spacing': [1.5, 1.5, 2.0], 'dim': [256, 256, 256], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}, 'if_aug': True}, {'mask_file': {'pseudo_label': 'temp_work_dir_inference_demo/datasets/all_masks_flexible_size_and_spacing_4000/./Task06/labelsTr/lung_075_133combined_aug_wbdm.nii.gz', 'spacing': [1.5, 1.5, 2.0], 'dim': [256, 256, 256], 'top_region_index': [1, 0, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}, 'if_aug': True}, {'mask_file': {'pseudo_label': 'temp_work_dir_inference_demo/datasets/all_masks_flexible_size_and_spacing_4000/./Task06/labelsTr/lung_014_133combined_aug_wbdm.nii.gz', 'spacing': [1.5, 1.5, 2.0], 'dim': [256, 256, 256], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}, 'if_aug': True}, {'mask_file': {'pseudo_label': 'temp_work_dir_inference_demo/datasets/all_masks_flexible_size_and_spacing_4000/./Task06/labelsTr/lung_049_133combined_aug_wbdm.nii.gz', 'spacing': [1.5, 1.5, 2.0], 'dim': [256, 256, 256], 'top_region_index': [1, 0, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}, 'if_aug': True}].\n",
+ "[2025-03-19 05:06:47.855][ INFO](root) - ---- Start preparing masks... ----\n",
+ "[2025-03-19 05:06:49.096][ INFO](root) - Resampling mask to target shape and spacing\n",
+ "[2025-03-19 05:06:49.100][ INFO](root) - Resize Spacing: [tensor(0.7617, dtype=torch.float64), tensor(0.7617, dtype=torch.float64), tensor(1.2939, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
+ "[2025-03-19 05:06:49.100][ INFO](root) - Output size: [512, 512, 256] -> [256, 256, 256]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Unzipped /workspace/data/datasets/all_masks_flexible_size_and_spacing_3000.zip to /workspace/data/datasets/all_masks_flexible_size_and_spacing_3000.\n"
+ "augmenting lung tumor\n",
+ "28\n",
+ "metatensor(180., device='cuda:0') | metatensor(547.4000, device='cuda:0')\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:root:Resampling mask to target shape and spacing\n",
- "INFO:root:Resize Spacing: [tensor(0.7988, dtype=torch.float64), tensor(0.7988, dtype=torch.float64), tensor(1.9062, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
- "INFO:root:Output size: [512, 512, 256] -> [256, 256, 256]\n",
- "INFO:root:Resampling mask to target shape and spacing\n",
- "INFO:root:Resize Spacing: [tensor(0.7852, dtype=torch.float64), tensor(0.7852, dtype=torch.float64), tensor(1.9336, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
- "INFO:root:Output size: [512, 512, 256] -> [256, 256, 256]\n",
- "INFO:root:Resampling mask to target shape and spacing\n",
- "INFO:root:Resize Spacing: [tensor(0.8027, dtype=torch.float64), tensor(0.8027, dtype=torch.float64), tensor(1.8672, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
- "INFO:root:Output size: [512, 512, 256] -> [256, 256, 256]\n",
- "INFO:root:Resampling mask to target shape and spacing\n",
- "INFO:root:Resize Spacing: [tensor(0.9062, dtype=torch.float64), tensor(0.9062, dtype=torch.float64), tensor(2.3438, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
- "INFO:root:Output size: [512, 512, 256] -> [256, 256, 256]\n",
- "INFO:root:Resampling mask to target shape and spacing\n",
- "INFO:root:Resize Spacing: [tensor(0.9551, dtype=torch.float64), tensor(0.9551, dtype=torch.float64), tensor(2.4805, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
- "INFO:root:Output size: [512, 512, 256] -> [256, 256, 256]\n",
- "INFO:root:Images will be generated based on [{'mask_file': {'pseudo_label': '/workspace/data/datasets/all_masks_flexible_size_and_spacing_3000/./Task03/labelsTr/liver_56_133combined_aug_wbdm.nii.gz', 'spacing': [1.5, 1.5, 2.0], 'dim': [256, 256, 256], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 0, 1]}, 'if_aug': True}].\n",
- "INFO:root:---- Start preparing masks... ----\n",
- "INFO:root:Resampling mask to target shape and spacing\n",
- "INFO:root:Resize Spacing: [tensor(0.8027, dtype=torch.float64), tensor(0.8027, dtype=torch.float64), tensor(1.8672, dtype=torch.float64)] -> [1.5, 1.5, 2.0]\n",
- "INFO:root:Output size: [512, 512, 256] -> [256, 256, 256]\n"
+ "[2025-03-19 05:06:52.345][ INFO](root) - ---- Mask preparation time: 4.489694595336914 seconds ----\n",
+ "[2025-03-19 05:06:52.363][ INFO](root) - ---- Start generating latent features... ----\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "augmenting liver tumor\n"
+ "metatensor(687., device='cuda:0') | metatensor(547.4000, device='cuda:0')\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:root:---- Mask preparation time: 10.283801794052124 seconds ----\n",
- "INFO:root:---- Start generating latent features... ----\n",
- "100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1000/1000 [02:15<00:00, 7.41it/s]\n",
- "INFO:root:---- Latent features generation time: 135.05768537521362 seconds ----\n",
- "INFO:root:---- Start decoding latent features into images... ----\n",
- "100%|ββββββββββ| 8/8 [00:08<00:00, 1.09s/it]\n",
- "INFO:root:---- Image decoding time: 8.780551671981812 seconds ----\n"
+ "100%|βββββββββββββββββββββββββββββββββββββββββββ| 30/30 [00:06<00:00, 4.58it/s]\n",
+ "[2025-03-19 05:06:58.962][ INFO](root) - ---- DM/ControlNet Latent features generation time: 6.598896265029907 seconds ----\n",
+ "[2025-03-19 05:06:59.051][ INFO](root) - ---- Start decoding latent features into images... ----\n",
+ "100%|βββββββββββββββββββββββββββββββββββββββββββββ| 8/8 [00:11<00:00, 1.42s/it]\n",
+ "[2025-03-19 05:07:10.494][ INFO](root) - ---- Image VAE decoding time: 11.442715883255005 seconds ----\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "2024-09-30 06:38:43,064 INFO image_writer.py:197 - writing: output/sample_20240930_063843_052831_image.nii.gz\n",
- "2024-09-30 06:38:44,661 INFO image_writer.py:197 - writing: output/sample_20240930_063843_052831_label.nii.gz\n",
- "MAISI image/mask generation finished\n"
+ "1 5\n",
+ "2025-03-19 05:07:10,994 INFO image_writer.py:197 - writing: output/sample_20250319_050710_975233_image.nii.gz\n",
+ "2025-03-19 05:07:12,461 INFO image_writer.py:197 - writing: output/sample_20250319_050710_975233_label.nii.gz\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[2025-03-19 05:07:13.422][ INFO](notebook) - MAISI image/mask generation finished\n"
]
}
],
"source": [
- "print(f\"The generated image/mask pairs will be saved in {args.output_dir}.\")\n",
+ "logger.info(f\"The generated image/mask pairs will be saved in {args.output_dir}.\")\n",
"output_filenames = ldm_sampler.sample_multiple_images(args.num_output_samples)\n",
- "print(\"MAISI image/mask generation finished\")"
+ "logger.info(\"MAISI image/mask generation finished\")"
]
},
{
@@ -578,22 +659,20 @@
},
{
"cell_type": "code",
- "execution_count": 10,
- "id": "e0453d9f-1614-4c84-aef1-77b6339d8c12",
- "metadata": {
- "scrolled": true
- },
+ "execution_count": 11,
+ "id": "dfd2ebf9-04f9-498f-982e-9daf44602bee",
+ "metadata": {},
"outputs": [
{
- "name": "stdout",
+ "name": "stderr",
"output_type": "stream",
"text": [
- "Visualizing output/sample_20240930_063843_052831_image.nii.gz and output/sample_20240930_063843_052831_label.nii.gz...\n"
+ "[2025-03-19 05:07:13.430][ INFO](notebook) - Visualizing output/sample_20250319_050710_975233_image.nii.gz and output/sample_20250319_050710_975233_label.nii.gz...\n"
]
},
{
"data": {
- "image/png": "",
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAA4UAAAFbCAYAAABieyvWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAn4klEQVR4nO3de3SU9Z3H8c+EkCFIZkIIyRAggCIg5aJLIM1BqwtZQspSL5weykIXEGWF4MqlVrN7uOxu16C2bqtLQbcusEvRlh6xlSpuyiUcaogQ5IDgiYBgIjCJgpkJCLnNb/+gPDpyDSSZjL/365zvOTPP7zeT7+PP0fmcZ57ncRljjAAAAAAAVoqJdAMAAAAAgMghFAIAAACAxQiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAQQatWrZLL5dKuXbsi3QoAwFKEQgAAAACwGKEQAAAAACxGKAQAAAAAixEKAQBWWrJkiVwulz788ENNmTJFXq9XXbt21cKFC2WMUUVFhe699155PB75fD797Gc/c15bV1enRYsWadiwYfJ6vbrpppt01113acuWLRf9nVdffVXDhg1TQkKCPB6PBg8erF/84hdX7O3zzz/XiBEj1KNHD5WVlTX7vgMA8FWEQgCA1SZOnKhQKKSlS5cqMzNTP/nJT/Tzn/9cf/M3f6Pu3bvr6aefVt++ffWjH/1I27ZtkyQFg0H96le/0j333KOnn35aS5Ys0aeffqqcnBzt2bPHee/CwkJNmjRJnTt31tNPP62lS5fqnnvu0Z///OfL9vPZZ59p1KhRqqysVFFRkfr379/S/wgAAJaLjXQDAABE0ogRI/Tiiy9KkmbOnKnevXtrwYIFKigo0BNPPCFJmjRpktLS0vTf//3f+s53vqPOnTvr6NGjiouLc97n4Ycf1oABA/TCCy/o5ZdfliT98Y9/lMfj0dtvv6127dpdtRe/36/s7GydPXtW27ZtU69evVpgjwEACMeRQgCA1R566CHncbt27ZSRkSFjjGbMmOFsT0xMVP/+/fXRRx858y4EwlAopFOnTqmhoUEZGRnavXt32OvOnDmjwsLCq/bxySef6O6771Z9fT2BEADQqgiFAACrpaenhz33er3q0KGDkpOTL9r++eefO89Xr16tIUOGqEOHDurSpYu6du2qP/7xjwoEAs6c2bNnq1+/fsrNzVWPHj304IMPauPGjZfs44c//KGqqqpUVFSk7t27N+MeAgBwZYRCAIDVLvWzzsv91NMYI0las2aNpk2bpltuuUUvv/yyNm7cqMLCQo0aNUqhUMiZn5KSoj179ugPf/iDvve972nLli3Kzc3V1KlTL3rvBx54QNXV1Ve9CA0AAM2NcwoBAGii3/3ud7r55pv12muvyeVyOdsXL1580dy4uDiNHz9e48ePVygU0uzZs/Xiiy9q4cKF6tu3rzPv0UcfVd++fbVo0SJ5vV49+eSTrbIvAAAQCgEAaKILRxKNMU4oLCkpUXFxcdjPUU+ePKkuXbo4z2NiYjRkyBBJUm1t7UXvu3DhQgWDQeXn58vr9WrWrFktuRsAAEgiFAIA0GR/+7d/q9dee03333+/xo0bpyNHjmjFihUaOHCgTp8+7cx76KGHdOrUKY0aNUo9evTQxx9/rBdeeEG33367brvttku+97PPPqtAIKC8vDwlJCRoypQprbVbAABLEQoBAGiiadOmye/368UXX9Tbb7+tgQMHas2aNVq3bp22bt3qzJsyZYpeeukl/fKXv1R1dbV8Pp8mTpyoJUuWKCbm8qf1r1ixQqdPn9b06dOVkJCge++9txX2CgBgK5e5cNY8AAAAAMA6XH0UAAAAACxGKAQAAAAAixEKAQAAAMBiEQ2Fy5YtU+/evdWhQwdlZmbq3XffjWQ7AAAAAGCdiIXC3/zmN5o/f74WL16s3bt3a+jQocrJyVFVVVWkWgIAAAAA60Ts6qOZmZkaPny4/vM//1OSFAqF1LNnTz366KN68sknr/jaUCik48ePKyEhwblpMAAAAADgS8YY1dTUKC0t7Yq3QorIfQrr6upUWlqq/Px8Z1tMTIyys7NVXFx80fza2lrV1tY6z48dO6aBAwe2Sq8AAAAAEM0qKirUo0ePy45H5Oejn332mRobG5Wamhq2PTU1VX6//6L5BQUF8nq9ThEIAQAAAODaJCQkXHE8Kq4+mp+fr0Ag4FRFRUWkWwIAAACAqHC1U+4i8vPR5ORktWvXTpWVlWHbKysr5fP5Lprvdrvldrtbqz0AAAAAsEZEjhTGxcVp2LBh2rRpk7MtFApp06ZNysrKikRLAAAAAGCliBwplKT58+dr6tSpysjI0IgRI/Tzn/9cZ86c0fTp0yPVEgAAAABYJ2KhcOLEifr000+1aNEi+f1+3X777dq4ceNFF58BAAAAALSciN2n8EYEg0F5vd5ItwEAAAAAbV4gEJDH47nseFRcfRQAAAAA0DIIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABYjFAIAAACAxQiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABYjFAIAAACAxQiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABYjFAIAAACAxQiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABZr9lC4ZMkSuVyusBowYIAzfu7cOeXl5alLly7q1KmTJkyYoMrKyuZuAwAAAABwDVrkSOG3vvUtnThxwqnt27c7Y/PmzdMbb7yhdevWqaioSMePH9cDDzzQEm0AAAAAAK4itkXeNDZWPp/vou2BQEAvv/yy1q5dq1GjRkmSVq5cqdtuu007duzQt7/97ZZoBwAAAABwGS1ypPDgwYNKS0vTzTffrMmTJ6u8vFySVFpaqvr6emVnZztzBwwYoPT0dBUXF1/2/WpraxUMBsMKAAAAAHDjmj0UZmZmatWqVdq4caOWL1+uI0eO6K677lJNTY38fr/i4uKUmJgY9prU1FT5/f7LvmdBQYG8Xq9TPXv2bO62AQAAAMBKzf7z0dzcXOfxkCFDlJmZqV69eum3v/2t4uPjr+s98/PzNX/+fOd5MBgkGAIAAABAM2jxW1IkJiaqX79+OnTokHw+n+rq6lRdXR02p7Ky8pLnIF7gdrvl8XjCCgAAAABw41o8FJ4+fVqHDx9Wt27dNGzYMLVv316bNm1yxsvKylReXq6srKyWbgUAAAAA8DXN/vPRH/3oRxo/frx69eql48ePa/HixWrXrp0mTZokr9erGTNmaP78+UpKSpLH49Gjjz6qrKwsrjwKAAAAABHQ7KHwk08+0aRJk3Ty5El17dpVd955p3bs2KGuXbtKkv7jP/5DMTExmjBhgmpra5WTk6Nf/vKXzd0GAAAAAOAauIwxJtJNNFUwGJTX6410GwAAAADQ5gUCgStel6XFzykEAAAAALRdhEIAAAAAsBihEAAAAAAsRigEAAAAAIsRCgEAAADAYoRCAAAAALAYoRAAAAAALEYoBAAAAACLEQoBAAAAwGKEQgAAAACwGKEQAAAAACxGKAQAAAAAixEKAQAAAMBihEIAAAAAsBihEAAAAAAsRigEAAAAAIsRCgEAAADAYoRCAAAAALAYoRAAAAAALEYoBAAAAACLEQoBAAAAwGKEQgAAAACwGKEQAAAAACxGKAQAAAAAixEKAQAAAMBihEIAAAAAsBihEAAAAAAsRigEAAAAAIsRCgEAAADAYoRCAAAAALAYoRAAAAAALEYoBAAAAACLEQoBAAAAwGKEQgAAAACwGKEQAAAAACxGKAQAAAAAixEKAQAAAMBihEIAAAAAsBihEAAAAAAsRigEAAAAAIsRCgEAAADAYoRCAAAAALAYoRAAAAAALEYoBAAAAACLNTkUbtu2TePHj1daWppcLpdef/31sHFjjBYtWqRu3bopPj5e2dnZOnjwYNicU6dOafLkyfJ4PEpMTNSMGTN0+vTpG9oRAAAAAEDTNTkUnjlzRkOHDtWyZcsuOf7MM8/o+eef14oVK1RSUqKbbrpJOTk5OnfunDNn8uTJ2r9/vwoLC7VhwwZt27ZNM2fOvP69AAAAAABcH3MDJJn169c7z0OhkPH5fObZZ591tlVXVxu3221eeeUVY4wxBw4cMJLMzp07nTlvvfWWcblc5tixY9f0dwOBgJFEURRFURRFURRFXaUCgcAV81WznlN45MgR+f1+ZWdnO9u8Xq8yMzNVXFwsSSouLlZiYqIyMjKcOdnZ2YqJiVFJSckl37e2tlbBYDCsAAAAAAA3rllDod/vlySlpqaGbU9NTXXG/H6/UlJSwsZjY2OVlJTkzPm6goICeb1ep3r27NmcbQMAAACAtaLi6qP5+fkKBAJOVVRURLolAAAAAPhGaNZQ6PP5JEmVlZVh2ysrK50xn8+nqqqqsPGGhgadOnXKmfN1brdbHo8nrAAAAAAAN65ZQ2GfPn3k8/m0adMmZ1swGFRJSYmysrIkSVlZWaqurlZpaakzZ/PmzQqFQsrMzGzOdgAAAAAAVxHb1BecPn1ahw4dcp4fOXJEe/bsUVJSktLT0zV37lz95Cc/0a233qo+ffpo4cKFSktL03333SdJuu222zR27Fg9/PDDWrFiherr6zVnzhz94Ac/UFpaWrPtGAAAAADgGjThDhTGGGO2bNlyycucTp061Rhz/rYUCxcuNKmpqcbtdpvRo0ebsrKysPc4efKkmTRpkunUqZPxeDxm+vTppqam5pp74JYUFEVRFEVRFEVR11ZXuyWFyxhjFGWCwaC8Xm+k2wAAAACANi8QCFzxuixRcfVRAAAAAEDLIBQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABaLjXQDgM26dpUSE7983tgoffRRxNoBgDbH5ZL69j3/+PBhKRSKbD8A8E1EKARaWUKCFPOXY/T/9m/SP/zDl2MnT0rJyZHpCwBak9d7bfM6dJA+/PD84/R06dgxgiEANDdCIdAKYr7yQ+2PPiL4AbBbQoJUXd3015WXS/36SQcPNntLAGA1zikEWlhKitTQ8GV16RLpjgAgctLTpUAg0l0AAL6KUAi0oEGDpKNHz58T89UCAFtVVJw/nxoA0HYQCoEWMHOm9Mkn0p/+JMXHR7obAGg7jJHOno10FwCAr+KcQqAFJCRI3btHugsA+GYZNer8eYUAgObFkUIAABAVdu+Wamsj3QUAfPNwpBBoAdu3S2vWSFOmXNv8Dz+UVq/mJ1UA7FBfL/3zP3/5fMECKSkpfE5FhbRiRfi2c+davjcAsBGhEGgBJSVSXd35q41ei337pOeea9meAKCtqK+Xnnrqy+edO198q54jR8LnAABajssYYyLdRFMFg0F5r/WutwAAAABgsUAgII/Hc9lxzikEAAAAAIsRCgEAAADAYoRCAAAAALAYoRAAAAAALEYoBAAAAACLEQoBAAAAwGKEQgAAAACwGKEQAAAAACxGKAQAAAAAixEKAQAAAMBihEIAAAAAsBihEAAAAAAsRigEAAAAAIsRCgEAAADAYoRCAAAAALAYoRAAAAAALEYoBAAAAACLEQoBAAAAwGKEQgAAAACwGKEQAAAAACxGKAQAAAAAixEKAQAAAMBihEIAAAAAsBihEAAAAAAsRigEAAAAAIs1ORRu27ZN48ePV1pamlwul15//fWw8WnTpsnlcoXV2LFjw+acOnVKkydPlsfjUWJiombMmKHTp0/f0I4AAAAAAJquyaHwzJkzGjp0qJYtW3bZOWPHjtWJEyeceuWVV8LGJ0+erP3796uwsFAbNmzQtm3bNHPmzKZ3DwAAAAC4IbFNfUFubq5yc3OvOMftdsvn811y7IMPPtDGjRu1c+dOZWRkSJJeeOEFffe739VPf/pTpaWlXfSa2tpa1dbWOs+DwWBT2wYAAAAAXEKLnFO4detWpaSkqH///po1a5ZOnjzpjBUXFysxMdEJhJKUnZ2tmJgYlZSUXPL9CgoK5PV6nerZs2dLtA0AAAAA1mn2UDh27Fj9z//8jzZt2qSnn35aRUVFys3NVWNjoyTJ7/crJSUl7DWxsbFKSkqS3++/5Hvm5+crEAg4VVFR0dxtAwAAAICVmvzz0av5wQ9+4DwePHiwhgwZoltuuUVbt27V6NGjr+s93W633G53c7UIAAAAwFKZmdKdd0onTkhr10a6m7ahxW9JcfPNNys5OVmHDh2SJPl8PlVVVYXNaWho0KlTpy57HiIAAAAANIfRo6Wf/lRavFgaMybS3bQNLR4KP/nkE508eVLdunWTJGVlZam6ulqlpaXOnM2bNysUCikzM7Ol2wEAAABgkcGDpaFDv6y/xBL16yf97/9KQ4ZEtr+2wGWMMU15wenTp52jfnfccYeee+45/fVf/7WSkpKUlJSkf/mXf9GECRPk8/l0+PBh/fjHP1ZNTY327dvn/AQ0NzdXlZWVWrFiherr6zV9+nRlZGRo7TUevw0Gg/J6vU3cVQAAAAC2CQalhITLj589K/XuHb7t88+l+voWbatVBQIBeTyey08wTbRlyxYj6aKaOnWq+eKLL8yYMWNM165dTfv27U2vXr3Mww8/bPx+f9h7nDx50kyaNMl06tTJeDweM336dFNTU3PNPQQCgUv2QFEURVEURVEUdaHi4mSCQRljrr1qa2XuuSfyvTdnBQKBK+arJh8pbAs4UggAAADgSlwuqaFBimniCXP9+kkHD7ZMT5FytSOFLX5OIQAAAACg7SIUAgAAAPhG6dhR+vTT80cLcXXNfp9CAAAAAIi0Ll0i3UH04EghAAAAAFiMUAgAAADgG6Nz5/M3p78eP/6xVFXVvP1EA0IhAAAAAEh66SUpEIh0F62PcwoBAAAAfGN8/rk0b56Unh6+/dvfvvR5ho2N0saN5x9/k25Y3xTcpxAAAADAN96qVdL3vnf+56UX1NdLe/dKGRkRa6tVcJ9CAAAAANabNk1au1Y6efLLKiv75gfCa8GRQgAAAAD4BuNIIQAAAADgsgiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABYjFAIAAACAxQiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABYjFAIAAACAxQiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYLEmhcKCggINHz5cCQkJSklJ0X333aeysrKwOefOnVNeXp66dOmiTp06acKECaqsrAybU15ernHjxqljx45KSUnR448/roaGhhvfGwAAAABAkzQpFBYVFSkvL087duxQYWGh6uvrNWbMGJ05c8aZM2/ePL3xxhtat26dioqKdPz4cT3wwAPOeGNjo8aNG6e6ujq98847Wr16tVatWqVFixY1314BAAAAAK6NuQFVVVVGkikqKjLGGFNdXW3at29v1q1b58z54IMPjCRTXFxsjDHmzTffNDExMcbv9ztzli9fbjwej6mtrb3k3zl37pwJBAJOVVRUGEkURVEURVEURVHUVSoQCFwx193QOYWBQECSlJSUJEkqLS1VfX29srOznTkDBgxQenq6iouLJUnFxcUaPHiwUlNTnTk5OTkKBoPav3//Jf9OQUGBvF6vUz179ryRtgEAAAAAf3HdoTAUCmnu3LkaOXKkBg0aJEny+/2Ki4tTYmJi2NzU1FT5/X5nzlcD4YXxC2OXkp+fr0Ag4FRFRcX1tg0AAAAA+IrY631hXl6e3n//fW3fvr05+7kkt9stt9vd4n8HAAAAAGxzXUcK58yZow0bNmjLli3q0aOHs93n86murk7V1dVh8ysrK+Xz+Zw5X78a6YXnF+YAAAAAAFpHk0KhMUZz5szR+vXrtXnzZvXp0ydsfNiwYWrfvr02bdrkbCsrK1N5ebmysrIkSVlZWdq3b5+qqqqcOYWFhfJ4PBo4cOCN7AsAAAAAoKmacrXRWbNmGa/Xa7Zu3WpOnDjh1BdffOHMeeSRR0x6errZvHmz2bVrl8nKyjJZWVnOeENDgxk0aJAZM2aM2bNnj9m4caPp2rWryc/Pv+Y+AoFAxK/gQ1EURVEURVEUFQ11tauPNikUXu6PrFy50plz9uxZM3v2bNO5c2fTsWNHc//995sTJ06Evc/Ro0dNbm6uiY+PN8nJyWbBggWmvr6eUEhRFEVRFEVRFNXMdbVQ6PpL2IsqwWBQXq830m0AAAAAQJsXCATk8XguO35D9ykEAAAAAEQ3QiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABYjFAIAAACAxQiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABYjFAIAAACAxQiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFCIUAAAAAYDFCIQAAAABYjFAIAAAAABYjFAIAAACAxQiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMUIhQAAAABgMUIhAAAAAFiMUAgAAAAAFiMUAgAAAIDFmhQKCwoKNHz4cCUkJCglJUX33XefysrKwubcc889crlcYfXII4+EzSkvL9e4cePUsWNHpaSk6PHHH1dDQ8ON7w0AAAAAoElimzK5qKhIeXl5Gj58uBoaGvRP//RPGjNmjA4cOKCbbrrJmffwww/rX//1X53nHTt2dB43NjZq3Lhx8vl8euedd3TixAn9/d//vdq3b6+nnnqqGXYJAAAAAHDNzA2oqqoykkxRUZGz7e677zaPPfbYZV/z5ptvmpiYGOP3+51ty5cvNx6Px9TW1l7yNefOnTOBQMCpiooKI4miKIqiKIqiKIq6SgUCgSvmuhs6pzAQCEiSkpKSwrb/+te/VnJysgYNGqT8/Hx98cUXzlhxcbEGDx6s1NRUZ1tOTo6CwaD2799/yb9TUFAgr9frVM+ePW+kbQAAAADAXzTp56NfFQqFNHfuXI0cOVKDBg1ytv/d3/2devXqpbS0NO3du1dPPPGEysrK9Nprr0mS/H5/WCCU5Dz3+/2X/Fv5+fmaP3++8zwYDBIMAQAAAKAZXHcozMvL0/vvv6/t27eHbZ85c6bzePDgwerWrZtGjx6tw4cP65Zbbrmuv+V2u+V2u6+3VQAAAADAZVzXz0fnzJmjDRs2aMuWLerRo8cV52ZmZkqSDh06JEny+XyqrKwMm3Phuc/nu552AAAAAADXqUmh0BijOXPmaP369dq8ebP69Olz1dfs2bNHktStWzdJUlZWlvbt26eqqipnTmFhoTwejwYOHNiUdgAAAAAAN8hljDHXOnn27Nlau3atfv/736t///7Odq/Xq/j4eB0+fFhr167Vd7/7XXXp0kV79+7VvHnz1KNHDxUVFUk6f0uK22+/XWlpaXrmmWfk9/v1wx/+UA899NA135IiEAgoMTGxaXsKAAAAABaqrq6W1+u9/IQm3IHispc4XblypTHGmPLycvOd73zHJCUlGbfbbfr27Wsef/zxiy6BevToUZObm2vi4+NNcnKyWbBggamvr7/mPrglBUVRFEVRFEVR1LVVRUXFFfNVk44UthWhUEhlZWUaOHCgKioq5PF4It0SrtGFK8eybtGFdYtOrFt0Yt2iE+sWnVi36MS6XTtjjGpqapSWlqaYmMufOXjdVx+NpJiYGHXv3l2S5PF4+JchCrFu0Yl1i06sW3Ri3aIT6xadWLfoxLpdmyv+bPQvbujm9QAAAACA6EYoBAAAAACLRW0odLvdWrx4MTe1jzKsW3Ri3aIT6xadWLfoxLpFJ9YtOrFuzS8qLzQDAAAAAGgeUXukEAAAAABw4wiFAAAAAGAxQiEAAAAAWIxQCAAAAAAWIxQCAAAAgMWiMhQuW7ZMvXv3VocOHZSZmal333030i1Zbdu2bRo/frzS0tLkcrn0+uuvh40bY7Ro0SJ169ZN8fHxys7O1sGDB8PmnDp1SpMnT5bH41FiYqJmzJih06dPt+Je2KegoEDDhw9XQkKCUlJSdN9996msrCxszrlz55SXl6cuXbqoU6dOmjBhgiorK8PmlJeXa9y4cerYsaNSUlL0+OOPq6GhoTV3xSrLly/XkCFD5PF45PF4lJWVpbfeessZZ82iw9KlS+VyuTR37lxnG2vX9ixZskQulyusBgwY4IyzZm3XsWPHNGXKFHXp0kXx8fEaPHiwdu3a5Yzz3aTt6d2790WfN5fLpby8PEl83lpa1IXC3/zmN5o/f74WL16s3bt3a+jQocrJyVFVVVWkW7PWmTNnNHToUC1btuyS488884yef/55rVixQiUlJbrpppuUk5Ojc+fOOXMmT56s/fv3q7CwUBs2bNC2bds0c+bM1toFKxUVFSkvL087duxQYWGh6uvrNWbMGJ05c8aZM2/ePL3xxhtat26dioqKdPz4cT3wwAPOeGNjo8aNG6e6ujq98847Wr16tVatWqVFixZFYpes0KNHDy1dulSlpaXatWuXRo0apXvvvVf79++XxJpFg507d+rFF1/UkCFDwrazdm3Tt771LZ04ccKp7du3O2OsWdv0+eefa+TIkWrfvr3eeustHThwQD/72c/UuXNnZw7fTdqenTt3hn3WCgsLJUnf//73JfF5a3EmyowYMcLk5eU5zxsbG01aWpopKCiIYFe4QJJZv3698zwUChmfz2eeffZZZ1t1dbVxu93mlVdeMcYYc+DAASPJ7Ny505nz1ltvGZfLZY4dO9ZqvduuqqrKSDJFRUXGmPPr1L59e7Nu3TpnzgcffGAkmeLiYmOMMW+++aaJiYkxfr/fmbN8+XLj8XhMbW1t6+6AxTp37mx+9atfsWZRoKamxtx6662msLDQ3H333eaxxx4zxvB5a6sWL15shg4deskx1qzteuKJJ8ydd9552XG+m0SHxx57zNxyyy0mFArxeWsFUXWksK6uTqWlpcrOzna2xcTEKDs7W8XFxRHsDJdz5MgR+f3+sDXzer3KzMx01qy4uFiJiYnKyMhw5mRnZysmJkYlJSWt3rOtAoGAJCkpKUmSVFpaqvr6+rC1GzBggNLT08PWbvDgwUpNTXXm5OTkKBgMOkeu0HIaGxv16quv6syZM8rKymLNokBeXp7GjRsXtkYSn7e27ODBg0pLS9PNN9+syZMnq7y8XBJr1pb94Q9/UEZGhr7//e8rJSVFd9xxh/7rv/7LGee7SdtXV1enNWvW6MEHH5TL5eLz1gqiKhR+9tlnamxsDFtsSUpNTZXf749QV7iSC+typTXz+/1KSUkJG4+NjVVSUhLr2kpCoZDmzp2rkSNHatCgQZLOr0tcXJwSExPD5n597S61thfG0DL27dunTp06ye1265FHHtH69es1cOBA1qyNe/XVV7V7924VFBRcNMbatU2ZmZlatWqVNm7cqOXLl+vIkSO66667VFNTw5q1YR999JGWL1+uW2+9VW+//bZmzZqlf/zHf9Tq1asl8d0kGrz++uuqrq7WtGnTJPHfyNYQG+kGAEReXl6e3n///bBzZdB29e/fX3v27FEgENDvfvc7TZ06VUVFRZFuC1dQUVGhxx57TIWFherQoUOk28E1ys3NdR4PGTJEmZmZ6tWrl377298qPj4+gp3hSkKhkDIyMvTUU09Jku644w69//77WrFihaZOnRrh7nAtXn75ZeXm5iotLS3SrVgjqo4UJicnq127dhddaaiyslI+ny9CXeFKLqzLldbM5/NddKGghoYGnTp1inVtBXPmzNGGDRu0ZcsW9ejRw9nu8/lUV1en6urqsPlfX7tLre2FMbSMuLg49e3bV8OGDVNBQYGGDh2qX/ziF6xZG1ZaWqqqqir91V/9lWJjYxUbG6uioiI9//zzio2NVWpqKmsXBRITE9WvXz8dOnSIz1sb1q1bNw0cODBs22233eb89JfvJm3bxx9/rD/96U966KGHnG183lpeVIXCuLg4DRs2TJs2bXK2hUIhbdq0SVlZWRHsDJfTp08f+Xy+sDULBoMqKSlx1iwrK0vV1dUqLS115mzevFmhUEiZmZmt3rMtjDGaM2eO1q9fr82bN6tPnz5h48OGDVP79u3D1q6srEzl5eVha7dv376w/3EWFhbK4/Fc9D9ktJxQKKTa2lrWrA0bPXq09u3bpz179jiVkZGhyZMnO49Zu7bv9OnTOnz4sLp168bnrQ0bOXLkRbdY+vDDD9WrVy9JfDdp61auXKmUlBSNGzfO2cbnrRVE+ko3TfXqq68at9ttVq1aZQ4cOGBmzpxpEhMTw640hNZVU1Nj3nvvPfPee+8ZSea5554z7733nvn444+NMcYsXbrUJCYmmt///vdm79695t577zV9+vQxZ8+edd5j7Nix5o477jAlJSVm+/bt5tZbbzWTJk2K1C5ZYdasWcbr9ZqtW7eaEydOOPXFF184cx555BGTnp5uNm/ebHbt2mWysrJMVlaWM97Q0GAGDRpkxowZY/bs2WM2btxounbtavLz8yOxS1Z48sknTVFRkTly5IjZu3evefLJJ43L5TL/93//Z4xhzaLJV68+agxr1xYtWLDAbN261Rw5csT8+c9/NtnZ2SY5OdlUVVUZY1izturdd981sbGx5t///d/NwYMHza9//WvTsWNHs2bNGmcO303apsbGRpOenm6eeOKJi8b4vLWsqAuFxhjzwgsvmPT0dBMXF2dGjBhhduzYEemWrLZlyxYj6aKaOnWqMeb8pZ8XLlxoUlNTjdvtNqNHjzZlZWVh73Hy5EkzadIk06lTJ+PxeMz06dNNTU1NBPbGHpdaM0lm5cqVzpyzZ8+a2bNnm86dO5uOHTua+++/35w4cSLsfY4ePWpyc3NNfHy8SU5ONgsWLDD19fWtvDf2ePDBB02vXr1MXFyc6dq1qxk9erQTCI1hzaLJ10Mha9f2TJw40XTr1s3ExcWZ7t27m4kTJ5pDhw4546xZ2/XGG2+YQYMGGbfbbQYMGGBeeumlsHG+m7RNb7/9tpF00VoYw+etpbmMMSYihygBAAAAABEXVecUAgAAAACaF6EQAAAAACxGKAQAAAAAixEKAQAAAMBihEIAAAAAsBihEAAAAAAsRigEAAAAAIsRCgEAAADAYoRCAAAAALAYoRAAAAAALEYoBAAAAACL/T8HiaY0L97NuQAAAABJRU5ErkJggg==",
"text/plain": [
""
]
@@ -603,7 +682,7 @@
},
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
""
]
@@ -615,7 +694,7 @@
"source": [
"visualize_image_filename = output_filenames[0][0]\n",
"visualize_mask_filename = output_filenames[0][1]\n",
- "print(f\"Visualizing {visualize_image_filename} and {visualize_mask_filename}...\")\n",
+ "logger.info(f\"Visualizing {visualize_image_filename} and {visualize_mask_filename}...\")\n",
"\n",
"# load image/mask pairs\n",
"loader = LoadImage(image_only=True, ensure_channel_first=True)\n",
@@ -624,7 +703,7 @@
"mask_volume = orientation(loader(visualize_mask_filename)).to(torch.uint8)\n",
"\n",
"# visualize for CT HU intensity between [-200, 500]\n",
- "image_volume = torch.clip(image_volume, -200, 500)\n",
+ "image_volume = torch.clip(image_volume, -1000, 300)\n",
"image_volume = image_volume - torch.min(image_volume)\n",
"image_volume = image_volume / torch.max(image_volume)\n",
"\n",
@@ -651,7 +730,7 @@
"formats": "py:percent,ipynb"
},
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
diff --git a/generation/maisi/maisi_train_controlnet_tutorial.ipynb b/generation/maisi/maisi_train_controlnet_tutorial.ipynb
index 37becc2e43..565adb7cd4 100644
--- a/generation/maisi/maisi_train_controlnet_tutorial.ipynb
+++ b/generation/maisi/maisi_train_controlnet_tutorial.ipynb
@@ -26,7 +26,9 @@
"\n",
"\n",
"\n",
- "In this notebook, we detail the procedure for training a 3D ControlNet to generate high-dimensional 3D medical images. Due to the potential for out-of-memory issues on most GPUs when generating large images (e.g., those with dimensions of 512 x 512 x 512 or greater), we have structured the training process into two primary steps: 1) preparing training data, 2) training config preparation, and 3) launch training of 3D ControlNet. The subsequent sections will demonstrate the entire process using a simulated dataset. We also provide the real preprocessed dataset used in the finetuning config `environment_maisi_controlnet_train.json`. More instructions about how to preprocess real data can be found in the [README](./data/README.md) in `data` folder.\n"
+ "In this notebook, we detail the procedure for training a 3D ControlNet to generate high-dimensional 3D medical images. Due to the potential for out-of-memory issues on most GPUs when generating large images (e.g., those with dimensions of 512 x 512 x 512 or greater), we have structured the training process into two primary steps: 1) preparing training data, 2) training config preparation, and 3) launch training of 3D ControlNet. The subsequent sections will demonstrate the entire process using a simulated dataset. We also provide the real preprocessed dataset used in the finetuning config `environment_maisi_controlnet_train.json`. More instructions about how to preprocess real data can be found in the [README](./data/README.md) in `data` folder.\n",
+ "\n",
+ "`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi3d-rflow'`. Compared with the previous version `'maisi3d-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section."
]
},
{
@@ -57,46 +59,38 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 2,
"id": "e3bf0346",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "MONAI version: 1.4.0rc12\n",
- "Numpy version: 1.24.4\n",
- "Pytorch version: 2.5.0a0+872d972e41.nv24.08\n",
+ "MONAI version: 1.4.1rc1+32.g34f37973\n",
+ "Numpy version: 1.26.4\n",
+ "Pytorch version: 2.5.0+cu124\n",
"MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
- "MONAI rev id: 76ef9f40c8da626928238c91eacddc789b0b4530\n",
- "MONAI __file__: /workspace/Code/MONAI/monai/__init__.py\n",
+ "MONAI rev id: 34f379735c5e18e7f809453eb1b3606c225c788b\n",
+ "MONAI __file__: /localhome//.local/lib/python3.10/site-packages/monai/__init__.py\n",
"\n",
"Optional dependencies:\n",
"Pytorch Ignite version: 0.4.11\n",
"ITK version: 5.4.0\n",
- "Nibabel version: 5.2.1\n",
- "scikit-image version: 0.23.2\n",
- "scipy version: 1.14.0\n",
- "Pillow version: 10.4.0\n",
- "Tensorboard version: 2.16.2\n",
+ "Nibabel version: 5.3.2\n",
+ "scikit-image version: 0.24.0\n",
+ "scipy version: 1.14.1\n",
+ "Pillow version: 11.0.0\n",
+ "Tensorboard version: 2.18.0\n",
"gdown version: 5.2.0\n",
- "TorchVision version: 0.20.0a0\n",
+ "TorchVision version: 0.20.0+cu124\n",
"tqdm version: 4.66.5\n",
"lmdb version: 1.5.1\n",
- "psutil version: 6.0.0\n",
- "pandas version: 2.2.2\n",
+ "psutil version: 6.1.0\n",
+ "pandas version: 2.2.3\n",
"einops version: 0.8.0\n",
"transformers version: 4.40.2\n",
- "mlflow version: 2.16.2\n",
+ "mlflow version: 2.17.1\n",
"pynrrd version: 1.0.0\n",
"clearml version: 1.16.5rc2\n",
"\n",
@@ -124,6 +118,47 @@
"logger = setup_logging(\"notebook\")"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "b16b92a4-2039-42b7-bf77-68851d25701b",
+ "metadata": {},
+ "source": [
+ "## Set up the MAISI version\n",
+ "\n",
+ "Choose between `'maisi3d-ddpm'` and `'maisi3d-rflow'`. The differences are:\n",
+ "- The maisi version `'maisi3d-ddpm'` uses basic noise scheduler DDPM. `'maisi3d-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n",
+ "- The maisi version `'maisi3d-ddpm'` requires training images to be labeled with body regions (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi3d-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi3d-rflow'`.\n",
+ "- For the released model weights, `'maisi3d-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi3d-ddpm'`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "3d233abe-d69c-4b57-9655-33c2c3da6c96",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[2025-03-14 16:29:13.938][ INFO](notebook) - MAISI version is maisi3d-rflow, whether to use body_region is False\n"
+ ]
+ }
+ ],
+ "source": [
+ "maisi_version = \"maisi3d-rflow\"\n",
+ "if maisi_version == \"maisi3d-ddpm\":\n",
+ " model_def_path = \"./configs/config_maisi3d-ddpm.json\"\n",
+ "elif maisi_version == \"maisi3d-rflow\":\n",
+ " model_def_path = \"./configs/config_maisi3d-rflow.json\"\n",
+ "else:\n",
+ " raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n",
+ "with open(model_def_path, \"r\") as f:\n",
+ " model_def = json.load(f)\n",
+ "include_body_region = model_def[\"include_body_region\"]\n",
+ "logger.info(f\"MAISI version is {maisi_version}, whether to use body_region is {include_body_region}\")"
+ ]
+ },
{
"cell_type": "markdown",
"id": "671e7f10",
@@ -154,7 +189,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 4,
"id": "fc32a7fe",
"metadata": {},
"outputs": [],
@@ -169,8 +204,6 @@
" \"fold\": 0, # fold index for cross validation, fold 0 is used for training\n",
" \"dim\": sim_dim, # the dimension of image\n",
" \"spacing\": [1.5, 1.5, 1.5], # the spacing of image\n",
- " \"top_region_index\": [0, 1, 0, 0], # the top region index of the image\n",
- " \"bottom_region_index\": [0, 0, 0, 1], # the bottom region index of the image\n",
" },\n",
" {\n",
" \"image\": \"tr_image_002_emb.nii.gz\",\n",
@@ -178,8 +211,6 @@
" \"fold\": 1,\n",
" \"dim\": sim_dim,\n",
" \"spacing\": [1.5, 1.5, 1.5],\n",
- " \"top_region_index\": [0, 1, 0, 0],\n",
- " \"bottom_region_index\": [0, 0, 0, 1],\n",
" },\n",
" {\n",
" \"image\": \"tr_image_003_emb.nii.gz\",\n",
@@ -187,11 +218,14 @@
" \"fold\": 1,\n",
" \"dim\": sim_dim,\n",
" \"spacing\": [1.5, 1.5, 1.5],\n",
- " \"top_region_index\": [0, 1, 0, 0],\n",
- " \"bottom_region_index\": [0, 0, 0, 1],\n",
" },\n",
" ]\n",
- "}"
+ "}\n",
+ "if include_body_region:\n",
+ " for i in range(len(sim_datalist[\"training\"])):\n",
+ " # body region index\n",
+ " sim_datalist[\"training\"][i][\"top_region_index\"] = [0, 1, 0, 0] # the top region index of the image\n",
+ " sim_datalist[\"training\"][i][\"bottom_region_index\"] = [0, 0, 0, 1] # the bottom region index of the image"
]
},
{
@@ -206,7 +240,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"id": "1b199078",
"metadata": {},
"outputs": [
@@ -214,9 +248,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:Generated simulated images.\n",
- "INFO:notebook:img_emb shape: (64, 64, 32, 4)\n",
- "INFO:notebook:label shape: (256, 256, 128)\n"
+ "[2025-03-14 16:29:13.952][ INFO](notebook) - Save data list json file to ./temp_work_dir_controlnet_train_demo/sim_datalist.json\n",
+ "[2025-03-14 16:29:16.033][ INFO](notebook) - Generated simulated images.\n",
+ "[2025-03-14 16:29:16.034][ INFO](notebook) - img_emb shape: (64, 64, 32, 4)\n",
+ "[2025-03-14 16:29:16.035][ INFO](notebook) - label shape: (256, 256, 128)\n"
]
}
],
@@ -232,6 +267,7 @@
"datalist_file = os.path.join(work_dir, \"sim_datalist.json\")\n",
"with open(datalist_file, \"w\") as f:\n",
" json.dump(sim_datalist, f, indent=4)\n",
+ "logger.info(f\"Save data list json file to {datalist_file}\")\n",
"\n",
"for d in sim_datalist[\"training\"]:\n",
" # The image embedding is downsampled twice by Autoencoder.\n",
@@ -280,7 +316,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 6,
"id": "6c7b434c",
"metadata": {},
"outputs": [
@@ -288,15 +324,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:files and folders under work_dir: ['config_maisi.json', 'sim_dataroot', 'sim_datalist.json', 'models', 'outputs', 'environment_maisi_controlnet_train.json', 'config_maisi_controlnet_train.json'].\n",
- "INFO:notebook:number of GPUs: 1.\n"
+ "[2025-03-14 16:29:16.049][ INFO](notebook) - files and folders under work_dir: ['config_maisi.json', 'models', 'config_maisi_controlnet_train.json', 'outputs', 'sim_dataroot', '.ipynb_checkpoints', 'environment_maisi_controlnet_train.json', 'sim_datalist.json'].\n",
+ "[2025-03-14 16:29:16.050][ INFO](notebook) - number of GPUs: 1.\n"
]
}
],
"source": [
"env_config_path = \"./configs/environment_maisi_controlnet_train.json\"\n",
"train_config_path = \"./configs/config_maisi_controlnet_train.json\"\n",
- "model_def_path = \"./configs/config_maisi.json\"\n",
"\n",
"# Load environment configuration, model configuration and model definition\n",
"with open(env_config_path, \"r\") as f:\n",
@@ -367,7 +402,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 7,
"id": "95ea6972",
"metadata": {},
"outputs": [],
@@ -427,7 +462,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 8,
"id": "ade6389d",
"metadata": {},
"outputs": [
@@ -435,30 +470,30 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:Training the model...\n"
+ "[2025-03-14 16:29:16.061][ INFO](notebook) - Training the model...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "2024-09-24 02:33:40,881 - INFO - 'dst' model updated: 158 of 206 variables.\n",
- "\n",
- "INFO:maisi.controlnet.training:Number of GPUs: 2\n",
- "INFO:maisi.controlnet.training:World_size: 1\n",
- "INFO:maisi.controlnet.training:trained diffusion model is not loaded.\n",
- "INFO:maisi.controlnet.training:set scale_factor -> 1.0.\n",
- "INFO:maisi.controlnet.training:train controlnet model from scratch.\n",
- "INFO:maisi.controlnet.training:total number of training steps: 4.0.\n",
- "INFO:maisi.controlnet.training:\n",
- "[Epoch 1/2] [Batch 1/2] [LR: 0.00000563] [loss: 0.7981] ETA: 0:00:01.501654 \n",
- "INFO:maisi.controlnet.training:\n",
- "[Epoch 1/2] [Batch 2/2] [LR: 0.00000250] [loss: 0.7976] ETA: 0:00:00 \n",
- "INFO:maisi.controlnet.training:best loss -> 0.7978459596633911.\n",
- "INFO:maisi.controlnet.training:\n",
- "[Epoch 2/2] [Batch 1/2] [LR: 0.00000063] [loss: 0.7982] ETA: 0:00:01.988772 \n",
- "INFO:maisi.controlnet.training:\n",
- "[Epoch 2/2] [Batch 2/2] [LR: 0.00000000] [loss: 0.7998] ETA: 0:00:00 \n",
+ "[2025-03-14 16:29:23.336][ INFO](maisi.controlnet.training) - Number of GPUs: 8\n",
+ "[2025-03-14 16:29:23.336][ INFO](maisi.controlnet.training) - World_size: 1\n",
+ "[2025-03-14 16:29:24.771][ INFO](maisi.controlnet.training) - trained diffusion model is not loaded.\n",
+ "[2025-03-14 16:29:24.771][ INFO](maisi.controlnet.training) - set scale_factor -> 1.0.\n",
+ "2025-03-14 16:29:25,271 - INFO - 'dst' model updated: 180 of 231 variables.\n",
+ "[2025-03-14 16:29:25.277][ INFO](maisi.controlnet.training) - train controlnet model from scratch.\n",
+ "[2025-03-14 16:29:25.300][ INFO](maisi.controlnet.training) - total number of training steps: 4.0.\n",
+ "[2025-03-14 16:29:26.826][ INFO](maisi.controlnet.training) -\n",
+ "[Epoch 1/2] [Batch 1/2] [LR: 0.00000563] [loss: 0.8278] ETA: 0:00:01.523338\n",
+ "[2025-03-14 16:29:26.974][ INFO](maisi.controlnet.training) -\n",
+ "[Epoch 1/2] [Batch 2/2] [LR: 0.00000250] [loss: 0.8289] ETA: 0:00:00\n",
+ "[2025-03-14 16:29:27.585][ INFO](maisi.controlnet.training) - best loss -> 0.8283329606056213.\n",
+ "[2025-03-14 16:29:28.909][ INFO](maisi.controlnet.training) -\n",
+ "[Epoch 2/2] [Batch 1/2] [LR: 0.00000063] [loss: 0.8288] ETA: 0:00:01.934548\n",
+ "[2025-03-14 16:29:29.052][ INFO](maisi.controlnet.training) -\n",
+ "[Epoch 2/2] [Batch 2/2] [LR: 0.00000000] [loss: 0.8277] ETA: 0:00:00\n",
+ "[2025-03-14 16:29:29.716][ INFO](maisi.controlnet.training) - best loss -> 0.8282470703125.\n",
"\n"
]
}
@@ -493,7 +528,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"id": "936360c8",
"metadata": {},
"outputs": [
@@ -501,32 +536,32 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:Inference...\n"
+ "[2025-03-14 16:29:32.229][ INFO](notebook) - Inference...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "2024-09-24 02:34:03,472 - INFO - 'dst' model updated: 158 of 206 variables.\n",
- "2024-09-24 02:34:06,052 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20240924_023406_038072_image.nii.gz\n",
- "2024-09-24 02:34:06,437 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20240924_023406_038072_label.nii.gz\n",
+ "[2025-03-14 16:29:39.519][ INFO](maisi.controlnet.infer) - Number of GPUs: 8\n",
+ "[2025-03-14 16:29:39.519][ INFO](maisi.controlnet.infer) - World_size: 1\n",
+ "[2025-03-14 16:29:39.990][ INFO](maisi.controlnet.infer) - trained autoencoder model is not loaded.\n",
+ "[2025-03-14 16:29:41.213][ INFO](maisi.controlnet.infer) - trained diffusion model is not loaded.\n",
+ "[2025-03-14 16:29:41.213][ INFO](maisi.controlnet.infer) - set scale_factor -> 1.0.\n",
+ "2025-03-14 16:29:41,716 - INFO - 'dst' model updated: 180 of 231 variables.\n",
+ "[2025-03-14 16:29:41.721][ INFO](maisi.controlnet.infer) - trained controlnet is not loaded.\n",
+ "[2025-03-14 16:29:42.102][ INFO](root) - `controllable_anatomy_size` is not provided.\n",
+ "[2025-03-14 16:29:42.104][ INFO](root) - ---- Start generating latent features... ----\n",
+ "[2025-03-14 16:29:42.670][ INFO](root) - ---- DM/ControlNet Latent features generation time: 0.565190315246582 seconds ----\n",
+ "[2025-03-14 16:29:42.672][ INFO](root) - ---- Start decoding latent features into images... ----\n",
+ "[2025-03-14 16:29:43.314][ INFO](root) - ---- Image VAE decoding time: 0.6416211128234863 seconds ----\n",
+ "2025-03-14 16:29:43,602 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20250314_162943_586788_image.nii.gz\n",
+ "2025-03-14 16:29:43,940 INFO image_writer.py:197 - writing: temp_work_dir_controlnet_train_demo/outputs/sample_20250314_162943_586788_label.nii.gz\n",
"\n",
- "INFO:maisi.controlnet.infer:Number of GPUs: 2\n",
- "INFO:maisi.controlnet.infer:World_size: 1\n",
- "INFO:maisi.controlnet.infer:trained autoencoder model is not loaded.\n",
- "INFO:maisi.controlnet.infer:trained diffusion model is not loaded.\n",
- "INFO:maisi.controlnet.infer:set scale_factor -> 1.0.\n",
- "INFO:maisi.controlnet.infer:trained controlnet is not loaded.\n",
- "INFO:root:`controllable_anatomy_size` is not provided.\n",
- "INFO:root:---- Start generating latent features... ----\n",
"\n",
- " 0%| | 0/1 [00:00, ?it/s]\n",
- "100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [00:00<00:00, 2.57it/s]\n",
- "100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1/1 [00:00<00:00, 2.57it/s]\n",
- "INFO:root:---- Latent features generation time: 0.4557678699493408 seconds ----\n",
- "INFO:root:---- Start decoding latent features into images... ----\n",
- "INFO:root:---- Image decoding time: 1.2888050079345703 seconds ----\n",
+ " 0%| | 0/1 [00:00, ?it/s]\n",
+ "100%|ββββββββββ| 1/1 [00:00<00:00, 2.02it/s]\n",
+ "100%|ββββββββββ| 1/1 [00:00<00:00, 2.02it/s]\n",
"\n"
]
}
@@ -558,7 +593,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 10,
"id": "459af453",
"metadata": {},
"outputs": [
diff --git a/generation/maisi/maisi_diff_unet_training_tutorial.ipynb b/generation/maisi/maisi_train_diff_unet_tutorial.ipynb
similarity index 97%
rename from generation/maisi/maisi_diff_unet_training_tutorial.ipynb
rename to generation/maisi/maisi_train_diff_unet_tutorial.ipynb
index 6effcee52b..03bba663fa 100644
--- a/generation/maisi/maisi_diff_unet_training_tutorial.ipynb
+++ b/generation/maisi/maisi_train_diff_unet_tutorial.ipynb
@@ -26,7 +26,9 @@
"\n",
"\n",
"\n",
- "In this notebook, we detail the procedure for training a 3D latent diffusion model to generate high-dimensional 3D medical images. Due to the potential for out-of-memory issues on most GPUs when generating large images (e.g., those with dimensions of 512 x 512 x 512 or greater), we have structured the training process into two primary steps: 1) generating image embeddings and 2) training 3D latent diffusion models. The subsequent sections will demonstrate the entire process using a simulated dataset."
+ "In this notebook, we detail the procedure for training a 3D latent diffusion model to generate high-dimensional 3D medical images. Due to the potential for out-of-memory issues on most GPUs when generating large images (e.g., those with dimensions of 512 x 512 x 512 or greater), we have structured the training process into two primary steps: 1) generating image embeddings and 2) training 3D latent diffusion models. The subsequent sections will demonstrate the entire process using a simulated dataset.\n",
+ "\n",
+ "`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi3d-rflow'`. Compared with the previous version `'maisi3d-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section."
]
},
{
@@ -60,27 +62,105 @@
"execution_count": 2,
"id": "e3bf0346",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "MONAI version: 1.4.1rc1+32.g34f37973\n",
+ "Numpy version: 1.26.4\n",
+ "Pytorch version: 2.5.0+cu124\n",
+ "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
+ "MONAI rev id: 34f379735c5e18e7f809453eb1b3606c225c788b\n",
+ "MONAI __file__: /localhome//.local/lib/python3.10/site-packages/monai/__init__.py\n",
+ "\n",
+ "Optional dependencies:\n",
+ "Pytorch Ignite version: 0.4.11\n",
+ "ITK version: 5.4.0\n",
+ "Nibabel version: 5.3.2\n",
+ "scikit-image version: 0.24.0\n",
+ "scipy version: 1.14.1\n",
+ "Pillow version: 11.0.0\n",
+ "Tensorboard version: 2.18.0\n",
+ "gdown version: 5.2.0\n",
+ "TorchVision version: 0.20.0+cu124\n",
+ "tqdm version: 4.66.5\n",
+ "lmdb version: 1.5.1\n",
+ "psutil version: 6.1.0\n",
+ "pandas version: 2.2.3\n",
+ "einops version: 0.8.0\n",
+ "transformers version: 4.40.2\n",
+ "mlflow version: 2.17.1\n",
+ "pynrrd version: 1.0.0\n",
+ "clearml version: 1.16.5rc2\n",
+ "\n",
+ "For details about installing the optional dependencies, please visit:\n",
+ " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
+ "\n"
+ ]
+ }
+ ],
"source": [
- "from scripts.diff_model_setting import setup_logging\n",
"import copy\n",
"import os\n",
"import json\n",
"import numpy as np\n",
"import nibabel as nib\n",
"import subprocess\n",
+ "from IPython.display import Image, display\n",
"\n",
"from monai.apps import download_url\n",
"from monai.data import create_test_image_3d\n",
"from monai.config import print_config\n",
"\n",
- "from IPython.display import Image, display\n",
+ "from scripts.diff_model_setting import setup_logging\n",
"\n",
"print_config()\n",
"\n",
"logger = setup_logging(\"notebook\")"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "c2389853",
+ "metadata": {},
+ "source": [
+ "## Set up the MAISI version\n",
+ "\n",
+ "Choose between `'maisi3d-ddpm'` and `'maisi3d-rflow'`. The differences are:\n",
+ "- The maisi version `'maisi3d-ddpm'` uses basic noise scheduler DDPM. `'maisi3d-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n",
+ "- The maisi version `'maisi3d-ddpm'` requires training images to be labeled with body regions (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi3d-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi3d-rflow'`.\n",
+ "- For the released model weights, `'maisi3d-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi3d-ddpm'`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "31684f74",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[2025-03-14 16:14:21.679][ INFO](notebook) - MAISI version is maisi3d-rflow, whether to use body_region is False\n"
+ ]
+ }
+ ],
+ "source": [
+ "maisi_version = \"maisi3d-rflow\"\n",
+ "if maisi_version == \"maisi3d-ddpm\":\n",
+ " model_def_path = \"./configs/config_maisi3d-ddpm.json\"\n",
+ "elif maisi_version == \"maisi3d-rflow\":\n",
+ " model_def_path = \"./configs/config_maisi3d-rflow.json\"\n",
+ "else:\n",
+ " raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n",
+ "with open(model_def_path, \"r\") as f:\n",
+ " model_def = json.load(f)\n",
+ "include_body_region = model_def[\"include_body_region\"]\n",
+ "logger.info(f\"MAISI version is {maisi_version}, whether to use body_region is {include_body_region}\")"
+ ]
+ },
{
"cell_type": "markdown",
"id": "d8e29c23",
@@ -95,7 +175,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"id": "fc32a7fe",
"metadata": {},
"outputs": [],
@@ -117,7 +197,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"id": "1b199078",
"metadata": {},
"outputs": [
@@ -125,7 +205,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:Generated simulated images.\n"
+ "[2025-03-14 16:14:22.301][ INFO](notebook) - Generated simulated images.\n"
]
}
],
@@ -154,7 +234,7 @@
},
{
"cell_type": "markdown",
- "id": "c2389853",
+ "id": "a059ddcf-8525-4241-9fe3-b661c4bdd336",
"metadata": {},
"source": [
"### Set up directories and configurations\n",
@@ -164,7 +244,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 6,
"id": "6c7b434c",
"metadata": {},
"outputs": [
@@ -172,15 +252,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:files and folders under work_dir: ['environment_maisi_diff_model.json', 'config_maisi.json', 'sim_dataroot', 'sim_datalist.json', 'models', 'embeddings', 'config_maisi_diff_model.json', 'predictions'].\n",
- "INFO:notebook:number of GPUs: 1.\n"
+ "[2025-03-14 16:14:22.313][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n",
+ "[2025-03-14 16:14:22.314][ INFO](notebook) - number of GPUs: 1.\n"
]
}
],
"source": [
"env_config_path = \"./configs/environment_maisi_diff_model.json\"\n",
"model_config_path = \"./configs/config_maisi_diff_model.json\"\n",
- "model_def_path = \"./configs/config_maisi.json\"\n",
"\n",
"# Load environment configuration, model configuration and model definition\n",
"with open(env_config_path, \"r\") as f:\n",
@@ -189,9 +268,6 @@
"with open(model_config_path, \"r\") as f:\n",
" model_config = json.load(f)\n",
"\n",
- "with open(model_def_path, \"r\") as f:\n",
- " model_def = json.load(f)\n",
- "\n",
"env_config_out = copy.deepcopy(env_config)\n",
"model_config_out = copy.deepcopy(model_config)\n",
"model_def_out = copy.deepcopy(model_def)\n",
@@ -229,7 +305,7 @@
" json.dump(model_config_out, f, sort_keys=True, indent=4)\n",
"\n",
"# Update model definition for demo\n",
- "model_def_out[\"autoencoder_def\"][\"num_splits\"] = 4\n",
+ "model_def_out[\"autoencoder_def\"][\"num_splits\"] = 2\n",
"model_def_filepath = os.path.join(work_dir, \"config_maisi.json\")\n",
"with open(model_def_filepath, \"w\") as f:\n",
" json.dump(model_def_out, f, sort_keys=True, indent=4)\n",
@@ -244,7 +320,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"id": "95ea6972",
"metadata": {},
"outputs": [],
@@ -304,7 +380,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 8,
"id": "f45ea863",
"metadata": {},
"outputs": [
@@ -312,7 +388,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:Creating training data...\n"
+ "[2025-03-14 16:14:22.326][ INFO](notebook) - Creating training data...\n"
]
},
{
@@ -320,8 +396,8 @@
"output_type": "stream",
"text": [
"\n",
- "INFO:creating training data:Using device cuda:0\n",
- "INFO:creating training data:filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n",
+ "[2025-03-14 16:14:29.646][ INFO](creating training data) - Using device cuda:0\n",
+ "[2025-03-14 16:14:30.160][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n",
"\n"
]
}
@@ -357,7 +433,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 9,
"id": "0221a658",
"metadata": {},
"outputs": [
@@ -365,9 +441,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
- "INFO:notebook:data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n",
- "INFO:notebook:Completed creating .json files for all embedding files.\n"
+ "[2025-03-14 16:14:32.560][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75]}.\n",
+ "[2025-03-14 16:14:32.562][ INFO](notebook) - Save json file to ./temp_work_dir/./embeddings/tr_image_001_emb.nii.gz.json\n",
+ "[2025-03-14 16:14:32.563][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75]}.\n",
+ "[2025-03-14 16:14:32.564][ INFO](notebook) - Save json file to ./temp_work_dir/./embeddings/tr_image_002_emb.nii.gz.json\n",
+ "[2025-03-14 16:14:32.565][ INFO](notebook) - Completed creating .json files for all embedding files.\n"
]
}
],
@@ -396,15 +474,13 @@
" spacing = [float(_item) for _item in spacing]\n",
"\n",
" # Create the dictionary with the specified keys and values\n",
- " # The region can be selected from one of four regions from top to bottom.\n",
- " # [1,0,0,0] is the head and neck, [0,1,0,0] is the chest region, [0,0,1,0]\n",
- " # is the abdomen region, and [0,0,0,1] is the lower body region.\n",
- " data = {\n",
- " \"dim\": dimensions,\n",
- " \"spacing\": spacing,\n",
- " \"top_region_index\": [0, 1, 0, 0], # chest region\n",
- " \"bottom_region_index\": [0, 0, 1, 0], # abdomen region\n",
- " }\n",
+ " data = {\"dim\": dimensions, \"spacing\": spacing}\n",
+ " if include_body_region:\n",
+ " # The region can be selected from one of four regions from top to bottom.\n",
+ " # [1,0,0,0] is the head and neck, [0,1,0,0] is the chest region, [0,0,1,0]\n",
+ " # is the abdomen region, and [0,0,0,1] is the lower body region.\n",
+ " data[\"top_region_index\"] = [0, 1, 0, 0] # chest region\n",
+ " data[\"bottom_region_index\"] = [0, 0, 1, 0] # abdomen region\n",
" logger.info(f\"data: {data}.\")\n",
"\n",
" # Create the .json filename\n",
@@ -413,6 +489,7 @@
" # Write the dictionary to the .json file\n",
" with open(json_filename, \"w\") as json_file:\n",
" json.dump(data, json_file, indent=4)\n",
+ " logger.info(f\"Save json file to {json_filename}\")\n",
"\n",
"\n",
"folder_path = env_config_out[\"embedding_base_dir\"]\n",
@@ -438,7 +515,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 10,
"id": "ade6389d",
"metadata": {},
"outputs": [
@@ -446,7 +523,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:Training the model...\n"
+ "[2025-03-14 16:14:32.570][ INFO](notebook) - Training the model...\n"
]
},
{
@@ -454,26 +531,26 @@
"output_type": "stream",
"text": [
"\n",
- "INFO:training:Using cuda:0 of 1\n",
- "INFO:training:[config] ckpt_folder -> ./temp_work_dir/./models.\n",
- "INFO:training:[config] data_root -> ./temp_work_dir/./embeddings.\n",
- "INFO:training:[config] data_list -> ./temp_work_dir/sim_datalist.json.\n",
- "INFO:training:[config] lr -> 0.0001.\n",
- "INFO:training:[config] num_epochs -> 2.\n",
- "INFO:training:[config] num_train_timesteps -> 1000.\n",
- "INFO:training:num_files_train: 2\n",
- "INFO:training:Training from scratch.\n",
- "INFO:training:Scaling factor set to 0.8903454542160034.\n",
- "INFO:training:scale_factor -> 0.8903454542160034.\n",
- "INFO:training:torch.set_float32_matmul_precision -> highest.\n",
- "INFO:training:Epoch 1, lr 0.0001.\n",
- "INFO:training:[2024-09-30 06:30:33] epoch 1, iter 1/2, loss: 0.7974, lr: 0.000100000000.\n",
- "INFO:training:[2024-09-30 06:30:33] epoch 1, iter 2/2, loss: 0.7939, lr: 0.000056250000.\n",
- "INFO:training:epoch 1 average loss: 0.7957.\n",
- "INFO:training:Epoch 2, lr 2.5e-05.\n",
- "INFO:training:[2024-09-30 06:30:35] epoch 2, iter 1/2, loss: 0.7902, lr: 0.000025000000.\n",
- "INFO:training:[2024-09-30 06:30:35] epoch 2, iter 2/2, loss: 0.7889, lr: 0.000006250000.\n",
- "INFO:training:epoch 2 average loss: 0.7895.\n",
+ "[2025-03-14 16:14:39.869][ INFO](training) - Using cuda:0 of 1\n",
+ "[2025-03-14 16:14:39.869][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n",
+ "[2025-03-14 16:14:39.869][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n",
+ "[2025-03-14 16:14:39.869][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n",
+ "[2025-03-14 16:14:39.869][ INFO](training) - [config] lr -> 0.0001.\n",
+ "[2025-03-14 16:14:39.869][ INFO](training) - [config] num_epochs -> 2.\n",
+ "[2025-03-14 16:14:39.869][ INFO](training) - [config] num_train_timesteps -> 1000.\n",
+ "[2025-03-14 16:14:41.316][ INFO](training) - Training from scratch.\n",
+ "[2025-03-14 16:14:41.337][ INFO](training) - num_files_train: 2\n",
+ "[2025-03-14 16:14:41.634][ INFO](training) - Scaling factor set to 1.159693956375122.\n",
+ "[2025-03-14 16:14:41.634][ INFO](training) - scale_factor -> 1.159693956375122.\n",
+ "[2025-03-14 16:14:41.637][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n",
+ "[2025-03-14 16:14:41.637][ INFO](training) - Epoch 1, lr 0.0001.\n",
+ "[2025-03-14 16:14:42.627][ INFO](training) - [2025-03-14 16:14:42] epoch 1, iter 1/2, loss: 1.1344, lr: 0.000100000000.\n",
+ "[2025-03-14 16:14:42.739][ INFO](training) - [2025-03-14 16:14:42] epoch 1, iter 2/2, loss: 1.1275, lr: 0.000056250000.\n",
+ "[2025-03-14 16:14:42.783][ INFO](training) - epoch 1 average loss: 1.1310.\n",
+ "[2025-03-14 16:14:44.540][ INFO](training) - Epoch 2, lr 2.5e-05.\n",
+ "[2025-03-14 16:14:44.981][ INFO](training) - [2025-03-14 16:14:44] epoch 2, iter 1/2, loss: 1.1254, lr: 0.000025000000.\n",
+ "[2025-03-14 16:14:45.106][ INFO](training) - [2025-03-14 16:14:45] epoch 2, iter 2/2, loss: 1.1201, lr: 0.000006250000.\n",
+ "[2025-03-14 16:14:45.177][ INFO](training) - epoch 2 average loss: 1.1227.\n",
"\n"
]
}
@@ -509,7 +586,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 11,
"id": "1626526d",
"metadata": {},
"outputs": [
@@ -517,8 +594,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "INFO:notebook:Running inference...\n",
- "INFO:notebook:Completed all steps.\n"
+ "[2025-03-14 16:14:49.136][ INFO](notebook) - Running inference...\n",
+ "[2025-03-14 16:15:02.647][ INFO](notebook) - Completed all steps.\n"
]
},
{
@@ -526,24 +603,24 @@
"output_type": "stream",
"text": [
"\n",
- "INFO:inference:Using cuda:0 of 1 with random seed: 93612\n",
- "INFO:inference:[config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n",
- "INFO:inference:[config] random_seed -> 93612.\n",
- "INFO:inference:[config] output_prefix -> unet_3d.\n",
- "INFO:inference:[config] output_size -> (256, 256, 128).\n",
- "INFO:inference:[config] out_spacing -> (1.0, 1.0, 0.75).\n",
- "INFO:root:`controllable_anatomy_size` is not provided.\n",
- "INFO:inference:checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n",
- "INFO:inference:scale_factor -> 0.8903454542160034.\n",
- "INFO:inference:num_downsample_level -> 4, divisor -> 4.\n",
- "INFO:inference:noise: cuda:0, torch.float32, \n",
+ "[2025-03-14 16:14:56.275][ INFO](inference) - Using cuda:0 of 1 with random seed: 59473\n",
+ "[2025-03-14 16:14:56.275][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n",
+ "[2025-03-14 16:14:56.275][ INFO](inference) - [config] random_seed -> 59473.\n",
+ "[2025-03-14 16:14:56.275][ INFO](inference) - [config] output_prefix -> unet_3d.\n",
+ "[2025-03-14 16:14:56.275][ INFO](inference) - [config] output_size -> (256, 256, 128).\n",
+ "[2025-03-14 16:14:56.275][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n",
+ "[2025-03-14 16:14:56.275][ INFO](root) - `controllable_anatomy_size` is not provided.\n",
+ "[2025-03-14 16:14:58.525][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n",
+ "[2025-03-14 16:14:58.527][ INFO](inference) - scale_factor -> 1.159693956375122.\n",
+ "[2025-03-14 16:14:58.528][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n",
+ "[2025-03-14 16:14:58.536][ INFO](inference) - noise: cuda:0, torch.float32, \n",
"\n",
- " 0%| | 0/10 [00:00, ?it/s]\n",
- " 10%|ββββββββ | 1/10 [00:00<00:02, 3.48it/s]\n",
- " 40%|ββββββββββββββββββββββββββββββ | 4/10 [00:00<00:00, 12.23it/s]\n",
- " 80%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 8/10 [00:00<00:00, 19.26it/s]\n",
- "100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 10/10 [00:00<00:00, 17.80it/s]\n",
- "INFO:inference:Saved ./temp_work_dir/./predictions/unet_3d_seed93612_size256x256x128_spacing1.00x1.00x0.75_20240930063144_rank0.nii.gz.\n",
+ " 0%| | 0/10 [00:00, ?it/s]\n",
+ " 10%|β | 1/10 [00:00<00:03, 2.86it/s]\n",
+ " 50%|βββββ | 5/10 [00:00<00:00, 13.34it/s]\n",
+ " 90%|βββββββββ | 9/10 [00:00<00:00, 20.02it/s]\n",
+ "100%|ββββββββββ| 10/10 [00:00<00:00, 16.56it/s]\n",
+ "[2025-03-14 16:15:00.652][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed59473_size256x256x128_spacing1.00x1.00x0.75_20250314161500_rank0.nii.gz.\n",
"\n"
]
}
@@ -579,7 +656,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 12,
"id": "0d8a344d",
"metadata": {},
"outputs": [
@@ -615,7 +692,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.13"
+ "version": "3.10.12"
}
},
"nbformat": 4,
diff --git a/generation/maisi/scripts/augmentation.py b/generation/maisi/scripts/augmentation.py
index ffdb25c200..64469403a6 100644
--- a/generation/maisi/scripts/augmentation.py
+++ b/generation/maisi/scripts/augmentation.py
@@ -9,94 +9,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Functions to perform augmentation.
-Reference: (TODO), a 132-class label dict that maps organ names and index. URL: TBD
-"""
-
-from typing import Sequence
-
import numpy as np
import torch
+import torch.nn.functional as F
from monai.transforms import Rand3DElastic, RandAffine, RandZoom
-from torch import Tensor
-
-from .utils import dilate_one_img, erode_one_img
-
-MAX_COUNT = 100
+from monai.utils import ensure_tuple_rep
-def initialize_tumor_mask(volume: Tensor, tumor_label: Sequence[int]) -> Tensor:
- """
- Initialize tumor mask for tumor augmentation.
+def erode3d(input_tensor, erosion=3):
+ # Define the structuring element
+ erosion = ensure_tuple_rep(erosion, 3)
+ structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device)
- Args:
- volume: input 3D multi-label mask, [1,H,W,D] torch tensor.
- tumor_label: tumor label in whole_mask, list of int.
-
- Return:
- tumor_mask_, initialized tumor mask, [1,H,W,D] torch tensor.
- """
- tumor_mask_ = torch.zeros_like(volume, dtype=torch.uint8)
- for idx, label in enumerate(tumor_label):
- tumor_mask_[volume == label] = idx + 1
- return tumor_mask_
+ # Pad the input tensor to handle border pixels
+ input_padded = F.pad(
+ input_tensor.float().unsqueeze(0).unsqueeze(0),
+ (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2),
+ mode="constant",
+ value=1.0,
+ )
+ # Apply erosion operation
+ output = F.conv3d(input_padded, structuring_element, padding=0)
-def finalize_tumor_mask(augmented_mask: Tensor, organ_mask: Tensor, threshold_tumor_size: float):
- """
- Try to generate the final tumor mask by combining the augmented tumor mask and organ mask.
- Need to make sure tumor is inside of organ and is larger than threshold_tumor_size.
+ # Set output values based on the minimum value within the structuring element
+ output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0)
- Args:
- augmented_mask: input 3D binary tumor mask, [1,H,W,D] torch tensor.
- organ_mask: input 3D binary organ mask, [1,H,W,D] torch tensor.
- threshold_tumor_size: threshold tumor size, float
+ return output.squeeze(0).squeeze(0)
- Return:
- tumor_mask, [H,W,D] torch tensor; or None if the size did not qualify
- """
- tumor_mask = augmented_mask * organ_mask
- if torch.sum(tumor_mask) >= threshold_tumor_size:
- tumor_mask = dilate_one_img(tumor_mask.squeeze(0), filter_size=5, pad_value=1.0)
- tumor_mask = erode_one_img(tumor_mask, filter_size=5, pad_value=1.0).unsqueeze(0).to(torch.uint8)
- return tumor_mask
- else:
- return None
+def dilate3d(input_tensor, erosion=3):
+ # Define the structuring element
+ erosion = ensure_tuple_rep(erosion, 3)
+ structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device)
-def augmentation_bone_tumor(whole_mask: Tensor, spatial_size: tuple[int, int, int] | int | None = None) -> Tensor:
- """
- Bone tumor augmentation.
+ # Pad the input tensor to handle border pixels
+ input_padded = F.pad(
+ input_tensor.float().unsqueeze(0).unsqueeze(0),
+ (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2),
+ mode="constant",
+ value=1.0,
+ )
- Args:
- whole_mask: input 3D multi-label mask, [1,1,H,W,D] torch tensor.
- spatial_size: output image spatial size, used in random transform.
- If not defined, will use (H,W,D). If some components are non-positive values,
- the transform will use the corresponding components of whole_mask size.
- For example, spatial_size=(128, 128, -1) will be adapted to (128, 128, 64)
- if the third spatial dimension size of whole_mask is 64.
+ # Apply erosion operation
+ output = F.conv3d(input_padded, structuring_element, padding=0)
- Return:
- augmented mask, with shape of spatial_size and data type as whole_mask.
+ # Set output values based on the minimum value within the structuring element
+ output = torch.where(output > 0, 1.0, 0.0)
- Example:
+ return output.squeeze(0).squeeze(0)
- .. code-block:: python
- # define a multi-label mask
- whole_mask = torch.zeros([1,1,128,128,128])
- whole_mask[0,0, 90:110, 90:110, 90:110]=127
- whole_mask[0,0, 97:103, 97:103, 97:103]=128
- augmented_whole_mask = augmentation_bone_tumor(whole_mask)
- """
- # Initialize binary tumor mask
- device = whole_mask.device
- volume = whole_mask.squeeze(0).cuda() if not whole_mask.is_cuda else whole_mask.squeeze(0)
- tumor_label = [128]
- tumor_mask_ = initialize_tumor_mask(volume, tumor_label)
+def augmentation_tumor_bone(pt_nda, output_size, random_seed):
+ volume = pt_nda.squeeze(0)
+ real_l_volume_ = torch.zeros_like(volume)
+ real_l_volume_[volume == 128] = 1
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
- # Define augmentation transform
elastic = RandAffine(
mode="nearest",
prob=1.0,
@@ -105,78 +74,52 @@ def augmentation_bone_tumor(whole_mask: Tensor, spatial_size: tuple[int, int, in
scale_range=(0.15, 0.15, 0),
padding_mode="zeros",
)
+ elastic.set_random_state(seed=random_seed)
- tumor_size = torch.sum((tumor_mask_ > 0).float())
+ tumor_szie = torch.sum((real_l_volume_ > 0).float())
###########################
# remove pred in pseudo_label in real lesion region
- volume[tumor_mask_ > 0] = 200
+ volume[real_l_volume_ > 0] = 200
###########################
- if tumor_size > 0:
+ if tumor_szie > 0:
# get organ mask
organ_mask = (
torch.logical_and(33 <= volume, volume <= 56).float()
+ torch.logical_and(63 <= volume, volume <= 97).float()
+ (volume == 127).float()
+ (volume == 114).float()
- + tumor_mask_
+ + real_l_volume_
)
organ_mask = (organ_mask > 0).float()
-
- # augment mask
- count = 0
+ cnt = 0
while True:
- threshold = 0.8 if count < 40 else 0.75
- tumor_mask = tumor_mask_
- # apply random augmentation
- augmented_mask = elastic(tumor_mask > 0, spatial_size=spatial_size).as_tensor()
- # generate final tumor mask
- count += 1
- tumor_mask = finalize_tumor_mask(augmented_mask, organ_mask, tumor_size * threshold)
- if tumor_mask is not None:
+ threshold = 0.8 if cnt < 40 else 0.75
+ real_l_volume = real_l_volume_
+ # random distor mask
+ distored_mask = elastic((real_l_volume > 0).cuda(), spatial_size=tuple(output_size)).as_tensor()
+ real_l_volume = distored_mask * organ_mask
+ cnt += 1
+ print(torch.sum(real_l_volume), "|", tumor_szie * threshold)
+ if torch.sum(real_l_volume) >= tumor_szie * threshold:
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8)
break
- if count > MAX_COUNT:
- raise ValueError("Please check if bone lesion is inside bone.")
else:
- tumor_mask = tumor_mask_
-
- # update the new tumor mask
- volume[tumor_mask == 1] = tumor_label[0]
-
- return volume.unsqueeze(0).to(device)
-
-
-def augmentation_liver_tumor(whole_mask: Tensor, spatial_size: tuple[int, int, int] | int | None = None) -> Tensor:
- """
- Bone liver augmentation.
+ real_l_volume = real_l_volume_
- Args:
- whole_mask: input 3D multi-label mask, [1,1,H,W,D] torch tensor.
- spatial_size: output image spatial size, used in random transform.
- If not defined, will use (H,W,D). If some components are non-positive values,
- the transform will use the corresponding components of whole_mask size.
- For example, spatial_size=(128, 128, -1) will be adapted to (128, 128, 64)
- if the third spatial dimension size of whole_mask is 64.
+ volume[real_l_volume == 1] = 128
- Return:
- augmented mask, with shape of spatial_size and data type as whole_mask.
+ pt_nda = volume.unsqueeze(0)
+ return pt_nda
- Example:
- .. code-block:: python
+def augmentation_tumor_liver(pt_nda, output_size, random_seed):
+ volume = pt_nda.squeeze(0)
+ real_l_volume_ = torch.zeros_like(volume)
+ real_l_volume_[volume == 1] = 1
+ real_l_volume_[volume == 26] = 2
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
- # define a multi-label mask
- whole_mask = torch.zeros([1,1,128,128,128])
- whole_mask[0,0, 90:110, 90:110, 90:110]=1
- whole_mask[0,0, 97:103, 97:103, 97:103]=26
- augmented_whole_mask = augment_liver_tumor(whole_mask)
- """
- # Initialize binary tumor mask
- device = whole_mask.device
- volume = whole_mask.squeeze(0).cuda() if not whole_mask.is_cuda else whole_mask.squeeze(0)
- tumor_label = [1, 26]
- tumor_mask_ = initialize_tumor_mask(volume, tumor_label)
-
- # Define augmentation transform
elastic = Rand3DElastic(
mode="nearest",
prob=1.0,
@@ -187,74 +130,45 @@ def augmentation_liver_tumor(whole_mask: Tensor, spatial_size: tuple[int, int, i
scale_range=(0.2, 0.2, 0.2),
padding_mode="zeros",
)
+ elastic.set_random_state(seed=random_seed)
- tumor_size = torch.sum(tumor_mask_ == 2)
+ tumor_szie = torch.sum(real_l_volume_ == 2)
###########################
# remove pred organ labels
volume[volume == 1] = 0
volume[volume == 26] = 0
# before move tumor maks, full the original location by organ labels
- volume[tumor_mask_ == 1] = 1
- volume[tumor_mask_ == 2] = 1
+ volume[real_l_volume_ == 1] = 1
+ volume[real_l_volume_ == 2] = 1
###########################
- if tumor_size > 0:
- count = 0
+ while True:
+ real_l_volume = real_l_volume_
+ # random distor mask
+ real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor()
# get organ mask
- organ_mask = (tumor_mask_ == 1).float() + (tumor_mask_ == 2).float()
- organ_mask = dilate_one_img(organ_mask.squeeze(0), filter_size=5, pad_value=1.0)
- organ_mask = erode_one_img(organ_mask, filter_size=5, pad_value=1.0).unsqueeze(0)
- while True:
- tumor_mask = tumor_mask_
- # apply random augmentation
- augmented_mask = elastic((tumor_mask == 2), spatial_size=spatial_size).as_tensor()
-
- # generate final tumor mask
- count += 1
- tumor_mask = finalize_tumor_mask(augmented_mask, organ_mask, tumor_size * 0.80)
- if tumor_mask is not None:
- break
- if count > MAX_COUNT:
- raise ValueError("Please check if liver tumor is inside liver.")
- else:
- tumor_mask = tumor_mask_
-
- volume[tumor_mask == 1] = 26
-
- return volume.unsqueeze(0).to(device)
-
+ organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float()
-def augmentation_lung_tumor(whole_mask: Tensor, spatial_size: tuple[int, int, int] | int | None = None) -> Tensor:
- """
- Lung tumor augmentation.
+ organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5)
+ organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0)
+ real_l_volume = real_l_volume * organ_mask
+ print(torch.sum(real_l_volume), "|", tumor_szie * 0.80)
+ if torch.sum(real_l_volume) >= tumor_szie * 0.80:
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0)
+ break
- Args:
- whole_mask: input 3D multi-label mask, [1,1,H,W,D] torch tensor.
- spatial_size: output image spatial size, used in random transform.
- If not defined, will use (H,W,D). If some components are non-positive values,
- the transform will use the corresponding components of whole_mask size.
- For example, spatial_size=(128, 128, -1) will be adapted to (128, 128, 64)
- if the third spatial dimension size of whole_mask is 64.
+ volume[real_l_volume == 1] = 26
- Return:
- augmented mask, with shape of spatial_size and data type as whole_mask.
+ pt_nda = volume.unsqueeze(0)
+ return pt_nda
- Example:
- .. code-block:: python
+def augmentation_tumor_lung(pt_nda, output_size, random_seed):
+ volume = pt_nda.squeeze(0)
+ real_l_volume_ = torch.zeros_like(volume)
+ real_l_volume_[volume == 23] = 1
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
- # define a multi-label mask
- whole_mask = torch.zeros([1,1,128,128,128])
- whole_mask[0,0, 90:110, 90:110, 90:110]=28
- whole_mask[0,0, 97:103, 97:103, 97:103]=23
- augmented_whole_mask = augmentation_lung_tumor(whole_mask)
- """
- # Initialize binary tumor mask
- device = whole_mask.device
- volume = whole_mask.squeeze(0).cuda() if not whole_mask.is_cuda else whole_mask.squeeze(0)
- tumor_label = [23]
- tumor_mask_ = initialize_tumor_mask(volume, tumor_label)
-
- # Define augmentation transform
elastic = Rand3DElastic(
mode="nearest",
prob=1.0,
@@ -265,87 +179,61 @@ def augmentation_lung_tumor(whole_mask: Tensor, spatial_size: tuple[int, int, in
scale_range=(0.15, 0.15, 0.15),
padding_mode="zeros",
)
+ elastic.set_random_state(seed=random_seed)
- tumor_size = torch.sum(tumor_mask_)
+ tumor_szie = torch.sum(real_l_volume_)
# before move lung tumor maks, full the original location by lung labels
- new_tumor_mask_ = dilate_one_img(tumor_mask_.squeeze(0), filter_size=3, pad_value=1.0)
- new_tumor_mask_ = new_tumor_mask_.unsqueeze(0)
- new_tumor_mask_[tumor_mask_ > 0] = 0
- new_tumor_mask_[volume < 28] = 0
- new_tumor_mask_[volume > 32] = 0
- tmp = volume[(volume * new_tumor_mask_).nonzero(as_tuple=True)].view(-1)
+ new_real_l_volume_ = dilate3d(real_l_volume_.squeeze(0), erosion=3)
+ new_real_l_volume_ = new_real_l_volume_.unsqueeze(0)
+ new_real_l_volume_[real_l_volume_ > 0] = 0
+ new_real_l_volume_[volume < 28] = 0
+ new_real_l_volume_[volume > 32] = 0
+ tmp = volume[(volume * new_real_l_volume_).nonzero(as_tuple=True)].view(-1)
mode = torch.mode(tmp, 0)[0].item()
+ print(mode)
assert 28 <= mode <= 32
- volume[tumor_mask_.bool()] = mode
+ volume[real_l_volume_.bool()] = mode
###########################
-
- if tumor_size > 0:
- count = 0
- # get lung mask v2 (133 order)
- organ_mask = (
- (volume == 28).float()
- + (volume == 29).float()
- + (volume == 30).float()
- + (volume == 31).float()
- + (volume == 32).float()
- )
- organ_mask = dilate_one_img(organ_mask.squeeze(0), filter_size=5, pad_value=1.0)
- organ_mask = erode_one_img(organ_mask, filter_size=5, pad_value=1.0).unsqueeze(0)
-
+ if tumor_szie > 0:
# aug
while True:
- tumor_mask = tumor_mask_
- # apply random augmentation
- augmented_mask = elastic(tumor_mask, spatial_size=spatial_size).as_tensor()
-
- # generate final tumor mask
- count += 1
- tumor_mask = finalize_tumor_mask(augmented_mask, organ_mask, tumor_size * 0.85)
- if tumor_mask is not None:
+ real_l_volume = real_l_volume_
+ # random distor mask
+ real_l_volume = elastic(real_l_volume, spatial_size=tuple(output_size)).as_tensor()
+ # get lung mask v2 (133 order)
+ lung_mask = (
+ (volume == 28).float()
+ + (volume == 29).float()
+ + (volume == 30).float()
+ + (volume == 31).float()
+ + (volume == 32).float()
+ )
+
+ lung_mask = dilate3d(lung_mask.squeeze(0), erosion=5)
+ lung_mask = erode3d(lung_mask, erosion=5).unsqueeze(0)
+ real_l_volume = real_l_volume * lung_mask
+ print(torch.sum(real_l_volume), "|", tumor_szie * 0.85)
+ if torch.sum(real_l_volume) >= tumor_szie * 0.85:
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8)
break
- if count > MAX_COUNT:
- raise ValueError("Please check if lung tumor is inside lung.")
else:
- tumor_mask = tumor_mask_
-
- volume[tumor_mask == 1] = tumor_label[0]
+ real_l_volume = real_l_volume_
- return volume.unsqueeze(0).to(device)
+ volume[real_l_volume == 1] = 23
+ pt_nda = volume.unsqueeze(0)
+ return pt_nda
-def augmentation_pancreas_tumor(whole_mask: Tensor, spatial_size: tuple[int, int, int] | int | None = None) -> Tensor:
- """
- Pancreas tumor augmentation.
- Args:
- whole_mask: input 3D multi-label mask, [1,1,H,W,D] torch tensor.
- spatial_size: output image spatial size, used in random transform.
- If not defined, will use (H,W,D). If some components are non-positive values,
- the transform will use the corresponding components of whole_mask size.
- For example, spatial_size=(128, 128, -1) will be adapted to (128, 128, 64)
- if the third spatial dimension size of whole_mask is 64.
+def augmentation_tumor_pancreas(pt_nda, output_size, random_seed):
+ volume = pt_nda.squeeze(0)
+ real_l_volume_ = torch.zeros_like(volume)
+ real_l_volume_[volume == 4] = 1
+ real_l_volume_[volume == 24] = 2
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
- Return:
- augmented mask, with shape of spatial_size and data type as whole_mask.
-
- Example:
-
- .. code-block:: python
-
- # define a multi-label mask
- whole_mask = torch.zeros([1,1,128,128,128])
- whole_mask[0,0, 90:110, 90:110, 90:110]=24
- whole_mask[0,0, 97:103, 97:103, 97:103]=4
- augmented_whole_mask = augmentation_pancreas_tumor(whole_mask)
- """
- # Initialize binary tumor mask
- device = whole_mask.device
- volume = whole_mask.squeeze(0).cuda() if not whole_mask.is_cuda else whole_mask.squeeze(0)
- tumor_label = [4, 24]
- tumor_mask_ = initialize_tumor_mask(volume, tumor_label)
-
- # Define augmentation transform
elastic = Rand3DElastic(
mode="nearest",
prob=1.0,
@@ -356,74 +244,45 @@ def augmentation_pancreas_tumor(whole_mask: Tensor, spatial_size: tuple[int, int
scale_range=(0.1, 0.1, 0.1),
padding_mode="zeros",
)
+ elastic.set_random_state(seed=random_seed)
- tumor_size = torch.sum(tumor_mask_ == 2)
+ tumor_szie = torch.sum(real_l_volume_ == 2)
###########################
# remove pred organ labels
volume[volume == 24] = 0
volume[volume == 4] = 0
# before move tumor maks, full the original location by organ labels
- volume[tumor_mask_ == 1] = 4
- volume[tumor_mask_ == 2] = 4
+ volume[real_l_volume_ == 1] = 4
+ volume[real_l_volume_ == 2] = 4
###########################
- if tumor_size > 0:
- count = 0
+ while True:
+ real_l_volume = real_l_volume_
+ # random distor mask
+ real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor()
# get organ mask
- organ_mask = (tumor_mask_ == 1).float() + (tumor_mask_ == 2).float()
- organ_mask = dilate_one_img(organ_mask.squeeze(0), filter_size=5, pad_value=1.0)
- organ_mask = erode_one_img(organ_mask, filter_size=5, pad_value=1.0).unsqueeze(0)
- while True:
- tumor_mask = tumor_mask_
- # apply random augmentation
- augmented_mask = elastic((tumor_mask == 2), spatial_size=spatial_size).as_tensor()
-
- # generate final tumor mask
- count += 1
- tumor_mask = finalize_tumor_mask(augmented_mask, organ_mask, tumor_size * 0.80)
- if tumor_mask is not None:
- break
- if count > MAX_COUNT:
- raise ValueError("Please check if pancreas tumor is inside pancreas.")
- else:
- tumor_mask = tumor_mask_
+ organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float()
- volume[tumor_mask == 1] = 24
+ organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5)
+ organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0)
+ real_l_volume = real_l_volume * organ_mask
+ print(torch.sum(real_l_volume), "|", tumor_szie * 0.80)
+ if torch.sum(real_l_volume) >= tumor_szie * 0.80:
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0)
+ break
- return volume.unsqueeze(0).to(device)
+ volume[real_l_volume == 1] = 24
+ pt_nda = volume.unsqueeze(0)
+ return pt_nda
-def augmentation_colon_tumor(whole_mask: Tensor, spatial_size: tuple[int, int, int] | int | None = None) -> Tensor:
- """
- Colon tumor augmentation.
- Args:
- whole_mask: input 3D multi-label mask, [1,1,H,W,D] torch tensor.
- spatial_size: output image spatial size, used in random transform.
- If not defined, will use (H,W,D). If some components are non-positive values,
- the transform will use the corresponding components of whole_mask size.
- For example, spatial_size=(128, 128, -1) will be adapted to (128, 128, 64)
- if the third spatial dimension size of whole_mask is 64.
+def augmentation_tumor_colon(pt_nda, output_size, random_seed):
+ volume = pt_nda.squeeze(0)
+ real_l_volume_ = torch.zeros_like(volume)
+ real_l_volume_[volume == 27] = 1
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
- Return:
- augmented mask, with shape of spatial_size and data type as whole_mask.
-
- Example:
-
- .. code-block:: python
-
- # define a multi-label mask
- whole_mask = torch.zeros([1,1,128,128,128])
- whole_mask[0,0, 90:110, 90:110, 90:110]=62
- whole_mask[0,0, 97:103, 97:103, 97:103]=27
- augmented_whole_mask = augmentation_colon_tumor(whole_mask)
- """
- # Initialize binary tumor mask
- device = whole_mask.device
- volume = whole_mask.squeeze(0).cuda() if not whole_mask.is_cuda else whole_mask.squeeze(0)
- tumor_label = [27]
- tumor_mask_ = initialize_tumor_mask(volume, tumor_label)
-
- # Define augmentation transform
elastic = Rand3DElastic(
mode="nearest",
prob=1.0,
@@ -434,125 +293,81 @@ def augmentation_colon_tumor(whole_mask: Tensor, spatial_size: tuple[int, int, i
scale_range=(0.1, 0.1, 0.1),
padding_mode="zeros",
)
+ elastic.set_random_state(seed=random_seed)
- tumor_size = torch.sum(tumor_mask_)
+ tumor_szie = torch.sum(real_l_volume_)
###########################
# before move tumor maks, full the original location by organ labels
- volume[tumor_mask_.bool()] = 62
+ volume[real_l_volume_.bool()] = 62
###########################
- if tumor_size > 0:
+ if tumor_szie > 0:
# get organ mask
organ_mask = (volume == 62).float()
- organ_mask = dilate_one_img(organ_mask.squeeze(0), filter_size=5, pad_value=1.0)
- organ_mask = erode_one_img(organ_mask, filter_size=5, pad_value=1.0).unsqueeze(0)
-
- count = 0
+ organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5)
+ organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0)
+ # cnt = 0
+ cnt = 0
while True:
threshold = 0.8
- tumor_mask = tumor_mask_
- if count < 20:
- # apply random augmentation
- augmented_mask = elastic((tumor_mask == 1), spatial_size=spatial_size).as_tensor()
- tumor_mask = augmented_mask * organ_mask
- elif 20 <= count < 40:
+ real_l_volume = real_l_volume_
+ if cnt < 20:
+ # random distor mask
+ distored_mask = elastic((real_l_volume == 1).cuda(), spatial_size=tuple(output_size)).as_tensor()
+ real_l_volume = distored_mask * organ_mask
+ elif 20 <= cnt < 40:
threshold = 0.75
else:
break
- # generate final tumor mask
- count += 1
- tumor_mask = finalize_tumor_mask(tumor_mask, organ_mask, tumor_size * threshold)
- if tumor_mask is not None:
+ real_l_volume = real_l_volume * organ_mask
+ print(torch.sum(real_l_volume), "|", tumor_szie * threshold)
+ cnt += 1
+ if torch.sum(real_l_volume) >= tumor_szie * threshold:
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8)
break
- if count > MAX_COUNT:
- raise ValueError("Please check if colon tumor is inside colon.")
else:
- tumor_mask = tumor_mask_
-
- volume[tumor_mask == 1] = tumor_label[0]
+ real_l_volume = real_l_volume_
+ # break
+ volume[real_l_volume == 1] = 27
- return volume.unsqueeze(0).to(device)
+ pt_nda = volume.unsqueeze(0)
+ return pt_nda
-def augmentation_body(whole_mask: Tensor) -> Tensor:
- """
- Whole body mask augmentation.
+def augmentation_body(pt_nda, random_seed):
+ volume = pt_nda.squeeze(0)
- Args:
- whole_mask: input 3D multi-label mask, [1,1,H,W,D] torch tensor.
+ zoom = RandZoom(min_zoom=0.99, max_zoom=1.01, mode="nearest", align_corners=None, prob=1.0)
+ zoom.set_random_state(seed=random_seed)
- Return:
- augmented mask, with same shape and data type as whole_mask.
-
- Example:
-
- .. code-block:: python
-
- # define a multi-label mask
- whole_mask = torch.zeros([1,1,128,128,128])
- whole_mask[0,0, 90:110, 90:110, 90:110]=127
- whole_mask[0,0, 97:103, 97:103, 97:103]=128
- augmented_whole_mask = augmentation_body(whole_mask)
- """
- device = whole_mask.device
- volume = whole_mask.squeeze(0).cuda() if not whole_mask.is_cuda else whole_mask.squeeze(0)
-
- # Define augmentation transform
- zoom = RandZoom(
- min_zoom=0.99,
- max_zoom=1.01,
- mode="nearest",
- align_corners=None,
- prob=1.0,
- )
- # apply random augmentation
volume = zoom(volume)
- return volume.unsqueeze(0).to(device)
-
-
-def augmentation(whole_mask: Tensor, spatial_size: tuple[int, int, int] | int | None = None) -> Tensor:
- """
- Tumor or whole body mask augmentation. If tumor exist, augment tumor mask; if not, augment whole body mask
-
- Args:
- whole_mask: input 3D multi-label mask, [1,1,H,W,D] torch tensor.
- spatial_size: output image spatial size, used in random transform. If not defined, will use (H,W,D). If some components are non-positive values, the transform will use the corresponding components of whole_mask size. For example, spatial_size=(128, 128, -1) will be adapted to (128, 128, 64) if the third spatial dimension size of whole_mask is 64.
-
- Return:
- augmented mask, with shape of spatial_size and data type as whole_mask.
-
- Example:
+ pt_nda = volume.unsqueeze(0)
+ return pt_nda
- .. code-block:: python
- # define a multi-label mask
- whole_mask = torch.zeros([1,1, 128,128,128])
- whole_mask[0,0, 90:110, 90:110, 90:110]=127
- whole_mask[0,0, 97:103, 97:103, 97:103]=128
- augmented_whole_mask = augmentation(whole_mask)
- """
- label_list = torch.unique(whole_mask)
+def augmentation(pt_nda, output_size, random_seed):
+ label_list = torch.unique(pt_nda)
label_list = list(label_list.cpu().numpy())
- # Note that we only augment one type of tumor.
if 128 in label_list:
print("augmenting bone lesion/tumor")
- whole_mask = augmentation_bone_tumor(whole_mask, spatial_size)
+ pt_nda = augmentation_tumor_bone(pt_nda, output_size, random_seed)
elif 26 in label_list:
print("augmenting liver tumor")
- whole_mask = augmentation_liver_tumor(whole_mask, spatial_size)
+ pt_nda = augmentation_tumor_liver(pt_nda, output_size, random_seed)
elif 23 in label_list:
print("augmenting lung tumor")
- whole_mask = augmentation_lung_tumor(whole_mask, spatial_size)
+ pt_nda = augmentation_tumor_lung(pt_nda, output_size, random_seed)
elif 24 in label_list:
print("augmenting pancreas tumor")
- whole_mask = augmentation_pancreas_tumor(whole_mask, spatial_size)
+ pt_nda = augmentation_tumor_pancreas(pt_nda, output_size, random_seed)
elif 27 in label_list:
print("augmenting colon tumor")
- whole_mask = augmentation_colon_tumor(whole_mask, spatial_size)
+ pt_nda = augmentation_tumor_colon(pt_nda, output_size, random_seed)
else:
print("augmenting body")
- whole_mask = augmentation_body(whole_mask)
+ pt_nda = augmentation_body(pt_nda, random_seed)
- return whole_mask
+ return pt_nda
diff --git a/generation/maisi/scripts/diff_model_create_training_data.py b/generation/maisi/scripts/diff_model_create_training_data.py
index ca44b43cc7..177dfa34cf 100644
--- a/generation/maisi/scripts/diff_model_create_training_data.py
+++ b/generation/maisi/scripts/diff_model_create_training_data.py
@@ -17,12 +17,11 @@
import os
from pathlib import Path
+import monai
import nibabel as nib
import numpy as np
import torch
import torch.distributed as dist
-
-import monai
from monai.transforms import Compose
from monai.utils import set_determinism
diff --git a/generation/maisi/scripts/diff_model_infer.py b/generation/maisi/scripts/diff_model_infer.py
index 9ba837328c..8b01e1cc96 100644
--- a/generation/maisi/scripts/diff_model_infer.py
+++ b/generation/maisi/scripts/diff_model_infer.py
@@ -21,14 +21,15 @@
import numpy as np
import torch
import torch.distributed as dist
-from tqdm import tqdm
-
from monai.inferers import sliding_window_inference
+from monai.inferers.inferer import SlidingWindowInferer
+from monai.networks.schedulers import RFlowScheduler
from monai.utils import set_determinism
+from tqdm import tqdm
from .diff_model_setting import initialize_distributed, load_config, setup_logging
from .sample import ReconModel, check_input
-from .utils import define_instance
+from .utils import define_instance, dynamic_infer
def set_random_seed(seed: int) -> int:
@@ -94,8 +95,11 @@ def prepare_tensors(args: argparse.Namespace, device: torch.device) -> tuple:
top_region_index_tensor = torch.from_numpy(top_region_index_tensor[np.newaxis, :]).half().to(device)
bottom_region_index_tensor = torch.from_numpy(bottom_region_index_tensor[np.newaxis, :]).half().to(device)
spacing_tensor = torch.from_numpy(spacing_tensor[np.newaxis, :]).half().to(device)
+ modality_tensor = args.diffusion_unet_inference["modality"] * torch.ones(
+ (len(spacing_tensor)), dtype=torch.long
+ ).to(device)
- return top_region_index_tensor, bottom_region_index_tensor, spacing_tensor
+ return top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor
def run_inference(
@@ -107,6 +111,7 @@ def run_inference(
top_region_index_tensor: torch.Tensor,
bottom_region_index_tensor: torch.Tensor,
spacing_tensor: torch.Tensor,
+ modality_tensor: torch.Tensor,
output_size: tuple,
divisor: int,
logger: logging.Logger,
@@ -123,6 +128,7 @@ def run_inference(
top_region_index_tensor (torch.Tensor): Top region index tensor.
bottom_region_index_tensor (torch.Tensor): Bottom region index tensor.
spacing_tensor (torch.Tensor): Spacing tensor.
+ modality_tensor (torch.Tensor): Modality tensor.
output_size (tuple): Output size of the synthetic image.
divisor (int): Divisor for downsample level.
logger (logging.Logger): Logger for logging information.
@@ -130,6 +136,9 @@ def run_inference(
Returns:
np.ndarray: Generated synthetic image data.
"""
+ include_body_region = unet.include_top_region_index_input
+ include_modality = unet.num_class_embeds is not None
+
noise = torch.randn(
(
1,
@@ -144,38 +153,64 @@ def run_inference(
image = noise
noise_scheduler = define_instance(args, "noise_scheduler")
- noise_scheduler.set_timesteps(num_inference_steps=args.diffusion_unet_inference["num_inference_steps"])
+ if isinstance(noise_scheduler, RFlowScheduler):
+ noise_scheduler.set_timesteps(
+ num_inference_steps=args.diffusion_unet_inference["num_inference_steps"],
+ input_img_size_numel=torch.prod(torch.tensor(noise.shape[2:])),
+ )
+ else:
+ noise_scheduler.set_timesteps(num_inference_steps=args.diffusion_unet_inference["num_inference_steps"])
recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device)
autoencoder.eval()
unet.eval()
+ all_timesteps = noise_scheduler.timesteps
+ all_next_timesteps = torch.cat((all_timesteps[1:], torch.tensor([0], dtype=all_timesteps.dtype)))
+ progress_bar = tqdm(
+ zip(all_timesteps, all_next_timesteps),
+ total=min(len(all_timesteps), len(all_next_timesteps)),
+ )
with torch.amp.autocast("cuda", enabled=True):
- for t in tqdm(noise_scheduler.timesteps, ncols=110):
- model_output = unet(
- x=image,
- timesteps=torch.Tensor((t,)).to(device),
- top_region_index_tensor=top_region_index_tensor,
- bottom_region_index_tensor=bottom_region_index_tensor,
- spacing_tensor=spacing_tensor,
- )
- image, _ = noise_scheduler.step(model_output, t, image)
-
- synthetic_images = sliding_window_inference(
- inputs=image,
- roi_size=(
- min(output_size[0] // divisor // 4 * 3, 96),
- min(output_size[1] // divisor // 4 * 3, 96),
- min(output_size[2] // divisor // 4 * 3, 96),
- ),
+ for t, next_t in progress_bar:
+ # Create a dictionary to store the inputs
+ unet_inputs = {
+ "x": image,
+ "timesteps": torch.Tensor((t,)).to(device),
+ "spacing_tensor": spacing_tensor,
+ }
+
+ # Add extra arguments if include_body_region is True
+ if include_body_region:
+ unet_inputs.update(
+ {
+ "top_region_index_tensor": top_region_index_tensor,
+ "bottom_region_index_tensor": bottom_region_index_tensor,
+ }
+ )
+
+ if include_modality:
+ unet_inputs.update(
+ {
+ "class_labels": modality_tensor,
+ }
+ )
+ model_output = unet(**unet_inputs)
+ if not isinstance(noise_scheduler, RFlowScheduler):
+ image, _ = noise_scheduler.step(model_output, t, image) # type: ignore
+ else:
+ image, _ = noise_scheduler.step(model_output, t, image, next_t) # type: ignore
+
+ inferer = SlidingWindowInferer(
+ roi_size=[80, 80, 80],
sw_batch_size=1,
- predictor=recon_model,
+ progress=True,
mode="gaussian",
- overlap=2.0 / 3.0,
+ overlap=0.4,
sw_device=device,
device=device,
)
-
+ synthetic_images = dynamic_infer(inferer, recon_model, image)
data = synthetic_images.squeeze().cpu().detach().numpy()
a_min, a_max, b_min, b_max = -1000, 1000, 0, 1
data = (data - b_min) / (b_max - b_min) * (a_max - a_min) + a_min
@@ -256,7 +291,7 @@ def diff_model_infer(env_config_path: str, model_config_path: str, model_def_pat
divisor = 2 ** (num_downsample_level - 2)
logger.info(f"num_downsample_level -> {num_downsample_level}, divisor -> {divisor}.")
- top_region_index_tensor, bottom_region_index_tensor, spacing_tensor = prepare_tensors(args, device)
+ top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor = prepare_tensors(args, device)
data = run_inference(
args,
device,
@@ -266,6 +301,7 @@ def diff_model_infer(env_config_path: str, model_config_path: str, model_def_pat
top_region_index_tensor,
bottom_region_index_tensor,
spacing_tensor,
+ modality_tensor,
output_size,
divisor,
logger,
diff --git a/generation/maisi/scripts/diff_model_setting.py b/generation/maisi/scripts/diff_model_setting.py
index 6ba4688867..3118b56d07 100644
--- a/generation/maisi/scripts/diff_model_setting.py
+++ b/generation/maisi/scripts/diff_model_setting.py
@@ -17,7 +17,6 @@
import torch
import torch.distributed as dist
-
from monai.utils import RankFilter
diff --git a/generation/maisi/scripts/diff_model_train.py b/generation/maisi/scripts/diff_model_train.py
index 2309c8a4ee..c616b89c37 100644
--- a/generation/maisi/scripts/diff_model_train.py
+++ b/generation/maisi/scripts/diff_model_train.py
@@ -18,15 +18,16 @@
from datetime import datetime
from pathlib import Path
+import monai
import torch
import torch.distributed as dist
-from torch.amp import GradScaler, autocast
-from torch.nn.parallel import DistributedDataParallel
-
-import monai
from monai.data import DataLoader, partition_dataset
+from monai.networks.schedulers import RFlowScheduler
+from monai.networks.schedulers.ddpm import DDPMPredictionType
from monai.transforms import Compose
from monai.utils import first
+from torch.amp import GradScaler, autocast
+from torch.nn.parallel import DistributedDataParallel
from .diff_model_setting import initialize_distributed, load_config, setup_logging
from .utils import define_instance
@@ -49,7 +50,12 @@ def load_filenames(data_list_path: str) -> list:
def prepare_data(
- train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1
+ train_files: list,
+ device: torch.device,
+ cache_rate: float,
+ num_workers: int = 2,
+ batch_size: int = 1,
+ include_body_region: bool = False,
) -> DataLoader:
"""
Prepare training data.
@@ -60,6 +66,7 @@ def prepare_data(
cache_rate (float): Cache rate for dataset.
num_workers (int): Number of workers for data loading.
batch_size (int): Mini-batch size.
+ include_body_region (bool): Whether to include body region in data
Returns:
DataLoader: Data loader for training.
@@ -69,22 +76,24 @@ def _load_data_from_file(file_path, key):
with open(file_path) as f:
return torch.FloatTensor(json.load(f)[key])
- train_transforms = Compose(
- [
- monai.transforms.LoadImaged(keys=["image"]),
- monai.transforms.EnsureChannelFirstd(keys=["image"]),
+ train_transforms_list = [
+ monai.transforms.LoadImaged(keys=["image"]),
+ monai.transforms.EnsureChannelFirstd(keys=["image"]),
+ monai.transforms.Lambdad(keys="spacing", func=lambda x: _load_data_from_file(x, "spacing")),
+ monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
+ ]
+ if include_body_region:
+ train_transforms_list += [
monai.transforms.Lambdad(
keys="top_region_index", func=lambda x: _load_data_from_file(x, "top_region_index")
),
monai.transforms.Lambdad(
keys="bottom_region_index", func=lambda x: _load_data_from_file(x, "bottom_region_index")
),
- monai.transforms.Lambdad(keys="spacing", func=lambda x: _load_data_from_file(x, "spacing")),
monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2),
monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2),
- monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
]
- )
+ train_transforms = Compose(train_transforms_list)
train_ds = monai.data.CacheDataset(
data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers
@@ -216,6 +225,9 @@ def train_one_epoch(
Returns:
torch.Tensor: Training loss for the epoch.
"""
+ include_body_region = unet.include_top_region_index_input
+ include_modality = unet.num_class_embeds is not None
+
if local_rank == 0:
current_lr = optimizer.param_groups[0]["lr"]
logger.info(f"Epoch {epoch + 1}, lr {current_lr}.")
@@ -231,30 +243,64 @@ def train_one_epoch(
images = train_data["image"].to(device)
images = images * scale_factor
- top_region_index_tensor = train_data["top_region_index"].to(device)
- bottom_region_index_tensor = train_data["bottom_region_index"].to(device)
+ if include_body_region:
+ top_region_index_tensor = train_data["top_region_index"].to(device)
+ bottom_region_index_tensor = train_data["bottom_region_index"].to(device)
+ # We trained with only CT in this version
+ if include_modality:
+ modality_tensor = torch.ones((len(images),), dtype=torch.long).to(device)
spacing_tensor = train_data["spacing"].to(device)
optimizer.zero_grad(set_to_none=True)
with autocast("cuda", enabled=amp):
- noise = torch.randn(
- (num_images_per_batch, 4, images.size(-3), images.size(-2), images.size(-1)), device=device
- )
+ noise = torch.randn_like(images)
- timesteps = torch.randint(0, num_train_timesteps, (images.shape[0],), device=images.device).long()
+ if isinstance(noise_scheduler, RFlowScheduler):
+ timesteps = noise_scheduler.sample_timesteps(images)
+ else:
+ timesteps = torch.randint(0, num_train_timesteps, (images.shape[0],), device=images.device).long()
noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
- noise_pred = unet(
- x=noisy_latent,
- timesteps=timesteps,
- top_region_index_tensor=top_region_index_tensor,
- bottom_region_index_tensor=bottom_region_index_tensor,
- spacing_tensor=spacing_tensor,
- )
+ # Create a dictionary to store the inputs
+ unet_inputs = {
+ "x": noisy_latent,
+ "timesteps": timesteps,
+ "spacing_tensor": spacing_tensor,
+ }
+ # Add extra arguments if include_body_region is True
+ if include_body_region:
+ unet_inputs.update(
+ {
+ "top_region_index_tensor": top_region_index_tensor,
+ "bottom_region_index_tensor": bottom_region_index_tensor,
+ }
+ )
+ if include_modality:
+ unet_inputs.update(
+ {
+ "class_labels": modality_tensor,
+ }
+ )
+ model_output = unet(**unet_inputs)
+
+ if noise_scheduler.prediction_type == DDPMPredictionType.EPSILON:
+ # predict noise
+ model_gt = noise
+ elif noise_scheduler.prediction_type == DDPMPredictionType.SAMPLE:
+ # predict sample
+ model_gt = images
+ elif noise_scheduler.prediction_type == DDPMPredictionType.V_PREDICTION:
+ # predict velocity
+ model_gt = images - noise
+ else:
+ raise ValueError(
+ "noise scheduler prediction type has to be chosen from ",
+ f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]",
+ )
- loss = loss_pt(noise_pred.float(), noise.float())
+ loss = loss_pt(model_output.float(), model_gt.float())
if amp:
scaler.scale(loss).backward()
@@ -345,6 +391,10 @@ def diff_model_train(
Path(args.model_dir).mkdir(parents=True, exist_ok=True)
+ unet = load_unet(args, device, logger)
+ noise_scheduler = define_instance(args, "noise_scheduler")
+ include_body_region = unet.include_top_region_index_input
+
filenames_train = load_filenames(args.json_data_list)
if local_rank == 0:
logger.info(f"num_files_train: {len(filenames_train)}")
@@ -356,21 +406,24 @@ def diff_model_train(
continue
str_info = os.path.join(args.embedding_base_dir, filenames_train[_i]) + ".json"
- train_files.append(
- {"image": str_img, "top_region_index": str_info, "bottom_region_index": str_info, "spacing": str_info}
- )
+ train_files_i = {"image": str_img, "spacing": str_info}
+ if include_body_region:
+ train_files_i["top_region_index"] = str_info
+ train_files_i["bottom_region_index"] = str_info
+ train_files.append(train_files_i)
if dist.is_initialized():
train_files = partition_dataset(
data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True
)[local_rank]
train_loader = prepare_data(
- train_files, device, args.diffusion_unet_train["cache_rate"], batch_size=args.diffusion_unet_train["batch_size"]
+ train_files,
+ device,
+ args.diffusion_unet_train["cache_rate"],
+ batch_size=args.diffusion_unet_train["batch_size"],
+ include_body_region=include_body_region,
)
- unet = load_unet(args, device, logger)
- noise_scheduler = define_instance(args, "noise_scheduler")
-
scale_factor = calculate_scale_factor(train_loader, device, logger)
optimizer = create_optimizer(unet, args.diffusion_unet_train["lr"])
diff --git a/generation/maisi/scripts/find_masks.py b/generation/maisi/scripts/find_masks.py
index c919d3932f..b7a730c463 100644
--- a/generation/maisi/scripts/find_masks.py
+++ b/generation/maisi/scripts/find_masks.py
@@ -107,19 +107,21 @@ def find_masks(
if not set(anatomy_list).issubset(_item["label_list"]):
continue
- # extract region indice (top_index and bottom_index) for candidate mask
- top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0]
- top_index = top_index[0]
- bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0]
- bottom_index = bottom_index[0]
-
# whether to keep this mask, default to be True.
keep_mask = True
- # if candiate mask does not contain all the body_region, skip it
- for _idx in body_region:
- if _idx > bottom_index or _idx < top_index:
- keep_mask = False
+ # extract region indice (top_index and bottom_index) for candidate mask
+ include_body_region = "top_region_index" in _item.keys()
+ if include_body_region:
+ top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0]
+ top_index = top_index[0]
+ bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0]
+ bottom_index = bottom_index[0]
+
+ # if candiate mask does not contain all the body_region, skip it
+ for _idx in body_region:
+ if _idx > bottom_index or _idx < top_index:
+ keep_mask = False
for tumor_label in [23, 24, 26, 27, 128]:
# we skip those mask with tumors if users do not provide tumor label in anatomy_list
@@ -138,9 +140,10 @@ def find_masks(
"pseudo_label": os.path.join(mask_foldername, _item["pseudo_label_filename"]),
"spacing": _item["spacing"],
"dim": _item["dim"],
- "top_region_index": _item["top_region_index"],
- "bottom_region_index": _item["bottom_region_index"],
}
+ if include_body_region:
+ candidate["top_region_index"] = _item["top_region_index"]
+ candidate["bottom_region_index"] = _item["bottom_region_index"]
# Conditionally add the label to the candidate dictionary
if "label_filename" in _item:
diff --git a/generation/maisi/scripts/infer_controlnet.py b/generation/maisi/scripts/infer_controlnet.py
index 04ac982c98..0e88547d52 100644
--- a/generation/maisi/scripts/infer_controlnet.py
+++ b/generation/maisi/scripts/infer_controlnet.py
@@ -18,7 +18,7 @@
import torch
import torch.distributed as dist
-from monai.data import decollate_batch, MetaTensor
+from monai.data import MetaTensor, decollate_batch
from monai.networks.utils import copy_model_state
from monai.transforms import SaveImage
from monai.utils import RankFilter
@@ -49,6 +49,7 @@ def main():
help="config json file that stores training hyper-parameters",
)
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
+
args = parser.parse_args()
# Step 0: configuration
@@ -109,6 +110,9 @@ def main():
# define diffusion Model
unet = define_instance(args, "diffusion_unet_def").to(device)
+ include_body_region = unet.include_top_region_index_input
+ include_modality = unet.num_class_embeds is not None
+
# load trained diffusion model
if args.trained_diffusion_path is not None:
if not os.path.exists(args.trained_diffusion_path):
@@ -150,9 +154,14 @@ def main():
# get label mask
labels = batch["label"].to(device)
# get corresponding conditions
- top_region_index_tensor = batch["top_region_index"].to(device)
- bottom_region_index_tensor = batch["bottom_region_index"].to(device)
+ if include_body_region:
+ top_region_index_tensor = batch["top_region_index"].to(device)
+ bottom_region_index_tensor = batch["bottom_region_index"].to(device)
+ else:
+ top_region_index_tensor = None
+ bottom_region_index_tensor = None
spacing_tensor = batch["spacing"].to(device)
+ modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device)
out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist())
# get target dimension
dim = batch["dim"]
@@ -162,22 +171,23 @@ def main():
check_input(None, None, None, output_size, out_spacing, None)
# generate a single synthetic image using a latent diffusion model with controlnet.
synthetic_images, _ = ldm_conditional_sample_one_image(
- autoencoder,
- unet,
- controlnet,
- noise_scheduler,
- scale_factor,
- device,
- labels,
- top_region_index_tensor,
- bottom_region_index_tensor,
- spacing_tensor,
+ autoencoder=autoencoder,
+ diffusion_unet=unet,
+ controlnet=controlnet,
+ noise_scheduler=noise_scheduler,
+ scale_factor=scale_factor,
+ device=device,
+ combine_label_or=labels,
+ top_region_index_tensor=top_region_index_tensor,
+ bottom_region_index_tensor=bottom_region_index_tensor,
+ spacing_tensor=spacing_tensor,
+ modality_tensor=modality_tensor,
latent_shape=latent_shape,
output_size=output_size,
noise_factor=1.0,
num_inference_steps=args.controlnet_infer["num_inference_steps"],
- # reduce it when GPU memory is limited
autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"],
+ autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"],
)
# save image/label pairs
labels = decollate_batch(batch)[0]["label"]
diff --git a/generation/maisi/scripts/inference.py b/generation/maisi/scripts/inference.py
index 9049aee89a..3f81f9c49b 100644
--- a/generation/maisi/scripts/inference.py
+++ b/generation/maisi/scripts/inference.py
@@ -14,8 +14,8 @@
import json
import logging
import os
-import tempfile
import sys
+import tempfile
import monai
import torch
@@ -23,6 +23,7 @@
from monai.config import print_config
from monai.transforms import LoadImage, Orientation
from monai.utils import set_determinism
+
from scripts.sample import LDMSampler, check_input
from scripts.utils import define_instance
from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image
@@ -60,10 +61,18 @@ def main():
default=None,
help="random seed, can be None or int",
)
+ parser.add_argument(
+ "--version",
+ default="maisi3d-rflow",
+ type=str,
+ help="maisi_version, choose from ['maisi3d-ddpm', 'maisi3d-rflow']",
+ )
args = parser.parse_args()
# Step 0: configuration
logger = logging.getLogger("maisi.inference")
+ maisi_version = args.version
+
# ## Set deterministic training for reproducibility
if args.random_seed is not None:
set_determinism(seed=args.random_seed)
@@ -79,41 +88,75 @@ def main():
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
+ # TODO: remove the `files` after the files are uploaded to the NGC
files = [
{
"path": "models/autoencoder_epoch273.pt",
- "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt",
- },
- {
- "path": "models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt",
- "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt",
- },
- {
- "path": "models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt",
- "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials"
+ "/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt",
},
{
"path": "models/mask_generation_autoencoder.pt",
- "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/mask_generation_autoencoder.pt",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
+ "/tutorials/mask_generation_autoencoder.pt",
},
{
"path": "models/mask_generation_diffusion_unet.pt",
- "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_mask_generation_diffusion_unet_v2.pt",
- },
- {
- "path": "configs/candidate_masks_flexible_size_and_spacing_3000.json",
- "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/candidate_masks_flexible_size_and_spacing_3000.json",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
+ "/tutorials/model_zoo/model_maisi_mask_generation_diffusion_unet_v2.pt",
},
{
"path": "configs/all_anatomy_size_condtions.json",
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/all_anatomy_size_condtions.json",
},
{
- "path": "datasets/all_masks_flexible_size_and_spacing_3000.zip",
- "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_all_masks_flexible_size_and_spacing_3000.zip",
+ "path": "datasets/all_masks_flexible_size_and_spacing_4000.zip",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
+ "/tutorials/all_masks_flexible_size_and_spacing_4000.zip",
},
]
+ if maisi_version == "maisi3d-ddpm":
+ files += [
+ {
+ "path": "models/diff_unet_3d_ddpm.pt",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo"
+ "/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt",
+ },
+ {
+ "path": "models/controlnet_3d_ddpm.pt",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo"
+ "/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt",
+ },
+ {
+ "path": "configs/candidate_masks_flexible_size_and_spacing_3000.json",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
+ "/tutorials/candidate_masks_flexible_size_and_spacing_3000.json",
+ },
+ ]
+ elif maisi_version == "maisi3d-rflow":
+ files += [
+ {
+ "path": "models/diff_unet_3d_rflow.pt",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/"
+ "diff_unet_ckpt_rflow_epoch19350.pt",
+ },
+ {
+ "path": "models/controlnet_3d_rflow.pt",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/"
+ "controlnet_rflow_epoch60.pt",
+ },
+ {
+ "path": "configs/candidate_masks_flexible_size_and_spacing_4000.json",
+ "url": "https://developer.download.nvidia.com/assets/Clara/monai"
+ "/tutorials/candidate_masks_flexible_size_and_spacing_4000.json",
+ },
+ ]
+ else:
+ raise ValueError(
+ f"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}."
+ )
+
for file in files:
file["path"] = file["path"] if "datasets/" not in file["path"] else os.path.join(root_dir, file["path"])
download_url(url=file["url"], filepath=file["path"])
@@ -230,6 +273,7 @@ def main():
image_output_ext=args.image_output_ext,
label_output_ext=args.label_output_ext,
spacing=args.spacing,
+ modality=args.modality,
num_inference_steps=args.num_inference_steps,
mask_generation_num_inference_steps=args.mask_generation_num_inference_steps,
random_seed=args.random_seed,
diff --git a/generation/maisi/scripts/quality_check.py b/generation/maisi/scripts/quality_check.py
index 223732761a..bff49b6da0 100644
--- a/generation/maisi/scripts/quality_check.py
+++ b/generation/maisi/scripts/quality_check.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import nibabel as nib
import numpy as np
@@ -110,8 +109,11 @@ def is_outlier(statistics, image_data, label_data, label_int_dict):
for label_name, stats in statistics.items():
# Get the thresholds from the statistics
- low_thresh = stats["sigma_6_low"] # or "sigma_12_low" depending on your needs
- high_thresh = stats["sigma_6_high"] # or "sigma_12_high" depending on your needs
+ low_thresh = min(stats["sigma_6_low"], stats["percentile_0_5"]) # or "sigma_12_low" depending on your needs
+ high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs
+
+ if label_name == "bone":
+ high_thresh = 1000.0
# Retrieve the corresponding label integers
labels = label_int_dict.get(label_name, [])
diff --git a/generation/maisi/scripts/sample.py b/generation/maisi/scripts/sample.py
index c1e2c8699a..fb1d0f4251 100644
--- a/generation/maisi/scripts/sample.py
+++ b/generation/maisi/scripts/sample.py
@@ -16,20 +16,29 @@
import random
import time
from datetime import datetime
+import warnings
+import gc
import monai
import torch
-from monai.inferers.inferer import DiffusionInferer
from monai.data import MetaTensor
-from monai.inferers import sliding_window_inference
+from monai.inferers.inferer import DiffusionInferer
from monai.transforms import Compose, SaveImage
from monai.utils import set_determinism
from tqdm import tqdm
+from monai.inferers.inferer import SlidingWindowInferer
+from monai.networks.schedulers import RFlowScheduler, DDPMScheduler
from .augmentation import augmentation
from .find_masks import find_masks
-from .utils import binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask, remap_labels
from .quality_check import is_outlier
+from .utils import (
+ binarize_labels,
+ general_mask_generation_post_process,
+ get_body_region_index_from_mask,
+ remap_labels,
+ dynamic_infer,
+)
class ReconModel(torch.nn.Module):
@@ -122,7 +131,19 @@ def ldm_conditional_sample_one_mask(
latents = initialize_noise_latents(latent_shape, device)
anatomy_size = torch.FloatTensor(anatomy_size).unsqueeze(0).unsqueeze(0).half().to(device)
# synthesize latents
+ if isinstance(noise_scheduler, DDPMScheduler) and num_inference_steps < noise_scheduler.num_train_timesteps:
+ warnings.warn(
+ "**************************************************************\n"
+ "* WARNING: Mask noise_scheduler is a DDPMScheduler.\n"
+ "* We expect num_inference_steps = noise_scheduler.num_train_timesteps"
+ f" = {noise_scheduler.num_train_timesteps}.\n"
+ f"* Yet got num_inference_steps = {num_inference_steps}.\n"
+ "* The generated image quality is not guaranteed.\n"
+ "**************************************************************"
+ )
+
noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps)
+ # mask generator is DDPM
inferer_ddpm = DiffusionInferer(noise_scheduler)
latents = inferer_ddpm.sample(
input_noise=latents,
@@ -131,29 +152,17 @@ def ldm_conditional_sample_one_mask(
verbose=True,
conditioning=anatomy_size.to(device),
)
- # decode latents to synthesized masks
- if math.prod(latent_shape[1:]) <= math.prod(autoencoder_sliding_window_infer_size):
- synthetic_mask = recon_model(latents).cpu().detach()
- else:
- synthetic_mask = (
- sliding_window_inference(
- inputs=latents,
- roi_size=(
- autoencoder_sliding_window_infer_size[0],
- autoencoder_sliding_window_infer_size[1],
- autoencoder_sliding_window_infer_size[2],
- ),
- sw_batch_size=1,
- predictor=recon_model,
- mode="gaussian",
- overlap=autoencoder_sliding_window_infer_overlap,
- sw_device=device,
- device=torch.device("cpu"),
- progress=True,
- )
- .cpu()
- .detach()
- )
+
+ inferer = SlidingWindowInferer(
+ roi_size=autoencoder_sliding_window_infer_size,
+ sw_batch_size=1,
+ progress=True,
+ mode="gaussian",
+ overlap=autoencoder_sliding_window_infer_overlap,
+ sw_device=device,
+ device=torch.device("cpu"),
+ )
+ synthetic_mask = dynamic_infer(inferer, recon_model, latents)
synthetic_mask = torch.softmax(synthetic_mask, dim=1)
synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True)
# mapping raw index to 132 labels
@@ -183,12 +192,13 @@ def ldm_conditional_sample_one_image(
scale_factor,
device,
combine_label_or,
- top_region_index_tensor,
- bottom_region_index_tensor,
spacing_tensor,
latent_shape,
output_size,
noise_factor,
+ top_region_index_tensor=None,
+ bottom_region_index_tensor=None,
+ modality_tensor=None,
num_inference_steps=1000,
autoencoder_sliding_window_infer_size=[96, 96, 96],
autoencoder_sliding_window_infer_overlap=0.6667,
@@ -204,12 +214,13 @@ def ldm_conditional_sample_one_image(
scale_factor (float): Scaling factor for the latent space.
device (torch.device): The device to run the computation on.
combine_label_or (torch.Tensor): The combined label tensor.
- top_region_index_tensor (torch.Tensor): Tensor specifying the top region index.
- bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index.
spacing_tensor (torch.Tensor): Tensor specifying the spacing.
latent_shape (tuple): The shape of the latent space.
output_size (tuple): The desired output size of the image.
noise_factor (float): Factor to scale the initial noise.
+ top_region_index_tensor (torch.Tensor): Tensor specifying the top region index. Defaults to None.
+ bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index. Defaults to None.
+ modality_tensor (torch.Tensor): Int Tensor specifying the modality.
num_inference_steps (int): Number of inference steps for the diffusion process.
autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96].
autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667.
@@ -224,6 +235,9 @@ def ldm_conditional_sample_one_image(
b_min = 0.0
b_max = 1
+ include_body_region = diffusion_unet.include_top_region_index_input
+ include_modality = diffusion_unet.num_class_embeds is not None
+
recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device)
with torch.no_grad(), torch.amp.autocast("cuda"):
@@ -247,54 +261,106 @@ def ldm_conditional_sample_one_image(
latents = initialize_noise_latents(latent_shape, device) * noise_factor
# synthesize latents
- noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps)
- for t in tqdm(noise_scheduler.timesteps, ncols=110):
- # Get controlnet output
- down_block_res_samples, mid_block_res_sample = controlnet(
- x=latents,
- timesteps=torch.Tensor((t,)).to(device),
- controlnet_cond=controlnet_cond_vis,
+ if isinstance(noise_scheduler, RFlowScheduler):
+ noise_scheduler.set_timesteps(
+ num_inference_steps=num_inference_steps,
+ input_img_size_numel=torch.prod(torch.tensor(latents.shape[2:])),
)
- latent_model_input = latents
- noise_pred = diffusion_unet(
- x=latent_model_input,
- timesteps=torch.Tensor((t,)).to(device),
- top_region_index_tensor=top_region_index_tensor,
- bottom_region_index_tensor=bottom_region_index_tensor,
- spacing_tensor=spacing_tensor,
- down_block_additional_residuals=down_block_res_samples,
- mid_block_additional_residual=mid_block_res_sample,
+ else:
+ noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps)
+
+ if isinstance(noise_scheduler, DDPMScheduler) and num_inference_steps < noise_scheduler.num_train_timesteps:
+ warnings.warn(
+ "**************************************************************\n"
+ "* WARNING: Image noise_scheduler is a DDPMScheduler.\n"
+ "* We expect num_inference_steps = noise_scheduler.num_train_timesteps"
+ f" = {noise_scheduler.num_train_timesteps}.\n"
+ f"* Yet got num_inference_steps = {num_inference_steps}.\n"
+ "* The generated image quality is not guaranteed.\n"
+ "**************************************************************"
)
- latents, _ = noise_scheduler.step(noise_pred, t, latents)
+
+ all_timesteps = noise_scheduler.timesteps
+ all_next_timesteps = torch.cat((all_timesteps[1:], torch.tensor([0], dtype=all_timesteps.dtype)))
+ progress_bar = tqdm(
+ zip(all_timesteps, all_next_timesteps),
+ total=min(len(all_timesteps), len(all_next_timesteps)),
+ )
+ for t, next_t in progress_bar:
+ # get controlnet output
+ # Create a dictionary to store the inputs
+ controlnet_inputs = {
+ "x": latents,
+ "timesteps": torch.Tensor((t,)).to(device),
+ "controlnet_cond": controlnet_cond_vis,
+ }
+ if include_modality:
+ controlnet_inputs.update(
+ {
+ "class_labels": modality_tensor,
+ }
+ )
+ down_block_res_samples, mid_block_res_sample = controlnet(**controlnet_inputs)
+
+ # get diffusion network output
+ # Create a dictionary to store the inputs
+ unet_inputs = {
+ "x": latents,
+ "timesteps": torch.Tensor((t,)).to(device),
+ "spacing_tensor": spacing_tensor,
+ "down_block_additional_residuals": down_block_res_samples,
+ "mid_block_additional_residual": mid_block_res_sample,
+ }
+ # Add extra arguments if include_body_region is True
+ if include_body_region:
+ unet_inputs.update(
+ {
+ "top_region_index_tensor": top_region_index_tensor,
+ "bottom_region_index_tensor": bottom_region_index_tensor,
+ }
+ )
+ if include_modality:
+ unet_inputs.update(
+ {
+ "class_labels": modality_tensor,
+ }
+ )
+ model_output = diffusion_unet(**unet_inputs)
+
+ if not isinstance(noise_scheduler, RFlowScheduler):
+ latents, _ = noise_scheduler.step(model_output, t, latents) # type: ignore
+ else:
+ latents, _ = noise_scheduler.step(model_output, t, latents, next_t) # type: ignore
end_time = time.time()
- logging.info(f"---- Latent features generation time: {end_time - start_time} seconds ----")
- del noise_pred
+ logging.info(f"---- DM/ControlNet Latent features generation time: {end_time - start_time} seconds ----")
+ del (
+ unet_inputs,
+ controlnet_inputs,
+ model_output,
+ controlnet_cond_vis,
+ down_block_res_samples,
+ mid_block_res_sample,
+ )
+ gc.collect()
torch.cuda.empty_cache()
# decode latents to synthesized images
logging.info("---- Start decoding latent features into images... ----")
start_time = time.time()
- if math.prod(latent_shape[1:]) <= math.prod(autoencoder_sliding_window_infer_size):
- synthetic_images = recon_model(latents)
- else:
- synthetic_images = sliding_window_inference(
- inputs=latents,
- roi_size=(
- min(output_size[0] // 4, autoencoder_sliding_window_infer_size[0]),
- min(output_size[1] // 4, autoencoder_sliding_window_infer_size[1]),
- min(output_size[2] // 4, autoencoder_sliding_window_infer_size[2]),
- ),
- sw_batch_size=1,
- predictor=recon_model,
- mode="gaussian",
- overlap=autoencoder_sliding_window_infer_overlap,
- sw_device=device,
- device=torch.device("cpu"),
- progress=True,
- )
+
+ inferer = SlidingWindowInferer(
+ roi_size=autoencoder_sliding_window_infer_size,
+ sw_batch_size=1,
+ progress=True,
+ mode="gaussian",
+ overlap=autoencoder_sliding_window_infer_overlap,
+ sw_device=device,
+ device=torch.device("cpu"),
+ )
+ synthetic_images = dynamic_infer(inferer, recon_model, latents)
synthetic_images = torch.clip(synthetic_images, b_min, b_max).cpu()
end_time = time.time()
- logging.info(f"---- Image decoding time: {end_time - start_time} seconds ----")
+ logging.info(f"---- Image VAE decoding time: {end_time - start_time} seconds ----")
## post processing:
# project output to [0, 1]
@@ -510,6 +576,7 @@ def __init__(
label_output_ext=".nii.gz",
real_img_median_statistics="./configs/image_median_statistics.json",
spacing=[1, 1, 1],
+ modality=1,
num_inference_steps=None,
mask_generation_num_inference_steps=None,
random_seed=None,
@@ -522,6 +589,7 @@ def __init__(
Args:
Various parameters related to model configuration, input settings, and output specifications.
"""
+ self.random_seed = random_seed
if random_seed is not None:
set_determinism(seed=random_seed)
@@ -575,7 +643,7 @@ def __init__(
self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap
# quality check args
- self.max_try_time = 5 # if not pass quality check, will try self.max_try_time times
+ self.max_try_time = 2 # if not pass quality check, will try self.max_try_time times
with open(real_img_median_statistics, "r") as json_file:
self.median_statistics = json.load(json_file)
self.label_int_dict = {
@@ -601,21 +669,27 @@ def __init__(
self.mask_generation_diffusion_unet.eval()
self.spacing = spacing
-
- self.val_transforms = Compose(
- [
- monai.transforms.LoadImaged(keys=["pseudo_label"]),
- monai.transforms.EnsureChannelFirstd(keys=["pseudo_label"]),
- monai.transforms.Orientationd(keys=["pseudo_label"], axcodes="RAS"),
- monai.transforms.EnsureTyped(keys=["pseudo_label"], dtype=torch.uint8),
+ self.modality_tensor = modality * torch.ones((1,), dtype=torch.long).to(device)
+ self.include_body_region = self.diffusion_unet.include_top_region_index_input
+ self.include_modality = self.diffusion_unet.num_class_embeds is not None
+
+ val_transforms_list = [
+ monai.transforms.LoadImaged(keys=["pseudo_label"]),
+ monai.transforms.EnsureChannelFirstd(keys=["pseudo_label"]),
+ monai.transforms.Orientationd(keys=["pseudo_label"], axcodes="RAS"),
+ monai.transforms.EnsureTyped(keys=["pseudo_label"], dtype=torch.uint8),
+ monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)),
+ monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
+ ]
+ if self.include_body_region:
+ val_transforms_list += [
monai.transforms.Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)),
monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)),
- monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)),
monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2),
monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2),
- monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
]
- )
+
+ self.val_transforms = Compose(val_transforms_list)
logging.info("LDM sampler initialized.")
def sample_multiple_images(self, num_img):
@@ -625,6 +699,7 @@ def sample_multiple_images(self, num_img):
Args:
num_img (int): Number of images to generate.
"""
+ modality_tensor = self.modality_tensor
output_filenames = []
if len(self.controllable_anatomy_size) > 0:
# we will use mask generation instead of finding candidate masks
@@ -653,11 +728,19 @@ def sample_multiple_images(self, num_img):
selected_mask_files = self.select_mask(candidate_mask_files, num_img)
logging.info(f"Images will be generated based on {selected_mask_files}.")
- if len(selected_mask_files) != num_img:
+ if len(selected_mask_files) < num_img:
raise ValueError(
- f"len(selected_mask_files) ({len(selected_mask_files)}) != num_img ({num_img}). This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)."
+ (
+ f"len(selected_mask_files) ({len(selected_mask_files)}) < num_img ({num_img}). "
+ "This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)."
+ )
)
- for item in selected_mask_files:
+
+ num_generated_img = 0
+ for index_s in range(len(selected_mask_files)):
+ item = selected_mask_files[index_s]
+ if num_generated_img >= num_img:
+ break
logging.info("---- Start preparing masks... ----")
start_time = time.time()
if len(self.controllable_anatomy_size) > 0:
@@ -682,58 +765,63 @@ def sample_multiple_images(self, num_img):
combine_label_or = self.ensure_output_size_and_spacing(combine_label_or)
# mask augmentation
if if_aug:
- combine_label_or = augmentation(combine_label_or, self.output_size)
+ combine_label_or = augmentation(combine_label_or, self.output_size, self.random_seed)
end_time = time.time()
logging.info(f"---- Mask preparation time: {end_time - start_time} seconds ----")
torch.cuda.empty_cache()
# generate image/label pairs
to_generate = True
try_time = 0
- while to_generate:
- synthetic_images, synthetic_labels = self.sample_one_pair(
- combine_label_or,
- top_region_index_tensor,
- bottom_region_index_tensor,
- spacing_tensor,
- )
- # synthetic image quality check
- pass_quality_check = self.quality_check(
- synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy()
- )
- if pass_quality_check or try_time > self.max_try_time:
- # save image/label pairs
- output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
- synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz"
- synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta)
- img_saver = SaveImage(
- output_dir=self.output_dir,
- output_postfix=output_postfix + "_image",
- output_ext=self.image_output_ext,
- separate_folder=False,
- )
- img_saver(synthetic_images[0])
- synthetic_images_filename = os.path.join(
- self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext
- )
- # filter out the organs that are not in anatomy_list
- synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list)
- label_saver = SaveImage(
- output_dir=self.output_dir,
- output_postfix=output_postfix + "_label",
- output_ext=self.label_output_ext,
- separate_folder=False,
- )
- label_saver(synthetic_labels[0])
- synthetic_labels_filename = os.path.join(
- self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext
- )
- output_filenames.append([synthetic_images_filename, synthetic_labels_filename])
- to_generate = False
- else:
+ # start generation
+ synthetic_images, synthetic_labels = self.sample_one_pair(
+ combine_label_or,
+ top_region_index_tensor,
+ bottom_region_index_tensor,
+ spacing_tensor,
+ modality_tensor,
+ )
+ # synthetic image quality check
+ pass_quality_check = self.quality_check(
+ synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy()
+ )
+ print(num_img - num_generated_img, (len(selected_mask_files) - index_s))
+ if pass_quality_check or (num_img - num_generated_img) >= (len(selected_mask_files) - index_s):
+ if not pass_quality_check:
logging.info(
- "Generated image/label pair did not pass quality check, will re-generate another pair."
+ "Generated image/label pair did not pass quality check, but will still save them. "
+ "Please consider changing spacing and output_size to facilitate a more realistic setting."
)
- try_time += 1
+ num_generated_img = num_generated_img + 1
+ # save image/label pairs
+ output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz"
+ synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta)
+ img_saver = SaveImage(
+ output_dir=self.output_dir,
+ output_postfix=output_postfix + "_image",
+ output_ext=self.image_output_ext,
+ separate_folder=False,
+ )
+ img_saver(synthetic_images[0])
+ synthetic_images_filename = os.path.join(
+ self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext
+ )
+ # filter out the organs that are not in anatomy_list
+ synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list)
+ label_saver = SaveImage(
+ output_dir=self.output_dir,
+ output_postfix=output_postfix + "_label",
+ output_ext=self.label_output_ext,
+ separate_folder=False,
+ )
+ label_saver(synthetic_labels[0])
+ synthetic_labels_filename = os.path.join(
+ self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext
+ )
+ output_filenames.append([synthetic_images_filename, synthetic_labels_filename])
+ to_generate = False
+ else:
+ logging.info("Generated image/label pair did not pass quality check, will re-generate another pair.")
return output_filenames
def select_mask(self, candidate_mask_files, num_img):
@@ -750,7 +838,7 @@ def select_mask(self, candidate_mask_files, num_img):
selected_mask_files = []
random.shuffle(candidate_mask_files)
- for n in range(num_img):
+ for n in range(len(candidate_mask_files)):
mask_file = candidate_mask_files[n % len(candidate_mask_files)]
selected_mask_files.append({"mask_file": mask_file, "if_aug": True})
return selected_mask_files
@@ -761,6 +849,7 @@ def sample_one_pair(
top_region_index_tensor,
bottom_region_index_tensor,
spacing_tensor,
+ modality_tensor,
):
"""
Generate a single pair of synthetic image and mask.
@@ -770,6 +859,7 @@ def sample_one_pair(
top_region_index_tensor (torch.Tensor): Tensor specifying the top region index.
bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index.
spacing_tensor (torch.Tensor): Tensor specifying the spacing.
+ modality_tensor (torch.Tensor): Int Tensor specifying the modality.
Returns:
tuple: A tuple containing the synthetic image and its corresponding label.
@@ -786,6 +876,7 @@ def sample_one_pair(
top_region_index_tensor=top_region_index_tensor,
bottom_region_index_tensor=bottom_region_index_tensor,
spacing_tensor=spacing_tensor,
+ modality_tensor=modality_tensor,
latent_shape=self.latent_shape,
output_size=self.output_size,
noise_factor=self.noise_factor,
@@ -959,13 +1050,11 @@ def read_mask_information(self, mask_file):
"""
val_data = self.val_transforms(mask_file)
- for key in [
- "pseudo_label",
- "spacing",
- "top_region_index",
- "bottom_region_index",
- ]:
- val_data[key] = val_data[key].unsqueeze(0).to(self.device)
+ for key in ["pseudo_label", "spacing", "top_region_index", "bottom_region_index"]:
+ if isinstance(val_data[key], torch.Tensor):
+ val_data[key] = val_data[key].unsqueeze(0).to(self.device)
+ else:
+ val_data[key] = None
return (
val_data["pseudo_label"],
@@ -1000,42 +1089,75 @@ def find_closest_masks(self, num_img):
if len(candidates) < num_img:
raise ValueError(f"candidate masks are less than {num_img}).")
+
# loop through the database and find closest combinations
new_candidates = []
for c in candidates:
diff = 0
+ include_c = True
for axis in range(3):
+ if abs(c["dim"][axis]) < self.output_size[axis] - 64:
+ # we cannot upsample the mask too much
+ include_c = False
+ break
+ # check diff in FOV, major metric
+ diff += abs(
+ (abs(c["dim"][axis] * c["spacing"][axis]) - self.output_size[axis] * self.spacing[axis]) / 10
+ )
# check diff in dim
- diff += abs((c["dim"][axis] - self.output_size[axis]) / 100)
+ diff += abs((abs(c["dim"][axis]) - self.output_size[axis]) / 100)
# check diff in spacing
- diff += abs(c["spacing"][axis] - self.spacing[axis])
- new_candidates.append((c, diff))
+ diff += abs(abs(c["spacing"][axis]) - self.spacing[axis])
+ if include_c:
+ new_candidates.append((c, diff))
+
# choose top-2*num_img candidates (at least 5)
- new_candidates = sorted(new_candidates, key=lambda x: x[1])[: max(2 * num_img, 5)]
+ num_candidates = max(self.max_try_time * num_img, 5)
+ new_candidates = sorted(new_candidates, key=lambda x: x[1])
+
final_candidates = []
# check top-2*num_img candidates and update spacing after resampling
- image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True)
for c, _ in new_candidates:
- label = image_loader(c["pseudo_label"])
- try:
- label = self.ensure_output_size_and_spacing(label.unsqueeze(0))
- except ValueError as e:
- if "Resampled mask does not contain required class labels" in str(e):
- continue
- else:
- raise e
- # get region_index after resample
- top_region_index, bottom_region_index = get_body_region_index_from_mask(label)
- c["top_region_index"] = top_region_index
- c["bottom_region_index"] = bottom_region_index
- c["spacing"] = self.spacing
- c["dim"] = self.output_size
-
- final_candidates.append(c)
+ c = self.resample_mask_check_organ_list(c)
+ if c is not None:
+ final_candidates.append(c)
+ if len(final_candidates) >= num_candidates:
+ break
if len(final_candidates) == 0:
raise ValueError("Cannot find body region with given organ list.")
return final_candidates
+ def resample_mask_check_organ_list(self, mask):
+ """
+ Resample mask and check if the resampled mask contains the required organ list.
+
+ Args:
+ mask (dict): input mask.
+
+ Returns:
+ dict: resampled mask. If None, means the resampled mask does not contain the required organ list
+
+ Raises:
+ ValueError: If suitable candidates cannot be found.
+ """
+
+ image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True)
+ label = image_loader(mask["pseudo_label"])
+ try:
+ label = self.ensure_output_size_and_spacing(label.unsqueeze(0))
+ except ValueError as e:
+ if "Resampled mask does not contain required class labels" in str(e):
+ return None
+ else:
+ raise e
+ # get region_index after resample
+ top_region_index, bottom_region_index = get_body_region_index_from_mask(label)
+ mask["top_region_index"] = top_region_index
+ mask["bottom_region_index"] = bottom_region_index
+ mask["spacing"] = self.spacing
+ mask["dim"] = self.output_size
+ return mask
+
def quality_check(self, image_data, label_data):
"""
Perform a quality check on the generated image.
diff --git a/generation/maisi/scripts/train_controlnet.py b/generation/maisi/scripts/train_controlnet.py
index c59bebccff..3d7336be17 100644
--- a/generation/maisi/scripts/train_controlnet.py
+++ b/generation/maisi/scripts/train_controlnet.py
@@ -23,6 +23,8 @@
import torch.nn.functional as F
from monai.networks.utils import copy_model_state
from monai.utils import RankFilter
+from monai.networks.schedulers import RFlowScheduler
+from monai.networks.schedulers.ddpm import DDPMPredictionType
from torch.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -51,6 +53,7 @@ def main():
help="config json file that stores training hyper-parameters",
)
parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node")
+
args = parser.parse_args()
# Step 0: configuration
@@ -105,6 +108,9 @@ def main():
# Step 2: define diffusion model and controlnet
# define diffusion Model
unet = define_instance(args, "diffusion_unet_def").to(device)
+ include_body_region = unet.include_top_region_index_input
+ include_modality = unet.num_class_embeds is not None
+
# load trained diffusion model
if args.trained_diffusion_path is not None:
if not os.path.exists(args.trained_diffusion_path):
@@ -168,57 +174,104 @@ def main():
epoch_loss_ = 0
for step, batch in enumerate(train_loader):
# get image embedding and label mask and scale image embedding by the provided scale_factor
- inputs = batch["image"].to(device) * scale_factor
+ images = batch["image"].to(device) * scale_factor
labels = batch["label"].to(device)
# get corresponding conditions
- top_region_index_tensor = batch["top_region_index"].to(device)
- bottom_region_index_tensor = batch["bottom_region_index"].to(device)
+ if include_body_region:
+ top_region_index_tensor = batch["top_region_index"].to(device)
+ bottom_region_index_tensor = batch["bottom_region_index"].to(device)
+ # We trained with only CT in this version
+ if include_modality:
+ modality_tensor = torch.ones((len(images),), dtype=torch.long).to(device)
spacing_tensor = batch["spacing"].to(device)
optimizer.zero_grad(set_to_none=True)
with autocast("cuda", enabled=True):
# generate random noise
- noise_shape = list(inputs.shape)
- noise = torch.randn(noise_shape, dtype=inputs.dtype).to(device)
+ noise_shape = list(images.shape)
+ noise = torch.randn(noise_shape, dtype=images.dtype).to(device)
# use binary encoding to encode segmentation mask
controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float()
# create timesteps
- timesteps = torch.randint(
- 0, noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=device
- ).long()
+ if isinstance(noise_scheduler, RFlowScheduler):
+ timesteps = noise_scheduler.sample_timesteps(images)
+ else:
+ timesteps = torch.randint(
+ 0, noise_scheduler.num_train_timesteps, (images.shape[0],), device=images.device
+ ).long()
# create noisy latent
- noisy_latent = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
+ noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
# get controlnet output
- down_block_res_samples, mid_block_res_sample = controlnet(
- x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond
- )
- # get noise prediction from diffusion unet
- noise_pred = unet(
- x=noisy_latent,
- timesteps=timesteps,
- top_region_index_tensor=top_region_index_tensor,
- bottom_region_index_tensor=bottom_region_index_tensor,
- spacing_tensor=spacing_tensor,
- down_block_additional_residuals=down_block_res_samples,
- mid_block_additional_residual=mid_block_res_sample,
+ # Create a dictionary to store the inputs
+ controlnet_inputs = {
+ "x": noisy_latent,
+ "timesteps": timesteps,
+ "controlnet_cond": controlnet_cond,
+ }
+ if include_modality:
+ controlnet_inputs.update(
+ {
+ "class_labels": modality_tensor,
+ }
+ )
+ down_block_res_samples, mid_block_res_sample = controlnet(**controlnet_inputs)
+
+ # get diffusion network output
+ # Create a dictionary to store the inputs
+ unet_inputs = {
+ "x": noisy_latent,
+ "timesteps": timesteps,
+ "spacing_tensor": spacing_tensor,
+ "down_block_additional_residuals": down_block_res_samples,
+ "mid_block_additional_residual": mid_block_res_sample,
+ }
+ # Add extra arguments if include_body_region is True
+ if include_body_region:
+ unet_inputs.update(
+ {
+ "top_region_index_tensor": top_region_index_tensor,
+ "bottom_region_index_tensor": bottom_region_index_tensor,
+ }
+ )
+ if include_modality:
+ unet_inputs.update(
+ {
+ "class_labels": modality_tensor,
+ }
+ )
+ model_output = unet(**unet_inputs)
+
+ if noise_scheduler.prediction_type == DDPMPredictionType.EPSILON:
+ # predict noise
+ model_gt = noise
+ elif noise_scheduler.prediction_type == DDPMPredictionType.SAMPLE:
+ # predict sample
+ model_gt = images
+ elif noise_scheduler.prediction_type == DDPMPredictionType.V_PREDICTION:
+ # predict velocity
+ model_gt = images - noise
+ else:
+ raise ValueError(
+ "noise scheduler prediction type has to be chosen from ",
+ f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]",
)
if weighted_loss > 1.0:
- weights = torch.ones_like(inputs).to(inputs.device)
- roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device)
- interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest")
+ weights = torch.ones_like(images).to(images.device)
+ roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(images.device)
+ interpolate_label = F.interpolate(labels, size=images.shape[2:], mode="nearest")
# assign larger weights for ROI (tumor)
for label in weighted_loss_label:
roi[interpolate_label == label] = 1
- weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = weighted_loss
- loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean()
+ weights[roi.repeat(1, images.shape[1], 1, 1, 1) == 1] = weighted_loss
+ loss = (F.l1_loss(noise_pred.float(), model_gt.float(), reduction="none") * weights).mean()
else:
- loss = F.l1_loss(noise_pred.float(), noise.float())
+ loss = F.l1_loss(model_output.float(), model_gt.float())
scaler.scale(loss).backward()
scaler.step(optimizer)
diff --git a/generation/maisi/scripts/utils.py b/generation/maisi/scripts/utils.py
index 13d9a240e1..9b1df28921 100644
--- a/generation/maisi/scripts/utils.py
+++ b/generation/maisi/scripts/utils.py
@@ -11,9 +11,9 @@
import copy
import json
+import logging
import math
import os
-import logging
from argparse import Namespace
from datetime import timedelta
from typing import Any, Sequence
@@ -22,11 +22,11 @@
import skimage
import torch
import torch.distributed as dist
-from monai.transforms.utils_morphological_ops import dilate, erode
from monai.bundle import ConfigParser
from monai.config import DtypeLike, NdarrayOrTensor
from monai.data import CacheDataset, DataLoader, partition_dataset
from monai.transforms import Compose, EnsureTyped, Lambdad, LoadImaged, Orientationd
+from monai.transforms.utils_morphological_ops import dilate, erode
from monai.utils import TransformBackends, convert_data_type, convert_to_dst_type, get_equivalent_dtype
from scipy import stats
from torch import Tensor
@@ -306,10 +306,12 @@ def prepare_maisi_controlnet_json_dataloader(
LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True),
Orientationd(keys=["label"], axcodes="RAS"),
EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True),
- Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)),
- Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)),
+ Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True),
+ Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True),
Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)),
- Lambdad(keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2),
+ Lambdad(
+ keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2, allow_missing_keys=True
+ ),
]
train_transforms, val_transforms = Compose(common_transform), Compose(common_transform)
@@ -706,7 +708,20 @@ def dynamic_infer(inferer, model, images):
Returns:
torch.Tensor: The output from the model or the inferer, depending on the input size.
"""
- if torch.numel(images[0:1, 0:1, ...]) < math.prod(inferer.roi_size):
+ if torch.numel(images[0:1, 0:1, ...]) <= math.prod(inferer.roi_size):
return model(images)
else:
- return inferer(network=model, inputs=images)
+ # Extract the spatial dimensions from the images tensor (H, W, D)
+ spatial_dims = images.shape[2:]
+ orig_roi = inferer.roi_size
+
+ # Check that roi has the same number of dimensions as spatial_dims
+ if len(orig_roi) != len(spatial_dims):
+ raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).")
+
+ # Iterate and adjust each ROI dimension
+ adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)]
+ inferer.roi_size = adjusted_roi
+ output = inferer(network=model, inputs=images)
+ inferer.roi_size = orig_roi
+ return output