Skip to content

Commit 62f8ff8

Browse files
Refactor MAISI tutorial, migrate GenerativeAI import (#1779)
Fixes #1772 . ### Description Refactor MAISI tutorial, migrate GenerativeAI import to monai core. Add inference script. Correct logging format. VAE training notebook tested. Diffusion training notebook tested. ControlNet training notebook tested. Inference notebook tested. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [x] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [x] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [x] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: Can-Zhao <volcanofly@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6067dc6 commit 62f8ff8

File tree

9 files changed

+476
-250
lines changed

9 files changed

+476
-250
lines changed

generation/maisi/README.md

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,56 @@ MAISI is based on the following papers:
5050

5151
[**ControlNet:** Lvmin Zhang, Anyi Rao, Maneesh Agrawala; “Adding Conditional Control to Text-to-Image Diffusion Models.” ICCV 2023.](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhang_Adding_Conditional_Control_to_Text-to-Image_Diffusion_Models_ICCV_2023_paper.pdf)
5252

53-
### 1. Installation
54-
Please refer to the [Installation of MONAI Generative Model](../README.md).
53+
### 1. Network Definition
54+
Network definition is stored in [./configs/config_maisi.json](./configs/config_maisi.json). Training and inference should use the same [./configs/config_maisi.json](./configs/config_maisi.json).
55+
56+
### 2. Model Inference
57+
The information for the inference input, like body region and anatomy to generate, is stored in [./configs/config_infer.json](./configs/config_infer.json). Please feel free to play with it. Here are the details of the parameters.
58+
59+
- `"num_output_samples"`: int, the number of output image/mask pairs it will generate.
60+
- `"spacing"`: voxel size of generated images. E.g., if set to `[1.5, 1.5, 2.0]`, it will generate images with a resolution of 1.5x1.5x2.0 mm.
61+
- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512x512x256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers.
62+
- `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. The output will contain paired image and segmentation mask for the controllable anatomy.
63+
- `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower".
64+
- `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json".
65+
- `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
66+
- `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time.
5567

56-
Note: MAISI depends on [xFormers](https://github.com/facebookresearch/xformers) library.
57-
ARM64 users can build xFormers from the [source](https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers) if the available wheel does not meet their requirements.
5868

59-
### 2. Model inference and example outputs
6069
Please refer to [maisi_inference_tutorial.ipynb](maisi_inference_tutorial.ipynb) for the tutorial for MAISI model inference.
6170

62-
### 3. Training example
71+
#### Execute Inference:
72+
To run the inference script, please run:
73+
```bash
74+
export MONAI_DATA_DIRECTORY=<dir_you_will_download_data>
75+
python -m scripts.inference -c ./configs/config_maisi.json -i ./configs/config_infer.json -e ./configs/environment.json --random-seed 0
76+
```
77+
78+
### 3. Model Training
6379
Training data preparation can be found in [./data/README.md](./data/README.md)
6480

6581
#### [3.1 3D Autoencoder Training](./maisi_train_vae_tutorial.ipynb)
82+
The information for the training hyperparameters and data processing parameters, like learning rate and patch size, are stored in [./configs/config_maisi_vae_train.json](./configs/config_maisi_vae_train.json). The provided configuration works for 16G V100 GPU. Please feel free to tune the parameters for your datasets and device.
83+
84+
Dataset preprocessing:
85+
- `"random_aug"`: bool, whether to add random data augmentation for training data.
86+
- `"spacing_type"`: choose from `"original"` (no resampling involved), `"fixed"` (all images resampled to same voxel size), and `"rand_zoom"` (images randomly zoomed, valid when `"random_aug"` is True).
87+
- `"spacing"`: None or list of three floats. If `"spacing_type"` is `"fixed"`, all the images will be interpolated to the voxel size of `"spacing"`.
88+
- `"select_channel"`: int, if multi-channel MRI, which channel it will select.
89+
90+
Training configs:
91+
- `"batch_size"`: training batch size. Please consider increasing it if GPU memory is larger than 16G.
92+
- `"patch_size"`: training patch size. For the released model, we first trained the autoencoder with small patch size [64,64,64], then continued training with patch size of [128,128,128].
93+
- `"val_patch_size"`: Size of validation patches. If None, will use the whole volume for validation. If given, will central crop a patch for validation.
94+
- `"val_sliding_window_patch_size"`: if the validation patch is too large, will use sliding window inference. Please consider increasing it if GPU memory is larger than 16G.
95+
- `"val_batch_size"`: validation batch size.
96+
- `"perceptual_weight"`: perceptual loss weight.
97+
- `"kl_weight"`: KL loss weight, important hyper-parameter. If too large, decoder cannot recon good results from latent space. If too small, latent space will not be regularized enough for the diffusion model.
98+
- `"adv_weight"`: adversavarial loss weight.
99+
- `"recon_loss"`: choose from 'l1' and l2'.
100+
- `"val_interval"`:int, do validation every `"val_interval"` epoches.
101+
- `"cache"`: float between 0 and 1, dataloader cache, choose small value if CPU memory is small.
102+
- `"n_epochs"`: int, number of epochs to train. Please adjust it based on the size of your datasets. We used 280 epochs for the released model on 58k data.
66103

67104
Please refer to [maisi_train_vae_tutorial.ipynb](maisi_train_vae_tutorial.ipynb) for the tutorial for MAISI VAE model training.
68105

generation/maisi/configs/config_infer.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
22
"num_output_samples": 1,
33
"body_region": ["abdomen"],
4-
"anatomy_list": ["liver"],
5-
"controllable_anatomy_size": [["hepatic tumor", 0.3], ["liver", 0.5]],
4+
"anatomy_list": ["liver","hepatic tumor"],
5+
"controllable_anatomy_size": [],
66
"num_inference_steps": 1000,
77
"mask_generation_num_inference_steps": 1000,
88
"output_size": [

generation/maisi/configs/config_maisi.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@
120120
"dim_split": 1
121121
},
122122
"mask_generation_diffusion_def": {
123-
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
123+
"_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
124124
"spatial_dims": "@spatial_dims",
125125
"in_channels": "@latent_channels",
126126
"out_channels": "@latent_channels",
127-
"num_channels":[64, 128, 256, 512],
127+
"channels":[64, 128, 256, 512],
128128
"attention_levels":[false, false, true, true],
129129
"num_head_channels":[0, 0, 32, 32],
130130
"num_res_blocks": 2,

generation/maisi/configs/config_maisi_vae_train.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@
1919
"val_interval": 10,
2020
"cache": 0.5,
2121
"amp": true,
22-
"n_epochs": 2
22+
"n_epochs": 1
2323
}
2424
}

0 commit comments

Comments
 (0)