diff --git a/automl/DiNTS/Figures/arch_ram-cost-0.8.png b/automl/DiNTS/Figures/arch_ram-cost-0.8.png new file mode 100755 index 0000000000..f946818dc0 Binary files /dev/null and b/automl/DiNTS/Figures/arch_ram-cost-0.8.png differ diff --git a/automl/DiNTS/Figures/search_space.png b/automl/DiNTS/Figures/search_space.png new file mode 100755 index 0000000000..f8ad66318d Binary files /dev/null and b/automl/DiNTS/Figures/search_space.png differ diff --git a/automl/DiNTS/Figures/training_loss.png b/automl/DiNTS/Figures/training_loss.png new file mode 100644 index 0000000000..a78662249d Binary files /dev/null and b/automl/DiNTS/Figures/training_loss.png differ diff --git a/automl/DiNTS/Figures/validation_metric.png b/automl/DiNTS/Figures/validation_metric.png new file mode 100644 index 0000000000..97655d7ff7 Binary files /dev/null and b/automl/DiNTS/Figures/validation_metric.png differ diff --git a/automl/DiNTS/README.md b/automl/DiNTS/README.md new file mode 100644 index 0000000000..c3bd1d2a88 --- /dev/null +++ b/automl/DiNTS/README.md @@ -0,0 +1,76 @@ +# Examples of DiNTS: Differentiable neural network topology search + +In this tutorial, we present a novel neural architecture search algorithm for 3D medical image segmentation. The datasets used in this tutorial are Task07 Pancreas (CT images) and Task09 Spleen (CT images) from [Medical Segmentation Decathlon](http://medicaldecathlon.com/). The implementation is based on: + +Yufan He, Dong Yang, Holger Roth, Can Zhao, Daguang Xu: "[DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation.](https://openaccess.thecvf.com/content/CVPR2021/papers/He_DiNTS_Differentiable_Neural_Network_Topology_Search_for_3D_Medical_Image_CVPR_2021_paper.pdf)" In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5841-5850. 2021. + +![0.8](./Figures/arch_ram-cost-0.8.png) +![space](./Figures/search_space.png) + +## Requirements +The script is tested with: +- `Ubuntu 20.04` and `CUDA 11` +- The searching and training stage requires at least two 16GB GPUs. + +## Dependencies and installation +### Download and install Nvidia PyTorch Docker +```bash +docker pull nvcr.io/nvidia/pytorch:21.10-py3 +``` +### Download the repository +```bash +git clone https://github.com/Project-MONAI/tutorials.git +``` +### Run into Docker +``` +sudo docker run -it --gpus all --pid=host --shm-size 16G -v /location/to/tutorials/automl/DiNTS/:/workspace/DiNTS/ nvcr.io/nvidia/pytorch:21.10-py3 +``` +### Install MONAI and dependencies +```bash +bash install.sh +``` + +## Data +[Spleen CT dataset](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2) and [Pancreas MRI dataset](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2) +from [Medical Segmentation Decathlon](http://medicaldecathlon.com/) is used. You can manually download it and save it to args.root. Otherwise, the script will automatic +download the dataset. + +## Examples +The tutorial contains two stages: searching stage and training stage. An architecture is searched and saved into a `.pth` file using `search_dints.py`. +The searched architecture will be loaded by `train_dints.py` and re-trained for spleen segmentation. + +Check all possible options: +```bash +cd ./DiNTS/ +python search_dints.py -h +python train_dints.py -h +``` + +### Searching +- Add the following script to the commands of running into docker (optional) +``` +-v /your_downloaded_data_root/Task07_Pancreas/:/workspace/data_msd/Task07_Pancreas/ +``` +- Change ``NUM_GPUS_PER_NODE`` to your number of GPUs. +- Run `bash search_dints.sh` + +### Training +- Add the following script to the commands of running into docker (Optional) +``` +-v /your_downloaded_data_root/Task09_Spleen/:/workspace/data_msd/Task09_Spleen/ +``` +- Change ``ARCH_CKPT`` to point to the architecture file (.pth) from the searching stage. +- Change ``NUM_GPUS_PER_NODE`` to your number of GPUs. +- Run `bash train_dints.sh` + +Training loss and validation metric curves are shown as follows. The experiments utilized 8 NVIDIA A100 GPUs. + +![training_loss](./Figures/training_loss.png) + +![validation_metric](./Figures/validation_metric.png) + +## Questions and bugs + +- For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. +- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues). +- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues). diff --git a/automl/DiNTS/arch_code.pth b/automl/DiNTS/arch_code.pth new file mode 100644 index 0000000000..341ba1afd7 Binary files /dev/null and b/automl/DiNTS/arch_code.pth differ diff --git a/automl/DiNTS/install.sh b/automl/DiNTS/install.sh new file mode 100644 index 0000000000..6a3fac43ee --- /dev/null +++ b/automl/DiNTS/install.sh @@ -0,0 +1,12 @@ +#!/bin/bash +clear + +pip install nibabel +pip install pandas + +# Update pip +python -m pip install -U pip +# Install scikit-image +python -m pip install -U scikit-image + +pip install git+https://github.com/Project-MONAI/MONAI#egg=monai diff --git a/automl/DiNTS/search_dints.py b/automl/DiNTS/search_dints.py new file mode 100644 index 0000000000..38cf1ca7a6 --- /dev/null +++ b/automl/DiNTS/search_dints.py @@ -0,0 +1,714 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example shows how to execute the DiNTS algorithm to search the optimal neural architecture with +distributed training based on PyTorch native `DistributedDataParallel` module. +It can run on several nodes with multiple GPU devices on every node. +This example is a real-world task based on Decathlon challenge Task09: Spleen (CT) segmentation. +Under default settings, each single GPU needs to use ~13GB memory for network training. +Main steps to set up the distributed training: +- Execute `torch.distributed.launch` to create processes on every node for every GPU. + It receives parameters as below: + `--nproc_per_node=NUM_GPUS_PER_NODE` + `--nnodes=NUM_NODES` + `--node_rank=INDEX_CURRENT_NODE` + `--master_addr="192.168.1.1"` + `--master_port=1234` + For more details, refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py. +- Use `init_process_group` to initialize every process, every GPU runs in a separate process with unique rank. + Here we use `NVIDIA NCCL` as the backend and must set `init_method="env://"` if use `torch.distributed.launch`. +- Wrap the model with `DistributedDataParallel` after moving to expected device. +- Partition dataset before training, so every rank process will only handle its own data partition. +Note: + `torch.distributed.launch` will launch `nnodes * nproc_per_node = world_size` processes in total. + Suggest setting exactly the same software environment for every node, especially `PyTorch`, `nccl`, etc. + A good practice is to use the same MONAI docker image for all nodes directly. + Example script to execute this program on every node: + python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE + --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE + --master_addr="192.168.1.1" --master_port=1234 + brats_training_ddp.py -d DIR_OF_TESTDATA + This example was tested with [Ubuntu 16.04/20.04], [NCCL 2.6.3]. +Referring to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html +Some codes are taken from https://github.com/pytorch/examples/blob/master/imagenet/main.py + +""" + +import argparse +import copy +import json +import logging +import monai +import nibabel as nib +import numpy as np +import os +import pandas as pd +import pathlib +import random +import shutil +import sys +import tempfile +import time +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import yaml +import pdb +from datetime import datetime +from glob import glob +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from monai.apps import download_and_extract +from monai.data import ( + DataLoader, + ThreadDataLoader, + decollate_batch, +) +from torch.utils.tensorboard import SummaryWriter +from monai.transforms import ( + apply_transform, + Randomizable, + Transform, + AsDiscrete, + AsDiscreted, + AddChannel, + AddChanneld, + AsChannelFirstd, + CastToTyped, + Compose, + ConcatItemsd, + CopyItemsd, + CropForegroundd, + DivisiblePadd, + EnsureChannelFirstd, + EnsureType, + EnsureTyped, + KeepLargestConnectedComponent, + Lambdad, + LoadImaged, + NormalizeIntensityd, + Orientationd, + ScaleIntensityRanged, + ThresholdIntensityd, + RandCropByLabelClassesd, + RandCropByPosNegLabeld, + RandGaussianNoised, + RandGaussianSmoothd, + RandShiftIntensityd, + RandScaleIntensityd, + RandSpatialCropd, + RandSpatialCropSamplesd, + RandFlipd, + RandRotated, + RandRotate90d, + RandZoomd, + Spacingd, + SpatialPadd, + SqueezeDimd, + ToDeviced, + ToNumpyd, + ToTensord, +) +from monai.data import Dataset, create_test_image_3d, DistributedSampler, list_data_collate, partition_dataset +from monai.inferers import sliding_window_inference +from monai.metrics import compute_meandice +from monai.utils import set_determinism +from scipy import ndimage + + +def main(): + parser = argparse.ArgumentParser(description="training") + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="checkpoint full path", + ) + parser.add_argument( + "--fold", + action="store", + required=True, + help="fold index in N-fold cross-validation", + ) + parser.add_argument( + "--json", + action="store", + required=True, + help="full path of .json file", + ) + parser.add_argument( + "--json_key", + action="store", + required=True, + help="selected key in .json data list", + ) + parser.add_argument( + "--local_rank", + required=int, + help="local process rank", + ) + parser.add_argument( + "--num_folds", + action="store", + required=True, + help="number of folds in cross-validation", + ) + parser.add_argument( + "--output_root", + action="store", + required=True, + help="output root", + ) + parser.add_argument( + "--root", + action="store", + required=True, + help="data root", + ) + args = parser.parse_args() + + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + if not os.path.exists(args.output_root): + os.makedirs(args.output_root) + + amp = True + determ = False + factor_ram_cost = 0.2 + fold = int(args.fold) + input_channels = 1 + learning_rate = 0.0002 + learning_rate_final = 0.00001 + num_images_per_batch = 1 + num_epochs = 1430 + num_epochs_per_validation = 60 + num_epochs_warmup = 715 + num_folds = int(args.num_folds) + num_patches_per_image = 1 + num_sw_batch_size = 6 + output_classes = 3 + overlap_ratio = 0.625 + patch_size = (96, 96, 96) + patch_size_valid = (96, 96, 96) + spacing = [1.0, 1.0, 1.0] + + # deterministic training + if determ: + set_determinism(seed=0) + + # initialize the distributed training process, every GPU runs in a process + dist.init_process_group(backend="nccl", init_method="env://") + + # data + if dist.get_rank() == 0: + resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + args.root.split(os.sep)[-1] + ".tar" + compressed_file = args.root + ".tar" + data_dir = args.root + root_dir = os.path.join(*args.root.split(os.sep)[:-1]) + if not os.path.exists(data_dir): + download_and_extract(resource, compressed_file, root_dir) + + dist.barrier() + world_size = dist.get_world_size() + + with open(args.json, "r") as f: + json_data = json.load(f) + + split = len(json_data[args.json_key]) // num_folds + list_train = json_data[args.json_key][:(split * fold)] + json_data[args.json_key][(split * (fold + 1)):] + list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))] + + # training data + files = [] + for _i in range(len(list_train)): + str_img = os.path.join(args.root, list_train[_i]["image"]) + str_seg = os.path.join(args.root, list_train[_i]["label"]) + + if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): + continue + + files.append({"image": str_img, "label": str_seg}) + + train_files = files + + random.shuffle(train_files) + + train_files_w = train_files[:len(train_files)//2] + train_files_w = partition_dataset(data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] + print("train_files_w:", len(train_files_w)) + + train_files_a = train_files[len(train_files)//2:] + train_files_a = partition_dataset(data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True)[dist.get_rank()] + print("train_files_a:", len(train_files_a)) + + # validation data + files = [] + for _i in range(len(list_valid)): + str_img = os.path.join(args.root, list_valid[_i]["image"]) + str_seg = os.path.join(args.root, list_valid[_i]["label"]) + + if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): + continue + + files.append({"image": str_img, "label": str_seg}) + val_files = files + val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[dist.get_rank()] + print("val_files:", len(val_files)) + + # network architecture + device = torch.device(f"cuda:{args.local_rank}") + torch.cuda.set_device(device) + + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), + CastToTyped(keys=["image"], dtype=(torch.float32)), + ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), + CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)), + CopyItemsd(keys=["label"], times=1, names=["label4crop"]), + Lambdad( + keys=["label4crop"], + func=lambda x: np.concatenate(tuple([ndimage.binary_dilation((x==_k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(output_classes)]), axis=0), + overwrite=True, + ), + EnsureTyped(keys=["image", "label"]), + CastToTyped(keys=["image"], dtype=(torch.float32)), + SpatialPadd(keys=["image", "label", "label4crop"], spatial_size=patch_size, mode=["reflect", "constant", "constant"]), + RandCropByLabelClassesd( + keys=["image", "label"], + label_key="label4crop", + num_classes=output_classes, + ratios=[1,] * output_classes, + spatial_size=patch_size, + num_samples=num_patches_per_image + ), + Lambdad(keys=["label4crop"], func=lambda x: 0), + RandRotated(keys=["image", "label"], range_x=0.3, range_y=0.3, range_z=0.3, mode=["bilinear", "nearest"], prob=0.2), + RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, mode=["trilinear", "nearest"], align_corners=[True, None], prob=0.16), + RandGaussianSmoothd(keys=["image"], sigma_x=(0.5,1.15), sigma_y=(0.5,1.15), sigma_z=(0.5,1.15), prob=0.15), + RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5), + RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), + RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), + RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5), + RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5), + RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5), + CastToTyped(keys=["image", "label"], dtype=(torch.float32, torch.uint8)), + ToTensord(keys=["image", "label"]), + ] + ) + + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), + CastToTyped(keys=["image"], dtype=(torch.float32)), + ScaleIntensityRanged(keys=["image"], a_min=-87.0, a_max=199.0, b_min=0.0, b_max=1.0, clip=True), + CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), + EnsureTyped(keys=["image", "label"]), + ToTensord(keys=["image", "label"]) + ] + ) + + train_ds_a = monai.data.CacheDataset(data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8) + train_ds_w = monai.data.CacheDataset(data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8) + val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) + + # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms) + # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms) + # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + + # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) + # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available()) + # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available()) + + train_loader_a = ThreadDataLoader(train_ds_a, num_workers=0, batch_size=num_images_per_batch, shuffle=True) + train_loader_w = ThreadDataLoader(train_ds_w, num_workers=0, batch_size=num_images_per_batch, shuffle=True) + val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) + + dints_space = monai.networks.nets.TopologySearch( + channel_mul=0.5, + num_blocks=12, + num_depths=4, + use_downsample=True, + device=device, + ) + + model = monai.networks.nets.DiNTS( + dints_space = dints_space, + in_channels=input_channels, + num_classes=output_classes, + use_downsample=True, + ) + + model = model.to(device) + + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=output_classes)]) + post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)]) + + # loss function + loss_func = monai.losses.DiceCELoss( + include_background=False, + to_onehot_y=True, + softmax=True, + squared_pred=True, + batch=True, + smooth_nr=0.00001, + smooth_dr=0.00001, + ) + + # optimizer + optimizer = torch.optim.SGD(model.weight_parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.00004) + arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a], lr=0.001, betas=(0.5, 0.999), weight_decay=0.0) + arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c], lr=0.001, betas=(0.5, 0.999), weight_decay=0.0) + + print() + + if torch.cuda.device_count() > 1: + if dist.get_rank() == 0: + print("Let's use", torch.cuda.device_count(), "GPUs!") + + model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) + + if args.checkpoint != None and os.path.isfile(args.checkpoint): + print("[info] fine-tuning pre-trained checkpoint {0:s}".format(args.checkpoint)) + model.load_state_dict(torch.load(args.checkpoint, map_location=device)) + torch.cuda.empty_cache() + else: + print("[info] training from scratch") + + # amp + if amp: + from torch.cuda.amp import autocast, GradScaler + scaler = GradScaler() + if dist.get_rank() == 0: + print("[info] amp enabled") + + # start a typical PyTorch training + val_interval = num_epochs_per_validation + best_metric = -1 + best_metric_epoch = -1 + epoch_loss_values = list() + idx_iter = 0 + metric_values = list() + + if dist.get_rank() == 0: + writer = SummaryWriter(log_dir=os.path.join(args.output_root, "Events")) + + with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: + f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") + + dataloader_a_iterator = iter(train_loader_a) + + start_time = time.time() + for epoch in range(num_epochs): + if learning_rate_final > -0.000001 and learning_rate_final < learning_rate: + # lr = (learning_rate - learning_rate_final) * (1 - epoch / (num_epochs - 1)) ** 0.9 + learning_rate_final + milestones = np.array([0.4, 0.8]) + decay = 0.5 ** np.sum([(epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > milestones]) + lr = learning_rate * decay + for param_group in optimizer.param_groups: + param_group["lr"] = lr + else: + lr = learning_rate + + lr = optimizer.param_groups[0]["lr"] + + if dist.get_rank() == 0: + print("-" * 10) + print(f"epoch {epoch + 1}/{num_epochs}") + print("learning rate is set to {}".format(lr)) + + model.train() + epoch_loss = 0 + loss_torch = torch.zeros(2, dtype=torch.float, device=device) + epoch_loss_arch = 0 + loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device) + step = 0 + + for batch_data in train_loader_w: + step += 1 + inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) + if world_size == 1: + for _ in model.weight_parameters(): + _.requires_grad = True + else: + for _ in model.module.weight_parameters(): + _.requires_grad = True + dints_space.log_alpha_a.requires_grad = False + dints_space.log_alpha_c.requires_grad = False + if amp: + with autocast(): + outputs = model(inputs) + if output_classes == 2: + loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) + else: + loss = loss_func(outputs, labels) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + outputs = model(inputs) + if output_classes == 2: + loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) + else: + loss = loss_func(outputs, labels) + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + loss_torch[0] += loss.item() + loss_torch[1] += 1.0 + epoch_len = len(train_loader_w) + idx_iter += 1 + + if dist.get_rank() == 0: + print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") + writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) + + if epoch < num_epochs_warmup: + continue + + try: + sample_a = next(dataloader_a_iterator) + except StopIteration: + dataloader_a_iterator = iter(train_loader_a) + sample_a = next(dataloader_a_iterator) + inputs_search, labels_search = sample_a["image"].to(device), sample_a["label"].to(device) + if world_size == 1: + for _ in model.weight_parameters(): + _.requires_grad = False + else: + for _ in model.module.weight_parameters(): + _.requires_grad = False + dints_space.log_alpha_a.requires_grad = True + dints_space.log_alpha_c.requires_grad = True + + # linear increase topology and RAM loss + entropy_alpha_c = torch.tensor(0.).cuda() + entropy_alpha_a = torch.tensor(0.).cuda() + ram_cost_full = torch.tensor(0.).cuda() + ram_cost_usage = torch.tensor(0.).cuda() + ram_cost_loss = torch.tensor(0.).cuda() + topology_loss = torch.tensor(0.).cuda() + + probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) + entropy_alpha_a = -((probs_a)*torch.log(probs_a + 1e-5)).mean() + entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \ + F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean() + topology_loss = dints_space.get_topology_entropy(probs_a) + + ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True) + ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape) + ram_cost_loss = torch.abs(factor_ram_cost - ram_cost_usage / ram_cost_full) + + arch_optimizer_a.zero_grad() + arch_optimizer_c.zero_grad() + + if amp: + with autocast(): + outputs_search = model(inputs_search) + if output_classes == 2: + loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) + else: + loss = loss_func(outputs_search, labels_search) + + loss += 1.0 * (1.0 * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + + 0.001 * topology_loss) + + scaler.scale(loss).backward() + scaler.step(arch_optimizer_a) + scaler.step(arch_optimizer_c) + scaler.update() + else: + outputs_search = model(inputs_search) + if output_classes == 2: + loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) + else: + loss = loss_func(outputs_search, labels_search) + + loss += 1.0 * (1.0 * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \ + + 0.001 * topology_loss) + + loss.backward() + arch_optimizer_a.step() + arch_optimizer_c.step() + + epoch_loss_arch += loss.item() + loss_torch_arch[0] += loss.item() + loss_torch_arch[1] += 1.0 + + if dist.get_rank() == 0: + print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}") + writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step) + + # synchronizes all processes and reduce results + dist.barrier() + dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) + loss_torch = loss_torch.tolist() + loss_torch_arch = loss_torch_arch.tolist() + if dist.get_rank() == 0: + loss_torch_epoch = loss_torch[0] / loss_torch[1] + print(f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}") + + if epoch < num_epochs_warmup: + continue + + loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] + print(f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}") + + if (epoch + 1) % val_interval == 0: + torch.cuda.empty_cache() + model.eval() + with torch.no_grad(): + metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) + metric_sum = 0.0 + metric_count = 0 + metric_mat = [] + val_images = None + val_labels = None + val_outputs = None + + _index = 0 + for val_data in val_loader: + val_images = val_data["image"].to(device) + val_labels = val_data["label"].to(device) + + roi_size = patch_size_valid + sw_batch_size = num_sw_batch_size + + if amp: + with torch.cuda.amp.autocast(): + pred = sliding_window_inference( + val_images, + roi_size, + sw_batch_size, + lambda x: model(x), + mode="gaussian", + overlap=overlap_ratio, + ) + else: + pred = sliding_window_inference( + val_images, + roi_size, + sw_batch_size, + lambda x: model(x), + mode="gaussian", + overlap=overlap_ratio, + ) + val_outputs = pred + + val_outputs = post_pred(val_outputs[0, ...]) + val_outputs = val_outputs[None, ...] + val_labels = post_label(val_labels[0, ...]) + val_labels = val_labels[None, ...] + + value = compute_meandice( + y_pred=val_outputs, + y=val_labels, + include_background=False + ) + + print(_index + 1, "/", len(val_loader), value) + + metric_count += len(value) + metric_sum += value.sum().item() + metric_vals = value.cpu().numpy() + if len(metric_mat) == 0: + metric_mat = metric_vals + else: + metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) + + for _c in range(output_classes - 1): + val0 = torch.nan_to_num(value[0, _c], nan=0.0) + val1 = 1.0 - torch.isnan(value[0, 0]).float() + metric[2 * _c] += val0 * val1 + metric[2 * _c + 1] += val1 + + _index += 1 + + # synchronizes all processes and reduce results + dist.barrier() + dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) + metric = metric.tolist() + if dist.get_rank() == 0: + for _c in range(output_classes - 1): + print("evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) + avg_metric = 0 + for _c in range(output_classes - 1): + avg_metric += metric[2 * _c] / metric[2 * _c + 1] + avg_metric = avg_metric / float(output_classes - 1) + print("avg_metric", avg_metric) + + if avg_metric > best_metric: + best_metric = avg_metric + best_metric_epoch = epoch + 1 + best_metric_iterations = idx_iter + + node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode() + torch.save( + { + "node_a": node_a_d, + "code_a": arch_code_a_d, + "code_a_max": arch_code_a_max_d, + "code_c": arch_code_c_d, + "iter_num": idx_iter, + "epochs": epoch + 1, + "best_dsc": best_metric, + "best_path": best_metric_iterations, + }, + os.path.join(args.output_root, "search_code_" + str(idx_iter) + ".pth"), + ) + print("saved new best metric model") + + dict_file = {} + dict_file["best_avg_dice_score"] = float(best_metric) + dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch) + dict_file["best_avg_dice_score_iteration"] = int(idx_iter) + with open(os.path.join(args.output_root, "progress.yaml"), "w") as out_file: + documents = yaml.dump(dict_file, stream=out_file) + + print( + "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( + epoch + 1, avg_metric, best_metric, best_metric_epoch + ) + ) + + current_time = time.time() + elapsed_time = (current_time - start_time) / 60.0 + with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: + f.write("{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n".format(epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter)) + + dist.barrier() + + torch.cuda.empty_cache() + + print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") + + if dist.get_rank() == 0: + writer.close() + + dist.destroy_process_group() + + return + + +if __name__ == "__main__": + main() diff --git a/automl/DiNTS/search_dints.sh b/automl/DiNTS/search_dints.sh new file mode 100644 index 0000000000..8c94ae9885 --- /dev/null +++ b/automl/DiNTS/search_dints.sh @@ -0,0 +1,47 @@ +#!/bin/bash +clear + +TASK="Task07_Pancreas" + +# DATA_ROOT="/home/dongy/Data/MSD/${TASK}" +DATA_ROOT="/workspace/data_msd/${TASK}" +JSON_PATH="${DATA_ROOT}/dataset.json" + +FOLD=4 +NUM_FOLDS=5 + +NUM_GPUS_PER_NODE=8 +NUM_NODES=1 + +if [ ${NUM_GPUS_PER_NODE} -eq 1 ] +then + export CUDA_VISIBLE_DEVICES=0 +elif [ ${NUM_GPUS_PER_NODE} -eq 2 ] +then + export CUDA_VISIBLE_DEVICES=0,1 +elif [ ${NUM_GPUS_PER_NODE} -eq 4 ] +then + export CUDA_VISIBLE_DEVICES=0,1,2,3 +elif [ ${NUM_GPUS_PER_NODE} -eq 8 ] +then + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +fi + +CHECKPOINT_ROOT="models/search_${TASK}_fold${FOLD}" +CHECKPOINT="${CHECKPOINT_ROOT}/best_metric_model.pth" +JSON_KEY="training" +OUTPUT_ROOT="models/search_${TASK}_fold${FOLD}" + +python -m torch.distributed.launch \ + --nproc_per_node=${NUM_GPUS_PER_NODE} \ + --nnodes=${NUM_NODES} \ + --node_rank=0 \ + --master_addr=localhost \ + --master_port=1234 \ + search_dints.py --checkpoint=${CHECKPOINT} \ + --fold=${FOLD} \ + --json=${JSON_PATH} \ + --json_key=${JSON_KEY} \ + --num_folds=${NUM_FOLDS} \ + --output_root=${OUTPUT_ROOT} \ + --root=${DATA_ROOT} diff --git a/automl/DiNTS/train_dints.py b/automl/DiNTS/train_dints.py new file mode 100644 index 0000000000..8d4abaa577 --- /dev/null +++ b/automl/DiNTS/train_dints.py @@ -0,0 +1,593 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example shows how to execute training from scratch with DiNTS's searched model with +distributed training based on PyTorch native `DistributedDataParallel` module. +It can run on several nodes with multiple GPU devices on every node. +This example is a real-world task based on Decathlon challenge Task09: Spleen (CT) segmentation. +Under default settings, each single GPU needs to use ~13GB memory for network training. +Main steps to set up the distributed training: +- Execute `torch.distributed.launch` to create processes on every node for every GPU. + It receives parameters as below: + `--nproc_per_node=NUM_GPUS_PER_NODE` + `--nnodes=NUM_NODES` + `--node_rank=INDEX_CURRENT_NODE` + `--master_addr="192.168.1.1"` + `--master_port=1234` + For more details, refer to https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py. +- Use `init_process_group` to initialize every process, every GPU runs in a separate process with unique rank. + Here we use `NVIDIA NCCL` as the backend and must set `init_method="env://"` if use `torch.distributed.launch`. +- Wrap the model with `DistributedDataParallel` after moving to expected device. +- Partition dataset before training, so every rank process will only handle its own data partition. +Note: + `torch.distributed.launch` will launch `nnodes * nproc_per_node = world_size` processes in total. + Suggest setting exactly the same software environment for every node, especially `PyTorch`, `nccl`, etc. + A good practice is to use the same MONAI docker image for all nodes directly. + Example script to execute this program on every node: + python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE + --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE + --master_addr="192.168.1.1" --master_port=1234 + brats_training_ddp.py -d DIR_OF_TESTDATA + This example was tested with [Ubuntu 16.04/20.04], [NCCL 2.6.3]. +Referring to: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html +Some codes are taken from https://github.com/pytorch/examples/blob/master/imagenet/main.py + +""" + +import argparse +import copy +import json +import logging +import monai +import nibabel as nib +import numpy as np +import os +import pandas as pd +import pathlib +import shutil +import sys +import tempfile +import time +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import yaml + +from datetime import datetime +from glob import glob +from monai.apps import download_and_extract +from monai.data import ( + DataLoader, + ThreadDataLoader, + decollate_batch, +) +from monai.transforms import ( + apply_transform, + Randomizable, + Transform, + AsDiscrete, + AsDiscreted, + AddChannel, + AddChanneld, + AsChannelFirstd, + CastToTyped, + Compose, + ConcatItemsd, + CopyItemsd, + CropForegroundd, + DivisiblePadd, + EnsureChannelFirstd, + EnsureType, + EnsureTyped, + KeepLargestConnectedComponent, + Lambdad, + LoadImaged, + NormalizeIntensityd, + Orientationd, + ScaleIntensityRanged, + ThresholdIntensityd, + RandCropByLabelClassesd, + RandCropByPosNegLabeld, + RandGaussianNoised, + RandGaussianSmoothd, + RandShiftIntensityd, + RandScaleIntensityd, + RandSpatialCropd, + RandSpatialCropSamplesd, + RandFlipd, + RandRotated, + RandRotate90d, + RandZoomd, + Spacingd, + SpatialPadd, + SqueezeDimd, + ToDeviced, + ToNumpyd, + ToTensord, +) +from monai.data import Dataset, create_test_image_3d, DistributedSampler, list_data_collate, partition_dataset +from monai.inferers import sliding_window_inference +from monai.metrics import compute_meandice +from monai.utils import set_determinism +from scipy import ndimage +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from torch.utils.tensorboard import SummaryWriter + + +def main(): + parser = argparse.ArgumentParser(description="training") + parser.add_argument( + "--arch_ckpt", + action="store", + required=True, + help="data root", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="checkpoint full path", + ) + parser.add_argument( + "--fold", + action="store", + required=True, + help="fold index in N-fold cross-validation", + ) + parser.add_argument( + "--json", + action="store", + required=True, + help="full path of .json file", + ) + parser.add_argument( + "--json_key", + action="store", + required=True, + help="selected key in .json data list", + ) + parser.add_argument( + "--local_rank", + required=int, + help="local process rank", + ) + parser.add_argument( + "--num_folds", + action="store", + required=True, + help="number of folds in cross-validation", + ) + parser.add_argument( + "--output_root", + action="store", + required=True, + help="output root", + ) + parser.add_argument( + "--root", + action="store", + required=True, + help="data root", + ) + args = parser.parse_args() + + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + if not os.path.exists(args.output_root): + os.makedirs(args.output_root) + + amp = True + determ = False + fold = int(args.fold) + input_channels = 1 + learning_rate = 0.0002 + learning_rate_final = 0.00001 + num_images_per_batch = 2 + num_epochs = 13500 + num_epochs_per_validation = 50 + num_folds = int(args.num_folds) + num_patches_per_image = 1 + num_sw_batch_size = 6 + output_classes = 2 + overlap_ratio = 0.5 + patch_size = (96, 96, 96) + patch_size_valid = (96, 96, 96) + spacing = [1.0, 1.0, 1.0] + + # deterministic training + if determ: + set_determinism(seed=0) + + # initialize the distributed training process, every GPU runs in a process + dist.init_process_group(backend="nccl", init_method="env://") + + # download data + if dist.get_rank() == 0: + resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + args.root.split(os.sep)[-1] + ".tar" + compressed_file = args.root + ".tar" + data_dir = args.root + root_dir = os.path.join(*args.root.split(os.sep)[:-1]) + if not os.path.exists(data_dir): + download_and_extract(resource, compressed_file, root_dir) + + dist.barrier() + + # load data list (.json) + with open(args.json, "r") as f: + json_data = json.load(f) + + split = len(json_data[args.json_key]) // num_folds + list_train = json_data[args.json_key][:(split * fold)] + json_data[args.json_key][(split * (fold + 1)):] + list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))] + + # training data + files = [] + for _i in range(len(list_train)): + str_img = os.path.join(args.root, list_train[_i]["image"]) + str_seg = os.path.join(args.root, list_train[_i]["label"]) + + if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): + continue + + files.append({"image": str_img, "label": str_seg}) + + train_files = files + train_files = partition_dataset(data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True)[dist.get_rank()] + print("train_files:", len(train_files)) + + # validation data + files = [] + for _i in range(len(list_valid)): + str_img = os.path.join(args.root, list_valid[_i]["image"]) + str_seg = os.path.join(args.root, list_valid[_i]["label"]) + + if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): + continue + + files.append({"image": str_img, "label": str_seg}) + val_files = files + val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=dist.get_world_size(), even_divisible=False)[dist.get_rank()] + print("val_files:", len(val_files)) + + # network architecture + device = torch.device(f"cuda:{args.local_rank}") + torch.cuda.set_device(device) + + train_transforms = Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), + CastToTyped(keys=["image"], dtype=(torch.float32)), + ScaleIntensityRanged(keys=["image"], a_min=-125.0, a_max=275.0, b_min=0.0, b_max=1.0, clip=True), + CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)), + CopyItemsd(keys=["label"], times=1, names=["label4crop"]), + Lambdad( + keys=["label4crop"], + func=lambda x: np.concatenate(tuple([ndimage.binary_dilation((x==_k).astype(x.dtype), iterations=48).astype(x.dtype) for _k in range(output_classes)]), axis=0), + overwrite=True, + ), + EnsureTyped(keys=["image", "label"]), + CastToTyped(keys=["image"], dtype=(torch.float32)), + SpatialPadd(keys=["image", "label", "label4crop"], spatial_size=patch_size, mode=["reflect", "constant", "constant"]), + RandCropByLabelClassesd( + keys=["image", "label"], + label_key="label4crop", + num_classes=output_classes, + ratios=[1,] * output_classes, + spatial_size=patch_size, + num_samples=num_patches_per_image + ), + Lambdad(keys=["label4crop"], func=lambda x: 0), + RandRotated(keys=["image", "label"], range_x=0.3, range_y=0.3, range_z=0.3, mode=["bilinear", "nearest"], prob=0.2), + RandZoomd(keys=["image", "label"], min_zoom=0.8, max_zoom=1.2, mode=["trilinear", "nearest"], align_corners=[True, None], prob=0.16), + RandGaussianSmoothd(keys=["image"], sigma_x=(0.5,1.15), sigma_y=(0.5,1.15), sigma_z=(0.5,1.15), prob=0.15), + RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5), + RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), + RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), + RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5), + RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5), + RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5), + CastToTyped(keys=["image", "label"], dtype=(torch.float32, torch.uint8)), + ToTensord(keys=["image", "label"]), + ] + ) + + val_transforms = Compose( + [ + LoadImaged(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd(keys=["image", "label"], pixdim=spacing, mode=("bilinear", "nearest"), align_corners=(True, True)), + CastToTyped(keys=["image"], dtype=(torch.float32)), + ScaleIntensityRanged(keys=["image"], a_min=-125.0, a_max=275.0, b_min=0.0, b_max=1.0, clip=True), + CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), + EnsureTyped(keys=["image", "label"]), + ToTensord(keys=["image", "label"]) + ] + ) + + # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + + train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8) + val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) + + train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=num_images_per_batch, shuffle=True) + val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) + + ckpt = torch.load(args.arch_ckpt) + node_a = ckpt['node_a'] + arch_code_a = ckpt['code_a'] + arch_code_c = ckpt['code_c'] + + dints_space = monai.networks.nets.TopologyInstance( + channel_mul=1.0, + num_blocks=12, + num_depths=4, + use_downsample=True, + arch_code=[arch_code_a, arch_code_c], + device=device, + ) + + model = monai.networks.nets.DiNTS( + dints_space=dints_space, + in_channels=input_channels, + num_classes=output_classes, + use_downsample=True, + node_a=node_a, + ) + + model = model.to(device) + + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=output_classes)]) + post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)]) + + # loss function + loss_func = monai.losses.DiceCELoss( + include_background=False, + to_onehot_y=True, + softmax=True, + squared_pred=True, + batch=True, + smooth_nr=0.00001, + smooth_dr=0.00001, + ) + + # optimizer + optimizer = torch.optim.AdamW( + model.parameters(), + lr=learning_rate + ) + + print() + + if torch.cuda.device_count() > 1: + if dist.get_rank() == 0: + print("Let's use", torch.cuda.device_count(), "GPUs!") + + model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) + + if args.checkpoint != None and os.path.isfile(args.checkpoint): + print("[info] fine-tuning pre-trained checkpoint {0:s}".format(args.checkpoint)) + model.load_state_dict(torch.load(args.checkpoint, map_location=device)) + torch.cuda.empty_cache() + else: + print("[info] training from scratch") + + # amp + if amp: + from torch.cuda.amp import autocast, GradScaler + scaler = GradScaler() + if dist.get_rank() == 0: + print("[info] amp enabled") + + # start a typical PyTorch training + val_interval = num_epochs_per_validation + best_metric = -1 + best_metric_epoch = -1 + epoch_loss_values = list() + idx_iter = 0 + metric_values = list() + + if dist.get_rank() == 0: + writer = SummaryWriter(log_dir=os.path.join(args.output_root, "Events")) + + with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: + f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") + + start_time = time.time() + for epoch in range(num_epochs): + if learning_rate_final > -0.000001 and learning_rate_final < learning_rate: + lr = (learning_rate - learning_rate_final) * (1 - epoch / (num_epochs - 1)) ** 0.9 + learning_rate_final + for param_group in optimizer.param_groups: + param_group["lr"] = lr + else: + lr = learning_rate + + lr = optimizer.param_groups[0]["lr"] + + if dist.get_rank() == 0: + print("-" * 10) + print(f"epoch {epoch + 1}/{num_epochs}") + print('learning rate is set to {}'.format(lr)) + + model.train() + epoch_loss = 0 + loss_torch = torch.zeros(2, dtype=torch.float, device=device) + step = 0 + for batch_data in train_loader: + step += 1 + inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) + + for param in model.parameters(): + param.grad = None + + if amp: + with autocast(): + outputs = model(inputs) + if output_classes == 2: + loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) + else: + loss = loss_func(outputs, labels) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + outputs = model(inputs) + if output_classes == 2: + loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) + else: + loss = loss_func(outputs, labels) + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + loss_torch[0] += loss.item() + loss_torch[1] += 1.0 + epoch_len = len(train_loader) + idx_iter += 1 + + if dist.get_rank() == 0: + print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") + writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) + + # synchronizes all processes and reduce results + dist.barrier() + dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) + loss_torch = loss_torch.tolist() + if dist.get_rank() == 0: + loss_torch_epoch = loss_torch[0] / loss_torch[1] + print(f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}") + + if (epoch + 1) % val_interval == 0: + torch.cuda.empty_cache() + model.eval() + with torch.no_grad(): + metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) + metric_sum = 0.0 + metric_count = 0 + metric_mat = [] + val_images = None + val_labels = None + val_outputs = None + + _index = 0 + for val_data in val_loader: + val_images = val_data["image"].to(device) + val_labels = val_data["label"].to(device) + + roi_size = patch_size_valid + sw_batch_size = num_sw_batch_size + + # test time augmentation + ct = 1.0 + with torch.cuda.amp.autocast(): + pred = sliding_window_inference( + val_images, + roi_size, + sw_batch_size, + lambda x: model(x), + mode="gaussian", + overlap=overlap_ratio, + ) + + val_outputs = pred / ct + + val_outputs = post_pred(val_outputs[0, ...]) + val_outputs = val_outputs[None, ...] + val_labels = post_label(val_labels[0, ...]) + val_labels = val_labels[None, ...] + + value = compute_meandice( + y_pred=val_outputs, + y=val_labels, + include_background=False + ) + + print(_index + 1, "/", len(val_loader), value) + + metric_count += len(value) + metric_sum += value.sum().item() + metric_vals = value.cpu().numpy() + if len(metric_mat) == 0: + metric_mat = metric_vals + else: + metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) + + for _c in range(output_classes - 1): + val0 = torch.nan_to_num(value[0, _c], nan=0.0) + val1 = 1.0 - torch.isnan(value[0, 0]).float() + metric[2 * _c] += val0 * val1 + metric[2 * _c + 1] += val1 + + _index += 1 + + # synchronizes all processes and reduce results + dist.barrier() + dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) + metric = metric.tolist() + if dist.get_rank() == 0: + for _c in range(output_classes - 1): + print("evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) + avg_metric = 0 + for _c in range(output_classes - 1): + avg_metric += metric[2 * _c] / metric[2 * _c + 1] + avg_metric = avg_metric / float(output_classes - 1) + print("avg_metric", avg_metric) + + if avg_metric > best_metric: + best_metric = avg_metric + best_metric_epoch = epoch + 1 + torch.save(model.state_dict(), os.path.join(args.output_root, "best_metric_model.pth")) + print("saved new best metric model") + + dict_file = {} + dict_file["best_avg_dice_score"] = float(best_metric) + dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch) + dict_file["best_avg_dice_score_iteration"] = int(idx_iter) + with open(os.path.join(args.output_root, "progress.yaml"), "w") as out_file: + documents = yaml.dump(dict_file, stream=out_file) + + print( + "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( + epoch + 1, avg_metric, best_metric, best_metric_epoch + ) + ) + + current_time = time.time() + elapsed_time = (current_time - start_time) / 60.0 + with open(os.path.join(args.output_root, "accuracy_history.csv"), "a") as f: + f.write("{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n".format(epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter)) + + dist.barrier() + + torch.cuda.empty_cache() + + print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") + + if dist.get_rank() == 0: + writer.close() + + dist.destroy_process_group() + + return + + +if __name__ == "__main__": + main() diff --git a/automl/DiNTS/train_dints.sh b/automl/DiNTS/train_dints.sh new file mode 100644 index 0000000000..3a3e4caa91 --- /dev/null +++ b/automl/DiNTS/train_dints.sh @@ -0,0 +1,49 @@ +#!/bin/bash +clear + +TASK="Task09_Spleen" + +ARCH_CKPT="arch_code.pth" +# DATA_ROOT="/home/dongy/Data/MSD/${TASK}" +DATA_ROOT="/workspace/data_msd/${TASK}" +JSON_PATH="${DATA_ROOT}/dataset.json" + +FOLD=0 +NUM_FOLDS=5 + +NUM_GPUS_PER_NODE=8 +NUM_NODES=1 + +if [ ${NUM_GPUS_PER_NODE} -eq 1 ] +then + export CUDA_VISIBLE_DEVICES=0 +elif [ ${NUM_GPUS_PER_NODE} -eq 2 ] +then + export CUDA_VISIBLE_DEVICES=0,1 +elif [ ${NUM_GPUS_PER_NODE} -eq 4 ] +then + export CUDA_VISIBLE_DEVICES=0,1,2,3 +elif [ ${NUM_GPUS_PER_NODE} -eq 8 ] +then + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +fi + +CHECKPOINT_ROOT="models/${TASK}_fold${FOLD}" +CHECKPOINT="${CHECKPOINT_ROOT}/best_metric_model.pth" +JSON_KEY="training" +OUTPUT_ROOT="models/${TASK}_fold${FOLD}" + +python -m torch.distributed.launch \ + --nproc_per_node=${NUM_GPUS_PER_NODE} \ + --nnodes=${NUM_NODES} \ + --node_rank=0 \ + --master_addr=localhost \ + --master_port=1234 \ + train_dints.py --arch_ckpt=${ARCH_CKPT} \ + --checkpoint=${CHECKPOINT} \ + --fold=${FOLD} \ + --json=${JSON_PATH} \ + --json_key=${JSON_KEY} \ + --num_folds=${NUM_FOLDS} \ + --output_root=${OUTPUT_ROOT} \ + --root=${DATA_ROOT} diff --git a/automl/README.md b/automl/README.md index 7ea5c24c57..30208a0652 100644 --- a/automl/README.md +++ b/automl/README.md @@ -1,4 +1,24 @@ # AutoML -## DiNTS: Differentiable neural network topology search -This section is coming soon. +Here we showcase the most recent AutoML techniques in medical imaging based on MONAI modules. + +## [DiNTS: Differentiable neural network topology search](./DiNTS) +Recently, neural architecture search (NAS) has been applied to automatically +search high-performance networks for medical image segmentation. The NAS search +space usually contains a network topology level (controlling connections among +cells with different spatial scales) and a cell level (operations within each +cell). Existing methods either require long searching time for large-scale 3D +image datasets, or are limited to pre-defined topologies (such as U-shaped or +single-path). + +In this work, we focus on three important aspects of NAS in 3D medical image +segmentation: flexible multi-path network topology, high search efficiency, and +budgeted GPU memory usage. A novel differentiable search framework is proposed +to support fast gradient-based search within a highly flexible network topology +search space. The discretization of the searched optimal continuous model in +differentiable scheme may produce a sub-optimal final discrete model +(discretization gap). Therefore, we propose a topology loss to alleviate this +problem. In addition, the GPU memory usage for the searched 3D model is limited +with budget constraints during search. The Differentiable Network Topology +Search scheme (DiNTS) was evaluated on the Medical Segmentation Decathlon (MSD) +challenge with state-of-the-art performance. diff --git a/pathology/multiple_instance_learning/README.md b/pathology/multiple_instance_learning/README.md index bca59bbe2c..c38e62ca8b 100644 --- a/pathology/multiple_instance_learning/README.md +++ b/pathology/multiple_instance_learning/README.md @@ -118,6 +118,6 @@ Expected validation QWK metric ## Questions and bugs -- For questions relating to the use of MONAI, please us our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. +- For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. - For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues). - For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).