Skip to content

Refactor MAISI tutorial, migrate GenerativeAI import #1779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f544302
change mor_op position, rm generative import in vae
Can-Zhao Aug 12, 2024
9da99fb
rm xformer import in vae
Can-Zhao Aug 12, 2024
dcad093
add load ckpt functions, inference notebook can run
Can-Zhao Aug 12, 2024
148484b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
a551102
fix logging format, add inference script, add details in readme
Can-Zhao Aug 12, 2024
966583f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
4c9c780
fix typo
Can-Zhao Aug 12, 2024
0b3e3dc
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 12, 2024
1272254
update readme
Can-Zhao Aug 12, 2024
a45c9f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
4af5f5a
reformat
Can-Zhao Aug 12, 2024
4aa9401
update readme
Can-Zhao Aug 12, 2024
b392528
clear directory in code
Can-Zhao Aug 12, 2024
a8ca930
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2024
b0dda95
clear directory in code
Can-Zhao Aug 12, 2024
1bc2a7d
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 12, 2024
094e938
change epoch to 1 to save notebook run time
Can-Zhao Aug 12, 2024
6b064d1
rm controllable size in inference input
Can-Zhao Aug 12, 2024
fb080de
mv some description to Readme, use subset of data to train in noteboo…
Can-Zhao Aug 12, 2024
5b942f1
rm dir info
Can-Zhao Aug 12, 2024
837cfb7
rm xformer
Can-Zhao Aug 13, 2024
04a78bf
rm generative in sample.py, clean print in utils.py
Can-Zhao Aug 13, 2024
e97a51d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
d5dbabe
clean legacy code in utils.py
Can-Zhao Aug 13, 2024
1b287fe
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 13, 2024
11be5ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
c223fad
rm repeated comment, rm xformer description, rm generative repo descr…
Can-Zhao Aug 13, 2024
daf5c2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
402f4c0
rm xformer description
Can-Zhao Aug 13, 2024
3c178c8
Merge branch 'main' into refactor_maisi
Can-Zhao Aug 13, 2024
288e73d
change docstring
Can-Zhao Aug 13, 2024
2659218
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 13, 2024
4e774bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
0835773
Merge branch 'main' into refactor_maisi
Can-Zhao Aug 14, 2024
c56b0f2
refactor for new controlnet
Can-Zhao Aug 14, 2024
b91b83c
Merge branch 'refactor_maisi' of https://github.com/Can-Zhao/tutorial…
Can-Zhao Aug 14, 2024
59907c4
refactor inference
Can-Zhao Aug 14, 2024
7b3c389
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions generation/maisi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,56 @@ MAISI is based on the following papers:

[**ControlNet:** Lvmin Zhang, Anyi Rao, Maneesh Agrawala; “Adding Conditional Control to Text-to-Image Diffusion Models.” ICCV 2023.](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhang_Adding_Conditional_Control_to_Text-to-Image_Diffusion_Models_ICCV_2023_paper.pdf)

### 1. Installation
Please refer to the [Installation of MONAI Generative Model](../README.md).
### 1. Network Definition
Network definition is stored in [./configs/config_maisi.json](./configs/config_maisi.json). Training and inference should use the same [./configs/config_maisi.json](./configs/config_maisi.json).

### 2. Model Inference
The information for the inference input, like body region and anatomy to generate, is stored in [./configs/config_infer.json](./configs/config_infer.json). Please feel free to play with it. Here are the details of the parameters.

