diff --git a/automl/DiNTS/Figures/search_0.2.png b/automl/DiNTS/Figures/search_0.2.png new file mode 100644 index 0000000000..c6c5c1232d Binary files /dev/null and b/automl/DiNTS/Figures/search_0.2.png differ diff --git a/automl/DiNTS/Figures/search_0.8.png b/automl/DiNTS/Figures/search_0.8.png new file mode 100644 index 0000000000..010bcf7e15 Binary files /dev/null and b/automl/DiNTS/Figures/search_0.8.png differ diff --git a/automl/DiNTS/README.md b/automl/DiNTS/README.md index c3bd1d2a88..043388a3c9 100644 --- a/automl/DiNTS/README.md +++ b/automl/DiNTS/README.md @@ -1,4 +1,4 @@ -# Examples of DiNTS: Differentiable neural network topology search +# Examples of DiNTS: Differentiable Neural Network Topology Search In this tutorial, we present a novel neural architecture search algorithm for 3D medical image segmentation. The datasets used in this tutorial are Task07 Pancreas (CT images) and Task09 Spleen (CT images) from [Medical Segmentation Decathlon](http://medicaldecathlon.com/). The implementation is based on: @@ -7,33 +7,47 @@ Yufan He, Dong Yang, Holger Roth, Can Zhao, Daguang Xu: "[DiNTS: Differentiable ![0.8](./Figures/arch_ram-cost-0.8.png) ![space](./Figures/search_space.png) -## Requirements -The script is tested with: -- `Ubuntu 20.04` and `CUDA 11` -- The searching and training stage requires at least two 16GB GPUs. ## Dependencies and installation -### Download and install Nvidia PyTorch Docker +The script is tested with: `Ubuntu 20.04` and `CUDA 11` + +You can use nvidia docker or conda environments to install the dependencies. +- ### Using Docker Image +1. #### Download and install Nvidia PyTorch Docker ```bash docker pull nvcr.io/nvidia/pytorch:21.10-py3 ``` -### Download the repository +2. #### Download the repository ```bash git clone https://github.com/Project-MONAI/tutorials.git ``` -### Run into Docker +3. #### Run into Docker ``` sudo docker run -it --gpus all --pid=host --shm-size 16G -v /location/to/tutorials/automl/DiNTS/:/workspace/DiNTS/ nvcr.io/nvidia/pytorch:21.10-py3 ``` -### Install MONAI and dependencies +4. #### Install required package in docker +```bash +bash install.sh +``` + +- ### Using Conda +1. #### Install Pytorch >= 1.6 +```bash +conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch +``` +2. #### Install MONAI and dependencies ```bash bash install.sh ``` +- ### Install [Graphviz](https://graphviz.org/download/) for visualization (needed in decode_plot.py) ## Data [Spleen CT dataset](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2) and [Pancreas MRI dataset](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2) -from [Medical Segmentation Decathlon](http://medicaldecathlon.com/) is used. You can manually download it and save it to args.root. Otherwise, the script will automatic -download the dataset. +from [Medical Segmentation Decathlon](http://medicaldecathlon.com/) is used for this tutorial. You can manually download it and save it to args.root. Or you can use the script `download_msd_datasets.py` to download the MSD datasets of 10 segmentation tasks. +```bash +python download_msd_datasets.py --msd_task "Task07_Pancreas" \ + --root "/workspace/data_msd" +``` ## Examples The tutorial contains two stages: searching stage and training stage. An architecture is searched and saved into a `.pth` file using `search_dints.py`. @@ -53,6 +67,10 @@ python train_dints.py -h ``` - Change ``NUM_GPUS_PER_NODE`` to your number of GPUs. - Run `bash search_dints.sh` +- Call the function in `decode_plot.py` to visualize the searched model in a vector image (graphvis needs to be installed). +The searched archtecture with ram cost 0.2 and 0.8 are shown below: +![0.2 search](./Figures/search_0.2.png) +![0.8 search](./Figures/search_0.8.png) ### Training - Add the following script to the commands of running into docker (Optional) @@ -69,6 +87,18 @@ Training loss and validation metric curves are shown as follows. The experiments ![validation_metric](./Figures/validation_metric.png) +## Citation +If you use this code in your work, please cite: +``` +@inproceedings{he2021dints, + title={DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation}, + author={He, Yufan and Yang, Dong and Roth, Holger and Zhao, Can and Xu, Daguang}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={5841--5850}, + year={2021} +} +``` + ## Questions and bugs - For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. diff --git a/automl/DiNTS/decode_plot.py b/automl/DiNTS/decode_plot.py new file mode 100644 index 0000000000..747fbeb940 --- /dev/null +++ b/automl/DiNTS/decode_plot.py @@ -0,0 +1,93 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import torch + +from graphviz import Digraph + + +parser = argparse.ArgumentParser( + description="training", +) +parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="checkpoint full path", +) +parser.add_argument( + "--directory", + type=str, + default="./", + help="directory to save", +) +parser.add_argument( + "--filename", + type=str, + default="graph", + help="directory to save", +) + +def plot_graph( + codepath, + filename="graph", + directory="./", + code2in = [0,1,0,1,2,1,2,3,2,3], + code2out = [0,0,1,1,1,2,2,2,3,3], +): + """ Plot the final searched model + Args: + codepath: path to the saved .pth file, generated from the searching script. + arch_code_a: architecture code (decoded using model.decode). + arch_code_c: cell operation code (decoded using model.decode). + filename: filename to save graph. + directory: directory to save graph. + code2in, code2out: see definition in monai.networks.nets.dints.py. + Return: + graphviz graph. + """ + code = torch.load(codepath) + arch_code_a = code["arch_code_a"] + arch_code_c = code["arch_code_c"] + ga = Digraph("G", filename=filename, engine="neato") + depth = (len(code2in) + 2)//3 + + # build a initial block + inputs = [] + for _ in range(depth): + inputs.append("(in," + str(_) + ")") + + with ga.subgraph(name="cluster_all") as g: + with g.subgraph(name="cluster_init") as c: + for idx, _ in enumerate(inputs): + c.node(_,pos="0,"+str(depth-idx)+"!") + for blk_idx in range(arch_code_a.shape[0]): + with g.subgraph(name="cluster"+str(blk_idx)) as c: + outputs = [str((blk_idx,_)) for _ in range(depth)] + for idx, _ in enumerate(outputs): + c.node(_,pos=str(2+2*blk_idx)+","+str(depth-idx)+"!") + for res_idx, activation in enumerate(arch_code_a[blk_idx]): + if activation: + c.edge(inputs[code2in[res_idx]], outputs[code2out[res_idx]], \ + label=str(arch_code_c[blk_idx][res_idx])) + inputs = outputs + ga.render(filename=filename, directory=directory, cleanup=True, format="png") + return ga + + +if __name__ == "__main__": + args = parser.parse_args() + plot_graph( + codepath=args.checkpoint, + filename=args.filename, + directory=args.directory, + ) diff --git a/automl/DiNTS/download_msd_datasets.py b/automl/DiNTS/download_msd_datasets.py new file mode 100644 index 0000000000..0bd91654c4 --- /dev/null +++ b/automl/DiNTS/download_msd_datasets.py @@ -0,0 +1,40 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from monai.apps import download_and_extract + + +def main(): + parser = argparse.ArgumentParser(description="training") + parser.add_argument( + "--msd_task", + action="store", + default="Task07_Pancreas", + help="msd task", + ) + parser.add_argument( + "--root", + action="store", + default="./data_msd", + help="data root", + ) + args = parser.parse_args() + + resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + args.msd_task + ".tar" + compressed_file = os.path.join(args.root, args.msd_task + ".tar") + if not os.path.exists(args.root): + download_and_extract(resource, compressed_file, args.root) + +if __name__ == "__main__": + main() diff --git a/automl/DiNTS/search_dints.py b/automl/DiNTS/search_dints.py index 38cf1ca7a6..8369f92671 100644 --- a/automl/DiNTS/search_dints.py +++ b/automl/DiNTS/search_dints.py @@ -134,6 +134,12 @@ def main(): default=None, help="checkpoint full path", ) + parser.add_argument( + "--factor_ram_cost", + default=0.0, + type=float, + help="factor to determine RAM cost in the searched architecture", + ) parser.add_argument( "--fold", action="store", @@ -180,18 +186,19 @@ def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not os.path.exists(args.output_root): - os.makedirs(args.output_root) + os.makedirs(args.output_root, exist_ok=True) amp = True - determ = False - factor_ram_cost = 0.2 + determ = True + factor_ram_cost = args.factor_ram_cost fold = int(args.fold) input_channels = 1 - learning_rate = 0.0002 - learning_rate_final = 0.00001 + learning_rate = 0.025 + learning_rate_arch = 0.001 + learning_rate_milestones = np.array([0.4, 0.8]) num_images_per_batch = 1 - num_epochs = 1430 - num_epochs_per_validation = 60 + num_epochs = 1430 # around 20k iteration + num_epochs_per_validation = 100 num_epochs_warmup = 715 num_folds = int(args.num_folds) num_patches_per_image = 1 @@ -202,6 +209,8 @@ def main(): patch_size_valid = (96, 96, 96) spacing = [1.0, 1.0, 1.0] + print("factor_ram_cost", factor_ram_cost) + # deterministic training if determ: set_determinism(seed=0) @@ -209,16 +218,7 @@ def main(): # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") - # data - if dist.get_rank() == 0: - resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + args.root.split(os.sep)[-1] + ".tar" - compressed_file = args.root + ".tar" - data_dir = args.root - root_dir = os.path.join(*args.root.split(os.sep)[:-1]) - if not os.path.exists(data_dir): - download_and_extract(resource, compressed_file, root_dir) - - dist.barrier() + # dist.barrier() world_size = dist.get_world_size() with open(args.json, "r") as f: @@ -238,7 +238,6 @@ def main(): continue files.append({"image": str_img, "label": str_seg}) - train_files = files random.shuffle(train_files) @@ -328,18 +327,20 @@ def main(): train_ds_w = monai.data.CacheDataset(data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8) val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) + # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited. # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms) # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms) # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) - # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) - # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) - # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available()) - train_loader_a = ThreadDataLoader(train_ds_a, num_workers=0, batch_size=num_images_per_batch, shuffle=True) train_loader_w = ThreadDataLoader(train_ds_w, num_workers=0, batch_size=num_images_per_batch, shuffle=True) val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) + # DataLoader can be used as alternatives when ThreadDataLoader is less efficient. + # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) + # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) + # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available()) + dints_space = monai.networks.nets.TopologySearch( channel_mul=0.5, num_blocks=12, @@ -374,9 +375,9 @@ def main(): ) # optimizer - optimizer = torch.optim.SGD(model.weight_parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.00004) - arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a], lr=0.001, betas=(0.5, 0.999), weight_decay=0.0) - arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c], lr=0.001, betas=(0.5, 0.999), weight_decay=0.0) + optimizer = torch.optim.SGD(model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004) + arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) + arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0) print() @@ -418,17 +419,10 @@ def main(): start_time = time.time() for epoch in range(num_epochs): - if learning_rate_final > -0.000001 and learning_rate_final < learning_rate: - # lr = (learning_rate - learning_rate_final) * (1 - epoch / (num_epochs - 1)) ** 0.9 + learning_rate_final - milestones = np.array([0.4, 0.8]) - decay = 0.5 ** np.sum([(epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > milestones]) - lr = learning_rate * decay - for param_group in optimizer.param_groups: - param_group["lr"] = lr - else: - lr = learning_rate - - lr = optimizer.param_groups[0]["lr"] + decay = 0.5 ** np.sum([(epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones]) + lr = learning_rate * decay + for param_group in optimizer.param_groups: + param_group["lr"] = lr if dist.get_rank() == 0: print("-" * 10) @@ -453,6 +447,9 @@ def main(): _.requires_grad = True dints_space.log_alpha_a.requires_grad = False dints_space.log_alpha_c.requires_grad = False + + optimizer.zero_grad() + if amp: with autocast(): outputs = model(inputs) @@ -502,12 +499,12 @@ def main(): dints_space.log_alpha_c.requires_grad = True # linear increase topology and RAM loss - entropy_alpha_c = torch.tensor(0.).cuda() - entropy_alpha_a = torch.tensor(0.).cuda() - ram_cost_full = torch.tensor(0.).cuda() - ram_cost_usage = torch.tensor(0.).cuda() - ram_cost_loss = torch.tensor(0.).cuda() - topology_loss = torch.tensor(0.).cuda() + entropy_alpha_c = torch.tensor(0.).to(device) + entropy_alpha_a = torch.tensor(0.).to(device) + ram_cost_full = torch.tensor(0.).to(device) + ram_cost_usage = torch.tensor(0.).to(device) + ram_cost_loss = torch.tensor(0.).to(device) + topology_loss = torch.tensor(0.).to(device) probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) entropy_alpha_a = -((probs_a)*torch.log(probs_a + 1e-5)).mean() @@ -522,6 +519,8 @@ def main(): arch_optimizer_a.zero_grad() arch_optimizer_c.zero_grad() + combination_weights = (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) + if amp: with autocast(): outputs_search = model(inputs_search) @@ -530,7 +529,7 @@ def main(): else: loss = loss_func(outputs_search, labels_search) - loss += 1.0 * (1.0 * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) scaler.scale(loss).backward() @@ -544,7 +543,7 @@ def main(): else: loss = loss_func(outputs_search, labels_search) - loss += 1.0 * (1.0 * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + 0.001 * topology_loss) loss.backward() @@ -568,11 +567,9 @@ def main(): loss_torch_epoch = loss_torch[0] / loss_torch[1] print(f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}") - if epoch < num_epochs_warmup: - continue - - loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] - print(f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}") + if epoch >= num_epochs_warmup: + loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] + print(f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}") if (epoch + 1) % val_interval == 0: torch.cuda.empty_cache() @@ -666,9 +663,9 @@ def main(): torch.save( { "node_a": node_a_d, - "code_a": arch_code_a_d, - "code_a_max": arch_code_a_max_d, - "code_c": arch_code_c_d, + "arch_code_a": arch_code_a_d, + "arch_code_a_max": arch_code_a_max_d, + "arch_code_c": arch_code_c_d, "iter_num": idx_iter, "epochs": epoch + 1, "best_dsc": best_metric, diff --git a/automl/DiNTS/search_dints.sh b/automl/DiNTS/search_dints.sh index 8c94ae9885..e4277869ce 100644 --- a/automl/DiNTS/search_dints.sh +++ b/automl/DiNTS/search_dints.sh @@ -3,14 +3,14 @@ clear TASK="Task07_Pancreas" -# DATA_ROOT="/home/dongy/Data/MSD/${TASK}" -DATA_ROOT="/workspace/data_msd/${TASK}" +DATA_ROOT="/home/dongy/Data/MSD/${TASK}" +# DATA_ROOT="/workspace/data_msd/${TASK}" JSON_PATH="${DATA_ROOT}/dataset.json" FOLD=4 NUM_FOLDS=5 -NUM_GPUS_PER_NODE=8 +NUM_GPUS_PER_NODE=4 NUM_NODES=1 if [ ${NUM_GPUS_PER_NODE} -eq 1 ] @@ -29,8 +29,9 @@ fi CHECKPOINT_ROOT="models/search_${TASK}_fold${FOLD}" CHECKPOINT="${CHECKPOINT_ROOT}/best_metric_model.pth" +FACTOR_RAM_COST=0.8 JSON_KEY="training" -OUTPUT_ROOT="models/search_${TASK}_fold${FOLD}" +OUTPUT_ROOT="models/search_${TASK}_fold${FOLD}_ram${FACTOR_RAM_COST}" python -m torch.distributed.launch \ --nproc_per_node=${NUM_GPUS_PER_NODE} \ @@ -39,6 +40,7 @@ python -m torch.distributed.launch \ --master_addr=localhost \ --master_port=1234 \ search_dints.py --checkpoint=${CHECKPOINT} \ + --factor_ram_cost=${FACTOR_RAM_COST} \ --fold=${FOLD} \ --json=${JSON_PATH} \ --json_key=${JSON_KEY} \ diff --git a/automl/DiNTS/train_dints.py b/automl/DiNTS/train_dints.py index 8d4abaa577..66947b8bd8 100644 --- a/automl/DiNTS/train_dints.py +++ b/automl/DiNTS/train_dints.py @@ -185,17 +185,17 @@ def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) if not os.path.exists(args.output_root): - os.makedirs(args.output_root) + os.makedirs(args.output_root, exist_ok=True) amp = True - determ = False + determ = True fold = int(args.fold) input_channels = 1 - learning_rate = 0.0002 - learning_rate_final = 0.00001 + learning_rate = 0.025 + learning_rate_milestones = np.array([0.2, 0.4, 0.6, 0.8]) num_images_per_batch = 2 num_epochs = 13500 - num_epochs_per_validation = 50 + num_epochs_per_validation = 500 num_folds = int(args.num_folds) num_patches_per_image = 1 num_sw_batch_size = 6 @@ -212,16 +212,8 @@ def main(): # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") - # download data - if dist.get_rank() == 0: - resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + args.root.split(os.sep)[-1] + ".tar" - compressed_file = args.root + ".tar" - data_dir = args.root - root_dir = os.path.join(*args.root.split(os.sep)[:-1]) - if not os.path.exists(data_dir): - download_and_extract(resource, compressed_file, root_dir) - dist.barrier() + world_size = dist.get_world_size() # load data list (.json) with open(args.json, "r") as f: @@ -243,7 +235,7 @@ def main(): files.append({"image": str_img, "label": str_seg}) train_files = files - train_files = partition_dataset(data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True)[dist.get_rank()] + train_files = partition_dataset(data=train_files, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] print("train_files:", len(train_files)) # validation data @@ -257,7 +249,7 @@ def main(): files.append({"image": str_img, "label": str_seg}) val_files = files - val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=dist.get_world_size(), even_divisible=False)[dist.get_rank()] + val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[dist.get_rank()] print("val_files:", len(val_files)) # network architecture @@ -330,8 +322,8 @@ def main(): ckpt = torch.load(args.arch_ckpt) node_a = ckpt['node_a'] - arch_code_a = ckpt['code_a'] - arch_code_c = ckpt['code_c'] + arch_code_a = ckpt['arch_code_a'] + arch_code_c = ckpt['arch_code_c'] dints_space = monai.networks.nets.TopologyInstance( channel_mul=1.0, @@ -369,9 +361,11 @@ def main(): ) # optimizer - optimizer = torch.optim.AdamW( + optimizer = torch.optim.SGD( model.parameters(), - lr=learning_rate + lr=learning_rate * world_size, + momentum=0.9, + weight_decay=0.00004 ) print() @@ -412,14 +406,10 @@ def main(): start_time = time.time() for epoch in range(num_epochs): - if learning_rate_final > -0.000001 and learning_rate_final < learning_rate: - lr = (learning_rate - learning_rate_final) * (1 - epoch / (num_epochs - 1)) ** 0.9 + learning_rate_final - for param_group in optimizer.param_groups: - param_group["lr"] = lr - else: - lr = learning_rate - - lr = optimizer.param_groups[0]["lr"] + decay = 0.5 ** np.sum([epoch / num_epochs > learning_rate_milestones]) + lr = learning_rate * decay * world_size + for param_group in optimizer.param_groups: + param_group["lr"] = lr if dist.get_rank() == 0: print("-" * 10) diff --git a/automl/README.md b/automl/README.md index 30208a0652..9ca5188b64 100644 --- a/automl/README.md +++ b/automl/README.md @@ -2,7 +2,7 @@ Here we showcase the most recent AutoML techniques in medical imaging based on MONAI modules. -## [DiNTS: Differentiable neural network topology search](./DiNTS) +## [DiNTS: Differentiable Neural Network Topology Search](./DiNTS) Recently, neural architecture search (NAS) has been applied to automatically search high-performance networks for medical image segmentation. The NAS search space usually contains a network topology level (controlling connections among