Skip to content

Update DiNTS Tutorials #469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 64 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
163a008
initialize dints tutorials
Nov 19, 2021
ddf8392
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2021
1a1a58d
update scripts
Nov 22, 2021
f7473d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2021
1ba684e
update scripts
Nov 22, 2021
85ee1a0
Merge branch 'master' of https://github.com/dongyang0122/tutorials
Nov 22, 2021
db1cafe
update scripts
Nov 22, 2021
981b303
Update README.md
dongyang0122 Nov 22, 2021
d98ca94
Update README.md
dongyang0122 Nov 22, 2021
3e820fd
update scripts
Nov 22, 2021
dcd4e10
Merge branch 'master' of https://github.com/dongyang0122/tutorials
Nov 22, 2021
0d3c8e7
Change Dints interface
heyufan1995 Nov 22, 2021
b8a3101
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2021
1981544
Merge branch 'master' of github.com:dongyang0122/tutorials
heyufan1995 Nov 22, 2021
2a5bb06
Modify scripts for new Dints interface
heyufan1995 Nov 22, 2021
fc2f277
Merge branch 'master' of github.com:dongyang0122/tutorials
heyufan1995 Nov 22, 2021
698a9a0
update scripts
Nov 23, 2021
4a44993
update scripts
Nov 24, 2021
8da5dc3
update scripts
Nov 30, 2021
2260ce2
update scripts
Nov 30, 2021
396c1ae
Test push
heyufan1995 Nov 30, 2021
5669d7b
update scripts
Nov 30, 2021
0a1cc78
update scripts
Nov 30, 2021
ee20f21
update scripts
Nov 30, 2021
3122caf
update readme
Nov 30, 2021
6ccdd50
Add readme
heyufan1995 Nov 30, 2021
51f3127
Merge branch 'master' of github.com:dongyang0122/tutorials
heyufan1995 Nov 30, 2021
79c9aab
Update readme
heyufan1995 Nov 30, 2021
d29f086
Change lr in search
heyufan1995 Dec 1, 2021
bdc759e
update readme
Dec 1, 2021
b2b1170
Enable single GPU
heyufan1995 Dec 1, 2021
1c72524
update readme
Dec 1, 2021
cf20e1d
Add visualization tutorial transform image (#448)
Nic-Ma Nov 22, 2021
03990f9
Update spleen_segmentation_3d.ipynb (#455)
nvahmadi Nov 24, 2021
1cb5a90
Figures added, pretrained weights link added, minor fixes (#456)
finalelement Nov 24, 2021
f573bec
Add itkwidgets example in notebook (#454)
Nic-Ma Nov 25, 2021
3f540b4
MIL example (#431)
myron Nov 25, 2021
829b8d5
450 update AsDiscrete (#451)
yiheng-wang-nv Nov 26, 2021
4422cab
459 update nvidia flare 2.0 example (#460)
yiheng-wang-nv Nov 30, 2021
cfbd3d2
Weights Link Updated (#465)
finalelement Dec 1, 2021
f3ea3c9
Merge remote-tracking branch 'upstream/master'
wyli Dec 1, 2021
a845545
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2021
e69a472
fixes readme typos
wyli Dec 1, 2021
cf5594e
Merge remote-tracking branch 'dong/master'
wyli Dec 1, 2021
6fb6b3e
update readmes
wyli Dec 1, 2021
76120b6
update readme
wyli Dec 1, 2021
7d82859
qa commit
wyli Dec 1, 2021
5cebe49
link
wyli Dec 1, 2021
8434a7f
Add plot arch_code utils
heyufan1995 Dec 3, 2021
5f290b6
Fix bugs in search and update readme
heyufan1995 Dec 3, 2021
327c190
Fix combination weights bug
heyufan1995 Dec 4, 2021
a6ee0ad
Small typo update
heyufan1995 Dec 5, 2021
a165628
Fix minor bug in train_dints
heyufan1995 Dec 6, 2021
e94d32e
update scripts
Dec 6, 2021
564a4f9
Merge branch 'master' into update-dints-tutorials
dongyang0122 Dec 6, 2021
d5d87a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2021
9c076e1
update scripts
Dec 6, 2021
4417673
update scripts
Dec 6, 2021
c491d83
update scripts
Dec 7, 2021
39dd57d
update scripts
Dec 7, 2021
4260ed2
update scripts
Dec 7, 2021
296a030
Update Readme
heyufan1995 Dec 7, 2021
06e08b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2021
0381db6
Merge branch 'master' into update-dints-tutorials
wyli Dec 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added automl/DiNTS/Figures/search_0.2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added automl/DiNTS/Figures/search_0.8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 41 additions & 11 deletions automl/DiNTS/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Examples of DiNTS: Differentiable neural network topology search
# Examples of DiNTS: Differentiable Neural Network Topology Search

In this tutorial, we present a novel neural architecture search algorithm for 3D medical image segmentation. The datasets used in this tutorial are Task07 Pancreas (CT images) and Task09 Spleen (CT images) from [Medical Segmentation Decathlon](http://medicaldecathlon.com/). The implementation is based on:

Expand All @@ -7,33 +7,47 @@ Yufan He, Dong Yang, Holger Roth, Can Zhao, Daguang Xu: "[DiNTS: Differentiable
![0.8](./Figures/arch_ram-cost-0.8.png)
![space](./Figures/search_space.png)

## Requirements
The script is tested with:
- `Ubuntu 20.04` and `CUDA 11`
- The searching and training stage requires at least two 16GB GPUs.

## Dependencies and installation
### Download and install Nvidia PyTorch Docker
The script is tested with: `Ubuntu 20.04` and `CUDA 11`

You can use nvidia docker or conda environments to install the dependencies.
- ### Using Docker Image
1. #### Download and install Nvidia PyTorch Docker
```bash
docker pull nvcr.io/nvidia/pytorch:21.10-py3
```
### Download the repository
2. #### Download the repository
```bash
git clone https://github.com/Project-MONAI/tutorials.git
```
### Run into Docker
3. #### Run into Docker
```
sudo docker run -it --gpus all --pid=host --shm-size 16G -v /location/to/tutorials/automl/DiNTS/:/workspace/DiNTS/ nvcr.io/nvidia/pytorch:21.10-py3
```
### Install MONAI and dependencies
4. #### Install required package in docker
```bash
bash install.sh
```

- ### Using Conda
1. #### Install Pytorch >= 1.6
```bash
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
```
2. #### Install MONAI and dependencies
```bash
bash install.sh
```
- ### Install [Graphviz](https://graphviz.org/download/) for visualization (needed in decode_plot.py)

## Data
[Spleen CT dataset](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2) and [Pancreas MRI dataset](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2)
from [Medical Segmentation Decathlon](http://medicaldecathlon.com/) is used. You can manually download it and save it to args.root. Otherwise, the script will automatic
download the dataset.
from [Medical Segmentation Decathlon](http://medicaldecathlon.com/) is used for this tutorial. You can manually download it and save it to args.root. Or you can use the script `download_msd_datasets.py` to download the MSD datasets of 10 segmentation tasks.
```bash
python download_msd_datasets.py --msd_task "Task07_Pancreas" \
--root "/workspace/data_msd"
```

## Examples
The tutorial contains two stages: searching stage and training stage. An architecture is searched and saved into a `.pth` file using `search_dints.py`.
Expand All @@ -53,6 +67,10 @@ python train_dints.py -h
```
- Change ``NUM_GPUS_PER_NODE`` to your number of GPUs.
- Run `bash search_dints.sh`
- Call the function in `decode_plot.py` to visualize the searched model in a vector image (graphvis needs to be installed).
The searched archtecture with ram cost 0.2 and 0.8 are shown below:
![0.2 search](./Figures/search_0.2.png)
![0.8 search](./Figures/search_0.8.png)

### Training
- Add the following script to the commands of running into docker (Optional)
Expand All @@ -69,6 +87,18 @@ Training loss and validation metric curves are shown as follows. The experiments

![validation_metric](./Figures/validation_metric.png)

## Citation
If you use this code in your work, please cite:
```
@inproceedings{he2021dints,
title={DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation},
author={He, Yufan and Yang, Dong and Roth, Holger and Zhao, Can and Xu, Daguang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={5841--5850},
year={2021}
}
```

## Questions and bugs

- For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI.
Expand Down
93 changes: 93 additions & 0 deletions automl/DiNTS/decode_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import torch

from graphviz import Digraph


parser = argparse.ArgumentParser(
description="training",
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="checkpoint full path",
)
parser.add_argument(
"--directory",
type=str,
default="./",
help="directory to save",
)
parser.add_argument(
"--filename",
type=str,
default="graph",
help="directory to save",
)

def plot_graph(
codepath,
filename="graph",
directory="./",
code2in = [0,1,0,1,2,1,2,3,2,3],
code2out = [0,0,1,1,1,2,2,2,3,3],
):
""" Plot the final searched model
Args:
codepath: path to the saved .pth file, generated from the searching script.
arch_code_a: architecture code (decoded using model.decode).
arch_code_c: cell operation code (decoded using model.decode).
filename: filename to save graph.
directory: directory to save graph.
code2in, code2out: see definition in monai.networks.nets.dints.py.
Return:
graphviz graph.
"""
code = torch.load(codepath)
arch_code_a = code["arch_code_a"]
arch_code_c = code["arch_code_c"]
ga = Digraph("G", filename=filename, engine="neato")
depth = (len(code2in) + 2)//3

# build a initial block
inputs = []
for _ in range(depth):
inputs.append("(in," + str(_) + ")")

with ga.subgraph(name="cluster_all") as g:
with g.subgraph(name="cluster_init") as c:
for idx, _ in enumerate(inputs):
c.node(_,pos="0,"+str(depth-idx)+"!")
for blk_idx in range(arch_code_a.shape[0]):
with g.subgraph(name="cluster"+str(blk_idx)) as c:
outputs = [str((blk_idx,_)) for _ in range(depth)]
for idx, _ in enumerate(outputs):
c.node(_,pos=str(2+2*blk_idx)+","+str(depth-idx)+"!")
for res_idx, activation in enumerate(arch_code_a[blk_idx]):
if activation:
c.edge(inputs[code2in[res_idx]], outputs[code2out[res_idx]], \
label=str(arch_code_c[blk_idx][res_idx]))
inputs = outputs
ga.render(filename=filename, directory=directory, cleanup=True, format="png")
return ga


if __name__ == "__main__":
args = parser.parse_args()
plot_graph(
codepath=args.checkpoint,
filename=args.filename,
directory=args.directory,
)
40 changes: 40 additions & 0 deletions automl/DiNTS/download_msd_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

from monai.apps import download_and_extract


def main():
parser = argparse.ArgumentParser(description="training")
parser.add_argument(
"--msd_task",
action="store",
default="Task07_Pancreas",
help="msd task",
)
parser.add_argument(
"--root",
action="store",
default="./data_msd",
help="data root",
)
args = parser.parse_args()

resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + args.msd_task + ".tar"
compressed_file = os.path.join(args.root, args.msd_task + ".tar")
if not os.path.exists(args.root):
download_and_extract(resource, compressed_file, args.root)

if __name__ == "__main__":
main()
Loading