- `"num_output_samples"`: int, the number of output image/mask pairs it will generate.
- `"spacing"`: voxel size of generated images. E.g., if set to `[1.5, 1.5, 2.0]`, it will generate images with a resolution of 1.5x1.5x2.0 mm.
- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512x512x256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers.
- `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. The output will contain paired image and segmentation mask for the controllable anatomy.
- `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower".
- `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json".
- `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
- `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time.

Note: MAISI depends on [xFormers](https://github.com/facebookresearch/xformers) library.
ARM64 users can build xFormers from the [source](https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers) if the available wheel does not meet their requirements.

### 2. Model inference and example outputs
Please refer to [maisi_inference_tutorial.ipynb](maisi_inference_tutorial.ipynb) for the tutorial for MAISI model inference.

### 3. Training example
#### Execute Inference:
To run the inference script, please run:
```bash
export MONAI_DATA_DIRECTORY=<dir_you_will_download_data>
python -m scripts.inference -c ./configs/config_maisi.json -i ./configs/config_infer.json -e ./configs/environment.json --random-seed 0
```

### 3. Model Training
Training data preparation can be found in [./data/README.md](./data/README.md)

#### [3.1 3D Autoencoder Training](./maisi_train_vae_tutorial.ipynb)
The information for the training hyperparameters and data processing parameters, like learning rate and patch size, are stored in [./configs/config_maisi_vae_train.json](./configs/config_maisi_vae_train.json). The provided configuration works for 16G V100 GPU. Please feel free to tune the parameters for your datasets and device.

Dataset preprocessing:
- `"random_aug"`: bool, whether to add random data augmentation for training data.
- `"spacing_type"`: choose from `"original"` (no resampling involved), `"fixed"` (all images resampled to same voxel size), and `"rand_zoom"` (images randomly zoomed, valid when `"random_aug"` is True).
- `"spacing"`: None or list of three floats. If `"spacing_type"` is `"fixed"`, all the images will be interpolated to the voxel size of `"spacing"`.
- `"select_channel"`: int, if multi-channel MRI, which channel it will select.

Training configs:
- `"batch_size"`: training batch size. Please consider increasing it if GPU memory is larger than 16G.
- `"patch_size"`: training patch size. For the released model, we first trained the autoencoder with small patch size [64,64,64], then continued training with patch size of [128,128,128].
- `"val_patch_size"`: Size of validation patches. If None, will use the whole volume for validation. If given, will central crop a patch for validation.
- `"val_sliding_window_patch_size"`: if the validation patch is too large, will use sliding window inference. Please consider increasing it if GPU memory is larger than 16G.
- `"val_batch_size"`: validation batch size.
- `"perceptual_weight"`: perceptual loss weight.
- `"kl_weight"`: KL loss weight, important hyper-parameter. If too large, decoder cannot recon good results from latent space. If too small, latent space will not be regularized enough for the diffusion model.
- `"adv_weight"`: adversavarial loss weight.
- `"recon_loss"`: choose from 'l1' and l2'.
- `"val_interval"`:int, do validation every `"val_interval"` epoches.
- `"cache"`: float between 0 and 1, dataloader cache, choose small value if CPU memory is small.
- `"n_epochs"`: int, number of epochs to train. Please adjust it based on the size of your datasets. We used 280 epochs for the released model on 58k data.

Please refer to [maisi_train_vae_tutorial.ipynb](maisi_train_vae_tutorial.ipynb) for the tutorial for MAISI VAE model training.

Expand Down
4 changes: 2 additions & 2 deletions generation/maisi/configs/config_infer.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"num_output_samples": 1,
"body_region": ["abdomen"],
"anatomy_list": ["liver"],
"controllable_anatomy_size": [["hepatic tumor", 0.3], ["liver", 0.5]],
"anatomy_list": ["liver","hepatic tumor"],
"controllable_anatomy_size": [],
"num_inference_steps": 1000,
"mask_generation_num_inference_steps": 1000,
"output_size": [
Expand Down
4 changes: 2 additions & 2 deletions generation/maisi/configs/config_maisi.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@
"dim_split": 1
},
"mask_generation_diffusion_def": {
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
"_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels":[64, 128, 256, 512],
"channels":[64, 128, 256, 512],
"attention_levels":[false, false, true, true],
"num_head_channels":[0, 0, 32, 32],
"num_res_blocks": 2,
Expand Down
2 changes: 1 addition & 1 deletion generation/maisi/configs/config_maisi_vae_train.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
"val_interval": 10,
"cache": 0.5,
"amp": true,
"n_epochs": 2
"n_epochs": 1
}
}
Loading
Loading