Skip to content

Commit e49afe3

Browse files
add tutorial for 3d ldm on brats (#1301)
Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: Can-Zhao <volcanofly@gmail.com> Signed-off-by: Can Zhao <canz@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e5adbd0 commit e49afe3

17 files changed

+1163
-0
lines changed

generative/3d_ldm/README.md

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# 3D Latent Diffusion Example
2+
This folder contains an example for training and validating a 3D Latent Diffusion Model on Brats data. The example includes support for multi-GPU training with distributed data parallelism based on a [tutorial designed for using single GPU](https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb).
3+
4+
The workflow of the Latent Diffusion Model is depicted in the figure below. It begins by training an autoencoder in pixel space to encode images into latent features. Following that, it trains a diffusion model in the latent space to denoise the noisy latent features. During inference, it first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Finally, it decodes the denoised latent features into images using the trained autoencoder.
5+
<p align="center">
6+
<img src="./figs/ldm.png" alt="latent diffusion scheme")
7+
</p>
8+
9+
MONAI latent diffusion model implementation is based on the following papers:
10+
11+
[**Latent Diffusion:** Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf)
12+
13+
This network is designed as a demonstration to showcase the training process for this type of network using MONAI. To achieve optimal performance, it is recommended that users have a GPU with memory larger than 32G to accommodate larger networks and attention layers.
14+
15+
### 1. Data
16+
17+
The dataset we are experimenting with in this example is BraTS 2016 and 2017 data.
18+
19+
BraTS is a public dataset of brain MR images. Using these images, the goal is to generate images that look similar to the images in BraTS 2016 and 2017 dataset.
20+
21+
The data can be automatically downloaded from [Medical Decathlon](http://medicaldecathlon.com/) at the beginning of training.
22+
23+
Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset! We acknowledge the National Cancer Institute and the Foundation for the National Institutes of Health, and their critical role in the creation of the free publicly available LIDC/IDRI Database used in this study.
24+
25+
### 2. Installation
26+
```
27+
pip install lpips
28+
pip install git+https://github.com/Project-MONAI/GenerativeModels.git#egg=Generative
29+
```
30+
31+
Or install it from source:
32+
```
33+
pip install lpips
34+
git clone https://github.com/Project-MONAI/GenerativeModels.git
35+
cd GenerativeModels/
36+
python setup.py install
37+
cd ..
38+
```
39+
### 3. Run the example
40+
41+
#### [3.1 3D Autoencoder Training](./train_autoencoder.py)
42+
43+
The network configuration files are located in [./config/config_train_32g.json](./config/config_train_32g.json) for 32G GPU
44+
and [./config/config_train_16g.json](./config/config_train_16g.json) for 16G GPU.
45+
You can modify the hyperparameters in these files to suit your requirements.
46+
47+
The training script resamples the brain images based on the voxel spacing specified in the `"spacing"` field of the configuration files. For instance, `"spacing": [1.1, 1.1, 1.1]` resamples the images to a resolution of 1.1x1.1x1.1 mm. If you have a GPU with larger memory, you may consider changing the `"spacing"` field to `"spacing": [1.0, 1.0, 1.0]`.
48+
49+
The training script uses the batch size and patch size defined in the configuration files. If you have a different GPU memory size, you should adjust the `"batch_size"` and `"patch_size"` parameters in the `"autoencoder_train"` to match your GPU. Note that the `"patch_size"` needs to be divisible by 4.
50+
51+
Before you start training, please set the path in [./config/environment.json](./config/environment.json).
52+
53+
- `"model_dir"`: where it saves the trained models
54+
- `"tfevent_path"`: where it saves the tensorboard events
55+
- `"output_dir"`: where you store the generated images during inference.
56+
- `"resume_ckpt"`: whether to resume training from existing checkpoints.
57+
- `"data_base_dir"`: where you store the Brats dataset.
58+
59+
If the Brats dataset is not downloaded, please add `--download_data` in training command, the Brats data will be downloaded from [Medical Decathlon](http://medicaldecathlon.com/) and extracted to `$data_base_dir`. You will see a subfolder `Task01_BrainTumour` under `$data_base_dir`. By default, you will see `./Task01_BrainTumour`
60+
For example, this command is for running the training script with one 32G gpu.
61+
```bash
62+
python train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1 --download_data
63+
```
64+
If `$data_base_dir/Task01_BrainTumour` already exists, you may omit the downloading.
65+
```bash
66+
python train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1
67+
```
68+
69+
The training script also enables multi-GPU training. For instance, if you are using eight 32G GPUs, you can run the training script with the following command:
70+
```bash
71+
export NUM_GPUS_PER_NODE=8
72+
torchrun \
73+
--nproc_per_node=${NUM_GPUS_PER_NODE} \
74+
--nnodes=1 \
75+
--master_addr=localhost --master_port=1234 \
76+
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
77+
```
78+
79+
<p align="center">
80+
<img src="./figs/train_recon.png" alt="autoencoder train curve" width="45%" >
81+
&nbsp; &nbsp; &nbsp; &nbsp;
82+
<img src="./figs/val_recon.png" alt="autoencoder validation curve" width="45%" >
83+
</p>
84+
85+
With eight DGX1V 32G GPUs, it took around 55 hours to train 1000 epochs.
86+
87+
#### [3.2 3D Latent Diffusion Training](./train_diffusion.py)
88+
The training script uses the batch size and patch size defined in the configuration files. If you have a different GPU memory size, you should adjust the `"batch_size"` and `"patch_size"` parameters in the `"diffusion_train"` to match your GPU. Note that the `"patch_size"` needs to be divisible by 16.
89+
90+
To train with single 32G GPU, please run:
91+
```bash
92+
python train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1
93+
```
94+
95+
The training script also enables multi-GPU training. For instance, if you are using eight 32G GPUs, you can run the training script with the following command:
96+
```bash
97+
export NUM_GPUS_PER_NODE=8
98+
torchrun \
99+
--nproc_per_node=${NUM_GPUS_PER_NODE} \
100+
--nnodes=1 \
101+
--master_addr=localhost --master_port=1234 \
102+
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
103+
```
104+
<p align="center">
105+
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" >
106+
&nbsp; &nbsp; &nbsp; &nbsp;
107+
<img src="./figs/val_diffusion.png" alt="latent diffusion validation curve" width="45%" >
108+
</p>
109+
110+
#### [3.3 Inference](./inference.py)
111+
To generate one image during inference, please run the following command:
112+
```bash
113+
python inference.py -c ./config/config_train_32g.json -e ./config/environment.json --num 1
114+
```
115+
`--num` defines how many images it would generate.
116+
117+
An example output is shown below.
118+
<p align="center">
119+
<img src="./figs/syn_axial.png" width="30%" >
120+
&nbsp; &nbsp; &nbsp; &nbsp;
121+
<img src="./figs/syn_sag.png" width="30%" >
122+
&nbsp; &nbsp; &nbsp; &nbsp;
123+
<img src="./figs/syn_cor.png" width="30%" >
124+
</p>
125+
126+
### 4. Questions and bugs
127+
128+
- For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI.
129+
- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues).
130+
- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).
131+
132+
### Reference
133+
[1] [Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf)
134+
135+
[2] [Menze, Bjoern H., et al. "The multimodal brain tumor image segmentation benchmark (BRATS)." IEEE transactions on medical imaging 34.10 (2014): 1993-2024.](https://ieeexplore.ieee.org/document/6975210)
136+
137+
[3] [Pinaya et al. "Brain imaging generation with latent diffusion models"](https://arxiv.org/abs/2209.07162)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
{
2+
"channel": 0,
3+
"spacing": [1.1, 1.1, 1.1],
4+
"spatial_dims": 3,
5+
"image_channels": 1,
6+
"latent_channels": 8,
7+
"autoencoder_def": {
8+
"_target_": "generative.networks.nets.AutoencoderKL",
9+
"spatial_dims": "@spatial_dims",
10+
"in_channels": "$@image_channels",
11+
"out_channels": "@image_channels",
12+
"latent_channels": "@latent_channels",
13+
"num_channels": [
14+
64,
15+
128,
16+
256
17+
],
18+
"num_res_blocks": 2,
19+
"norm_num_groups": 32,
20+
"norm_eps": 1e-06,
21+
"attention_levels": [
22+
false,
23+
false,
24+
false
25+
],
26+
"with_encoder_nonlocal_attn": false,
27+
"with_decoder_nonlocal_attn": false
28+
},
29+
"autoencoder_train": {
30+
"batch_size": 1,
31+
"patch_size": [112,128,80],
32+
"lr": 5e-6,
33+
"perceptual_weight": 0.001,
34+
"kl_weight": 1e-7,
35+
"recon_loss": "l1",
36+
"n_epochs": 1000,
37+
"val_interval": 10
38+
},
39+
"diffusion_def": {
40+
"_target_": "generative.networks.nets.DiffusionModelUNet",
41+
"spatial_dims": "@spatial_dims",
42+
"in_channels": "@latent_channels",
43+
"out_channels": "@latent_channels",
44+
"num_channels":[256, 256, 512],
45+
"attention_levels":[false, true, true],
46+
"num_head_channels":[0, 64, 64],
47+
"num_res_blocks": 2
48+
},
49+
"diffusion_train": {
50+
"batch_size": 2,
51+
"patch_size": [144,176,112],
52+
"lr": 5e-6,
53+
"n_epochs": 10000,
54+
"val_interval": 2
55+
},
56+
"NoiseScheduler": {
57+
"num_train_timesteps": 1000,
58+
"beta_start": 0.0015,
59+
"beta_end": 0.0195
60+
}
61+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
{
2+
"channel": 0,
3+
"spacing": [1.1, 1.1, 1.1],
4+
"spatial_dims": 3,
5+
"image_channels": 1,
6+
"latent_channels": 8,
7+
"autoencoder_def": {
8+
"_target_": "generative.networks.nets.AutoencoderKL",
9+
"spatial_dims": "@spatial_dims",
10+
"in_channels": "$@image_channels",
11+
"out_channels": "@image_channels",
12+
"latent_channels": "@latent_channels",
13+
"num_channels": [
14+
64,
15+
128,
16+
256
17+
],
18+
"num_res_blocks": 2,
19+
"norm_num_groups": 32,
20+
"norm_eps": 1e-06,
21+
"attention_levels": [
22+
false,
23+
false,
24+
false
25+
],
26+
"with_encoder_nonlocal_attn": false,
27+
"with_decoder_nonlocal_attn": false
28+
},
29+
"autoencoder_train": {
30+
"batch_size": 2,
31+
"patch_size": [112,128,80],
32+
"lr": 1e-5,
33+
"perceptual_weight": 0.001,
34+
"kl_weight": 1e-7,
35+
"recon_loss": "l1",
36+
"n_epochs": 1000,
37+
"val_interval": 10
38+
},
39+
"diffusion_def": {
40+
"_target_": "generative.networks.nets.DiffusionModelUNet",
41+
"spatial_dims": "@spatial_dims",
42+
"in_channels": "@latent_channels",
43+
"out_channels": "@latent_channels",
44+
"num_channels":[256, 256, 512],
45+
"attention_levels":[false, true, true],
46+
"num_head_channels":[0, 64, 64],
47+
"num_res_blocks": 2
48+
},
49+
"diffusion_train": {
50+
"batch_size": 3,
51+
"patch_size": [144,176,112],
52+
"lr": 1e-5,
53+
"n_epochs": 10000,
54+
"val_interval": 2
55+
},
56+
"NoiseScheduler": {
57+
"num_train_timesteps": 1000,
58+
"beta_start": 0.0015,
59+
"beta_end": 0.0195
60+
}
61+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"data_base_dir": "./dataset",
3+
"model_dir": "./trained_weights/diffusion_3d",
4+
"tfevent_path": "./tfevent/diffusion_3d",
5+
"output_dir": "./output",
6+
"resume_ckpt": false
7+
}

generative/3d_ldm/figs/ldm.png

69.9 KB
Loading

generative/3d_ldm/figs/syn_axial.png

106 KB
Loading

generative/3d_ldm/figs/syn_cor.png

98.5 KB
Loading

generative/3d_ldm/figs/syn_sag.png

85 KB
Loading
338 KB
Loading
294 KB
Loading
330 KB
Loading

generative/3d_ldm/figs/val_recon.png

205 KB
Loading

generative/3d_ldm/inference.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import argparse
13+
import json
14+
import logging
15+
import os
16+
import sys
17+
from datetime import datetime
18+
from pathlib import Path
19+
20+
import nibabel as nib
21+
import numpy as np
22+
import torch
23+
from generative.inferers import LatentDiffusionInferer
24+
from generative.networks.schedulers import DDPMScheduler
25+
from monai.config import print_config
26+
from monai.utils import set_determinism
27+
28+
from utils import define_instance
29+
30+
31+
def main():
32+
parser = argparse.ArgumentParser(description="PyTorch Latent Diffusion Model Inference")
33+
parser.add_argument(
34+
"-e",
35+
"--environment-file",
36+
default="./config/environment.json",
37+
help="environment json file that stores environment path",
38+
)
39+
parser.add_argument(
40+
"-c",
41+
"--config-file",
42+
default="./config/config_train_48g.json",
43+
help="config json file that stores hyper-parameters",
44+
)
45+
parser.add_argument(
46+
"-n",
47+
"--num",
48+
type=int,
49+
default=1,
50+
help="number of generated images",
51+
)
52+
args = parser.parse_args()
53+
54+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55+
56+
print_config()
57+
torch.backends.cudnn.benchmark = True
58+
torch.set_num_threads(4)
59+
60+
env_dict = json.load(open(args.environment_file, "r"))
61+
config_dict = json.load(open(args.config_file, "r"))
62+
63+
for k, v in env_dict.items():
64+
setattr(args, k, v)
65+
for k, v in config_dict.items():
66+
setattr(args, k, v)
67+
68+
set_determinism(42)
69+
70+
# load trained networks
71+
autoencoder = define_instance(args, "autoencoder_def").to(device)
72+
trained_g_path = os.path.join(args.model_dir, "autoencoder.pt")
73+
autoencoder.load_state_dict(torch.load(trained_g_path))
74+
75+
diffusion_model = define_instance(args, "diffusion_def").to(device)
76+
trained_diffusion_path = os.path.join(args.model_dir, "diffusion_unet.pt")
77+
diffusion_model.load_state_dict(torch.load(trained_diffusion_path))
78+
79+
scheduler = DDPMScheduler(
80+
num_train_timesteps=args.NoiseScheduler["num_train_timesteps"],
81+
beta_schedule="scaled_linear",
82+
beta_start=args.NoiseScheduler["beta_start"],
83+
beta_end=args.NoiseScheduler["beta_end"],
84+
)
85+
inferer = LatentDiffusionInferer(scheduler, scale_factor=1.0)
86+
87+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
88+
latent_shape = [p // 4 for p in args.diffusion_train["patch_size"]]
89+
noise_shape = [1, args.latent_channels] + latent_shape
90+
91+
for _ in range(args.num):
92+
noise = torch.randn(noise_shape, dtype=torch.float32).to(device)
93+
with torch.no_grad():
94+
synthetic_images = inferer.sample(
95+
input_noise=noise,
96+
autoencoder_model=autoencoder,
97+
diffusion_model=diffusion_model,
98+
scheduler=scheduler,
99+
)
100+
filename = os.path.join(args.output_dir, datetime.now().strftime("synimg_%Y%m%d_%H%M%S"))
101+
final_img = nib.Nifti1Image(synthetic_images[0, 0, ...].unsqueeze(-1).cpu().numpy(), np.eye(4))
102+
nib.save(final_img, filename)
103+
104+
105+
if __name__ == "__main__":
106+
logging.basicConfig(
107+
stream=sys.stdout,
108+
level=logging.INFO,
109+
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
110+
datefmt="%Y-%m-%d %H:%M:%S",
111+
)
112+
main()

0 commit comments

Comments
 (0)