diff --git a/federated_learning/nvflare/nvflare_spleen_example/3-Admin.ipynb b/federated_learning/nvflare/nvflare_spleen_example/3-Admin.ipynb index 342a71b8f2..193b2edc0a 100644 --- a/federated_learning/nvflare/nvflare_spleen_example/3-Admin.ipynb +++ b/federated_learning/nvflare/nvflare_spleen_example/3-Admin.ipynb @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -46,17 +46,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "a working folder for the admin exists!!!\n" - ] - } - ], + "outputs": [], "source": [ "admin_path = \"poc/admin/startup/\"\n", "\n", @@ -73,50 +65,28 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mkdir -p poc/admin/transfer\n", - "!cp -r hello_monai/ poc/admin/transfer/" + "!cp -r hello-monai/ poc/admin/transfer/" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['hello_monai']" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "os.listdir(\"poc/admin/transfer/\")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['fl_admin.sh']" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "os.listdir(admin_path)" ] @@ -144,23 +114,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - " Open a new terminal" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "HTML(' Open a new terminal')" ] @@ -192,51 +148,18 @@ "\n", "The commands can be:\n", "```\n", - "upload_app hello_monai\n", + "upload_app hello-monai\n", "set_run_number 1\n", - "deploy_app hello_monai server\n", - "deploy_app hello_monai client\n", - "```\n", - "\n", - "Now, let's check if the folder has been distributed into the server and all client(s):" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "config files on server: ['app_server', 'fl_app.txt']\n", - " \n", - "config files on site-1: ['app_site-1', 'fl_app.txt']\n", - " \n", - "config files on site-2: ['fl_app.txt', 'app_site-2']\n", - " \n" - ] - } - ], - "source": [ - "run_file = \"run_1\"\n", - "\n", - "poc_path = \"poc/\"\n", - "\n", - "for name in [\"server\", \"site-1\", \"site-2\"]:\n", - " path = os.path.join(poc_path, name, run_file)\n", - " if os.path.exists(path):\n", - " print(\"config files on {}: {}\".\n", - " format(name, os.listdir(path)))\n", - " print(\" \")" + "deploy_app hello-monai server\n", + "deploy_app hello-monai client\n", + "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This example prepares two different data list files: `dataset_part1.json` and `dataset_part2.json`, and they have the same validation set and totally different training set. The default file used in `config_train.json` is `config/dataset_part1.json`. Therefore, if you want to let two clients train on different data, you can switch to use `dataset_part2.json` for `org1-b`.\n", + "This example prepares two different data list files: `dataset_part1.json` and `dataset_part2.json`, and they have the same validation set and totally different training sets. The default file used in `config_train.json` is `config/dataset_part1.json`. Therefore, if you want to let two clients train on different data, you can switch to use `dataset_part2.json` for `org1-b`.\n", "\n", "[Link to site-1 config](poc/site-1/run_1/app_site-1/config/config_train.json)\n", "\n", @@ -250,13 +173,12 @@ "### (Optional) Copy Dataset\n", "\n", "After starting a client (for example `site-1`), the Spleen dataset will be downloaded into:\n", - "`run_1/app_site-1/Task09_Spleen.tar`.\n", + "`poc/site-1/`.\n", "\n", - "To prevent repeatedly downloading the dataset, you can copy the uncompressed `Task09_Spleen` into the corresponding place after running the `deploy_app` command.\n", - "For example:\n", + "If you already have the `Task09_Spleen`, you can copy it into the corresponding place directly to prevent repeatedly download. For example:\n", "\n", "```\n", - "cp -r /path-to-dataset/Task09_Spleen poc/site-1/run_2/app_site-1/\n", + "cp -r /path-to-dataset/Task09_Spleen poc/site-1/\n", "```" ] }, diff --git a/federated_learning/nvflare/nvflare_spleen_example/README.md b/federated_learning/nvflare/nvflare_spleen_example/README.md index 9f27a8fdaa..a4574ff48e 100644 --- a/federated_learning/nvflare/nvflare_spleen_example/README.md +++ b/federated_learning/nvflare/nvflare_spleen_example/README.md @@ -2,13 +2,15 @@ ## Brief Introduction -This repository contains an end-to-end Federated training example based on MONAI trainers and [NVIDIA FLARE](https://github.com/nvidia/nvflare). +This repository contains an end-to-end Federated training example based on MONAI trainers and [NVIDIA FLARE](https://github.com/nvidia/nvflare). Please also download the `hello-monai` folder in [NVFlare/examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/hello-monai), and copy it into this directory. + + This example requires Python 3.8.10 -Inside this folder: +Inside this directory: - All Jupiter notebooks are used to build an FL experiment step-by-step. -- `hello-monai` is a folder containing all required config files for the experiment (in `config/`) and the customized trainer (in `custom`) and its components. +- hello-monai (needs to be downloaded) is a folder containing all required config files for the experiment (in `config/`) and the customized trainer (in `custom`) and its components. ## Installation diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_client.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_client.json deleted file mode 100644 index 98bb905fbc..0000000000 --- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_client.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "format_version": 2, - "executors": [ - { - "tasks": ["train"], - "executor": { - "path": "monai_trainer.MONAITrainer", - "args": { - "aggregation_epochs": 10 - } - } - } - ], - "task_result_filters": [ - ], - "task_data_filters": [ - ], - "components": [ - ] -} diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_server.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_server.json deleted file mode 100644 index 555b919dee..0000000000 --- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_server.json +++ /dev/null @@ -1,63 +0,0 @@ -{ - "format_version": 2, - - "server": { - "heart_beat_timeout": 600 - }, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - { - "id": "persistor", - "name": "PTFileModelPersistor", - "args": { - "model": { - "path": "monai.networks.nets.unet.UNet", - "args": { - "dimensions": 3, - "in_channels": 1, - "out_channels": 2, - "channels": [16, 32, 64, 128, 256], - "strides": [2, 2, 2, 2], - "num_res_units": 2, - "norm": "batch" - } - } - } - }, - { - "id": "shareable_generator", - "name": "FullModelShareableGenerator", - "args": {} - }, - { - "id": "aggregator", - "name": "AccumulateWeightedAggregator", - "args": { - "aggregation_weights": { - "site-1": 1.0, - "site-2": 0.5 - }, - "expected_data_kind": "WEIGHTS" - } - } - ], - "workflows": [ - { - "id": "scatter_and_gather", - "name": "ScatterAndGather", - "args": { - "min_clients" : 1, - "num_rounds" : 100, - "start_round": 0, - "wait_time_after_min_received": 10, - "aggregator_id": "aggregator", - "persistor_id": "persistor", - "shareable_generator_id": "shareable_generator", - "train_task_name": "train", - "train_timeout": 0, - "ignore_result_error": true - } - } - ] -} diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_train.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_train.json deleted file mode 100644 index 9724632477..0000000000 --- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_train.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "max_epochs": 100, - "learning_rate": 2e-4, - "amp": true, - "use_gpu": true, - "val_interval": 5, - "data_list_json_file": "config/dataset_part1.json", - "ckpt_dir": "models" -} diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part1.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part1.json deleted file mode 100644 index f88ccc8402..0000000000 --- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part1.json +++ /dev/null @@ -1 +0,0 @@ -{"training": [{"image": "./imagesTr/spleen_40.nii.gz", "label": "./labelsTr/spleen_40.nii.gz"}, {"image": "./imagesTr/spleen_44.nii.gz", "label": "./labelsTr/spleen_44.nii.gz"}, {"image": "./imagesTr/spleen_38.nii.gz", "label": "./labelsTr/spleen_38.nii.gz"}, {"image": "./imagesTr/spleen_25.nii.gz", "label": "./labelsTr/spleen_25.nii.gz"}, {"image": "./imagesTr/spleen_13.nii.gz", "label": "./labelsTr/spleen_13.nii.gz"}, {"image": "./imagesTr/spleen_6.nii.gz", "label": "./labelsTr/spleen_6.nii.gz"}, {"image": "./imagesTr/spleen_19.nii.gz", "label": "./labelsTr/spleen_19.nii.gz"}, {"image": "./imagesTr/spleen_24.nii.gz", "label": "./labelsTr/spleen_24.nii.gz"}, {"image": "./imagesTr/spleen_52.nii.gz", "label": "./labelsTr/spleen_52.nii.gz"}, {"image": "./imagesTr/spleen_9.nii.gz", "label": "./labelsTr/spleen_9.nii.gz"}, {"image": "./imagesTr/spleen_10.nii.gz", "label": "./labelsTr/spleen_10.nii.gz"}, {"image": "./imagesTr/spleen_41.nii.gz", "label": "./labelsTr/spleen_41.nii.gz"}, {"image": "./imagesTr/spleen_60.nii.gz", "label": "./labelsTr/spleen_60.nii.gz"}, {"image": "./imagesTr/spleen_56.nii.gz", "label": "./labelsTr/spleen_56.nii.gz"}, {"image": "./imagesTr/spleen_26.nii.gz", "label": "./labelsTr/spleen_26.nii.gz"}, {"image": "./imagesTr/spleen_17.nii.gz", "label": "./labelsTr/spleen_17.nii.gz"}], "validation": [{"image": "./imagesTr/spleen_31.nii.gz", "label": "./labelsTr/spleen_31.nii.gz"}, {"image": "./imagesTr/spleen_33.nii.gz", "label": "./labelsTr/spleen_33.nii.gz"}, {"image": "./imagesTr/spleen_8.nii.gz", "label": "./labelsTr/spleen_8.nii.gz"}, {"image": "./imagesTr/spleen_21.nii.gz", "label": "./labelsTr/spleen_21.nii.gz"}, {"image": "./imagesTr/spleen_22.nii.gz", "label": "./labelsTr/spleen_22.nii.gz"}, {"image": "./imagesTr/spleen_2.nii.gz", "label": "./labelsTr/spleen_2.nii.gz"}, {"image": "./imagesTr/spleen_3.nii.gz", "label": "./labelsTr/spleen_3.nii.gz"}, {"image": "./imagesTr/spleen_45.nii.gz", "label": "./labelsTr/spleen_45.nii.gz"}, {"image": "./imagesTr/spleen_32.nii.gz", "label": "./labelsTr/spleen_32.nii.gz"}]} diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part2.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part2.json deleted file mode 100644 index aab7f22875..0000000000 --- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part2.json +++ /dev/null @@ -1 +0,0 @@ -{"training": [{"image": "./imagesTr/spleen_16.nii.gz", "label": "./labelsTr/spleen_16.nii.gz"}, {"image": "./imagesTr/spleen_20.nii.gz", "label": "./labelsTr/spleen_20.nii.gz"}, {"image": "./imagesTr/spleen_18.nii.gz", "label": "./labelsTr/spleen_18.nii.gz"}, {"image": "./imagesTr/spleen_46.nii.gz", "label": "./labelsTr/spleen_46.nii.gz"}, {"image": "./imagesTr/spleen_27.nii.gz", "label": "./labelsTr/spleen_27.nii.gz"}, {"image": "./imagesTr/spleen_49.nii.gz", "label": "./labelsTr/spleen_49.nii.gz"}, {"image": "./imagesTr/spleen_62.nii.gz", "label": "./labelsTr/spleen_62.nii.gz"}, {"image": "./imagesTr/spleen_53.nii.gz", "label": "./labelsTr/spleen_53.nii.gz"}, {"image": "./imagesTr/spleen_12.nii.gz", "label": "./labelsTr/spleen_12.nii.gz"}, {"image": "./imagesTr/spleen_47.nii.gz", "label": "./labelsTr/spleen_47.nii.gz"}, {"image": "./imagesTr/spleen_28.nii.gz", "label": "./labelsTr/spleen_28.nii.gz"}, {"image": "./imagesTr/spleen_61.nii.gz", "label": "./labelsTr/spleen_61.nii.gz"}, {"image": "./imagesTr/spleen_29.nii.gz", "label": "./labelsTr/spleen_29.nii.gz"}, {"image": "./imagesTr/spleen_14.nii.gz", "label": "./labelsTr/spleen_14.nii.gz"}, {"image": "./imagesTr/spleen_63.nii.gz", "label": "./labelsTr/spleen_63.nii.gz"}, {"image": "./imagesTr/spleen_59.nii.gz", "label": "./labelsTr/spleen_59.nii.gz"}], "validation": [{"image": "./imagesTr/spleen_31.nii.gz", "label": "./labelsTr/spleen_31.nii.gz"}, {"image": "./imagesTr/spleen_33.nii.gz", "label": "./labelsTr/spleen_33.nii.gz"}, {"image": "./imagesTr/spleen_8.nii.gz", "label": "./labelsTr/spleen_8.nii.gz"}, {"image": "./imagesTr/spleen_21.nii.gz", "label": "./labelsTr/spleen_21.nii.gz"}, {"image": "./imagesTr/spleen_22.nii.gz", "label": "./labelsTr/spleen_22.nii.gz"}, {"image": "./imagesTr/spleen_2.nii.gz", "label": "./labelsTr/spleen_2.nii.gz"}, {"image": "./imagesTr/spleen_3.nii.gz", "label": "./labelsTr/spleen_3.nii.gz"}, {"image": "./imagesTr/spleen_45.nii.gz", "label": "./labelsTr/spleen_45.nii.gz"}, {"image": "./imagesTr/spleen_32.nii.gz", "label": "./labelsTr/spleen_32.nii.gz"}]} diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/monai_trainer.py b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/monai_trainer.py deleted file mode 100644 index 6593897f91..0000000000 --- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/monai_trainer.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Dict - -import numpy as np -import torch -from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable -from nvflare.apis.event_type import EventType -from nvflare.apis.executor import Executor -from nvflare.apis.fl_constant import FLContextKey, ReturnCode -from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable -from nvflare.apis.signal import Signal -from nvflare.app_common.app_constant import AppConstants - -from train_configer import TrainConfiger - - -class MONAITrainer(Executor): - """ - This class implements a MONAI based trainer that can be used for Federated Learning with NVFLARE. - - Args: - aggregation_epochs: the number of training epochs for a round. Defaults to 1. - - """ - - def __init__( - self, - aggregation_epochs: int = 1, - train_task_name: str = AppConstants.TASK_TRAIN, - ): - super().__init__() - """ - Trainer init happens at the very beginning, only the basic info regarding the trainer is set here, - and the actual run has not started at this point. - """ - self.aggregation_epochs = aggregation_epochs - self._train_task_name = train_task_name - - def _initialize_trainer(self, fl_ctx: FLContext): - """ - The trainer's initialization function. At the beginning of a FL experiment, - the train and evaluate engines, as well as train context and FL context - should be initialized. - """ - # Initialize train and evaluation engines. - app_root = fl_ctx.get_prop(FLContextKey.APP_ROOT) - fl_args = fl_ctx.get_prop(FLContextKey.ARGS) - # will update multi-gpu supports later - # num_gpus = fl_ctx.get_prop(AppConstants.NUMBER_OF_GPUS, 1) - # self.multi_gpu = num_gpus > 1 - self.client_name = fl_ctx.get_identity_name() - self.log_info( - fl_ctx, - f"Client {self.client_name} initialized at \n {app_root} \n with args: {fl_args}", - ) - conf = TrainConfiger( - app_root=app_root, - wf_config_file_name=fl_args.train_config, - local_rank=fl_args.local_rank, - ) - conf.configure() - - # train_engine, and eval_engine are MONAI engines that will be used for training and validation. - # The corresponding training/validation settings, such as transforms, network and dataset - # are contained in `TrainConfiger`. - # The engine will be started when `.run()` is called, and when `.terminate()` is called, - # it will be completely terminated after the current iteration is finished. - self.train_engine = conf.train_engine - self.eval_engine = conf.eval_engine - - def assign_current_model(self, model_weights: Dict[str, np.ndarray]): - """ - This function is used to load provided weights for the network. - Before loading weights, tensors might need to be reshaped to support HE for secure aggregation. - More info of HE: - https://github.com/NVIDIA/clara-train-examples/blob/master/PyTorch/NoteBooks/FL/Homomorphic_Encryption.ipynb - - """ - net = self.train_engine.network - - local_var_dict = net.state_dict() - model_keys = model_weights.keys() - for var_name in local_var_dict: - if var_name in model_keys: - weights = model_weights[var_name] - try: - local_var_dict[var_name] = torch.as_tensor( - np.reshape(weights, local_var_dict[var_name].shape) - ) - except Exception as e: - raise ValueError( - "Convert weight from {} failed with error: {}".format( - var_name, str(e) - ) - ) - - net.load_state_dict(local_var_dict) - - def extract_model(self) -> Dict[str, np.ndarray]: - """ - This function is used to extract weights of the network. - The extracted weights will be converted into a numpy array based dict. - """ - net = self.train_engine.network - local_state_dict = net.state_dict() - local_model_dict = {} - for var_name in local_state_dict: - try: - local_model_dict[var_name] = local_state_dict[var_name].cpu().numpy() - except Exception as e: - raise ValueError( - "Convert weight from {} failed with error: {}".format( - var_name, str(e) - ) - ) - - return local_model_dict - - def generate_shareable(self): - """ - This function is used to generate a DXO instance. - The instance can contain not only model weights, but also - some additional information that clients want to share. - """ - # update meta, NUM_STEPS_CURRENT_ROUND is needed for aggregation. - if self.achieved_meta is None: - meta = {MetaKey.NUM_STEPS_CURRENT_ROUND: self.current_iters} - else: - meta = self.achieved_meta - meta[MetaKey.NUM_STEPS_CURRENT_ROUND] = self.current_iters - return DXO( - data_kind=DataKind.WEIGHTS, - data=self.extract_model(), - meta=meta, - ).to_shareable() - - def handle_event(self, event_type: str, fl_ctx: FLContext): - """ - This function is an extended function from the super class. - It is used to handle two events: - - 1) `START_RUN`. At the start point of a FL experiment, - necessary components should be initialized. - 2) `ABORT_TASK`, when this event is fired, the running engines - should be terminated (this example uses MONAI engines to do train - and validation, and the engines can be terminated from another thread. - If the solution does not provide any way to interrupt/end the execution, - handle this event is not feasible). - - - Args: - event_type: the type of event that will be fired. In MONAITrainer, - only `START_RUN` and `END_RUN` need to be handled. - fl_ctx: an `FLContext` object. - - """ - if event_type == EventType.START_RUN: - self._initialize_trainer(fl_ctx) - elif event_type == EventType.ABORT_TASK: - # This event is fired to abort the current execution task. We are using the ignite engine to run the task. - # Unfortunately the ignite engine does not support the abort of task right now. We have to wait until - # the current task finishes. - pass - elif event_type == EventType.END_RUN: - self.eval_engine.terminate() - self.train_engine.terminate() - - def _abort_execution(self) -> Shareable: - shareable = Shareable() - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable - - def execute( - self, - task_name: str, - shareable: Shareable, - fl_ctx: FLContext, - abort_signal: Signal, - ) -> Shareable: - """ - This function is an extended function from the super class. - As a supervised learning based trainer, the execute function will run - evaluate and train engines based on model weights from `shareable`. - After fininshing training, a new `Shareable` object will be submitted - to server for aggregation. - - Args: - task_name: decide which task will be executed. - shareable: the `Shareable` object acheived from server. - fl_ctx: the `FLContext` object achieved from server. - abort_signal: if triggered, the training will be aborted. In order to interrupt the training/validation - state, a separate is used to check the signal information every few seconds. The implementation is - shown in the `handle_event` function. - Returns: - a new `Shareable` object to be submitted to server for aggregation. - """ - if task_name == self._train_task_name: - # convert shareable into DXO instance - dxo = from_shareable(shareable) - # check if dxo is valid. - if not isinstance(dxo, DXO): - self.log_exception( - fl_ctx, f"dxo excepted type DXO. Got {type(dxo)} instead." - ) - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable - - # ensure data kind is weights. - if not dxo.data_kind == DataKind.WEIGHTS: - self.log_exception( - fl_ctx, - f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.", - ) - shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION) - return shareable - - # load weights from dxo - self.assign_current_model(dxo.data) - # collect meta from dxo - self.achieved_meta = dxo.meta - - # set engine state max epochs. - self.train_engine.state.max_epochs = ( - self.train_engine.state.epoch + self.aggregation_epochs - ) - # get current iteration when a round starts - iter_of_start_time = self.train_engine.state.iteration - - # execute validation at the beginning of every round - self.eval_engine.run(self.train_engine.state.epoch + 1) - - # check abort signal after validation - if abort_signal.triggered: - return self._abort_execution() - - self.train_engine.run() - - # check abort signal after train - if abort_signal.triggered: - return self._abort_execution() - - # calculate current iteration and epoch data after training. - self.current_iters = self.train_engine.state.iteration - iter_of_start_time - # create a new `Shareable` object - return self.generate_shareable() - else: - # If unknown task name, set ReturnCode accordingly. - shareable = Shareable() - shareable.set_return_code(ReturnCode.TASK_UNKNOWN) - return shareable diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/train_configer.py b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/train_configer.py deleted file mode 100644 index ca7f315b1f..0000000000 --- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/train_configer.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os - -import torch -from monai.apps.utils import download_and_extract -from monai.data import CacheDataset, DataLoader, load_decathlon_datalist -from monai.engines import SupervisedEvaluator, SupervisedTrainer -from monai.handlers import ( - CheckpointSaver, - LrScheduleHandler, - MeanDice, - StatsHandler, - TensorBoardStatsHandler, - ValidationHandler, - from_engine, -) -from monai.inferers import SimpleInferer, SlidingWindowInferer -from monai.losses import DiceLoss -from monai.networks.layers import Norm -from monai.networks.nets import UNet -from monai.transforms import ( - Activationsd, - AsDiscreted, - Compose, - CropForegroundd, - EnsureChannelFirstd, - LoadImaged, - Orientationd, - RandCropByPosNegLabeld, - ScaleIntensityRanged, - Spacingd, - ToTensord, -) - - -class TrainConfiger: - """ - This class is used to config the necessary components of train and evaluate engines - for MONAI trainer. - Please check the implementation of `SupervisedEvaluator` and `SupervisedTrainer` - from `monai.engines` and determine which components can be used. - Args: - app_root: root folder path of config files. - wf_config_file_name: json file name of the workflow config file. - """ - - def __init__( - self, - app_root: str, - wf_config_file_name: str, - local_rank: int = 0, - dataset_folder_name: str = "Task09_Spleen", - ): - with open(os.path.join(app_root, wf_config_file_name)) as file: - wf_config = json.load(file) - - self.wf_config = wf_config - """ - config Args: - max_epochs: the total epoch number for trainer to run. - learning_rate: the learning rate for optimizer. - dataset_dir: the directory containing the dataset. if `dataset_folder_name` does not - exist in the directory, it will be downloaded first. - data_list_json_file: the data list json file. - val_interval: the interval (number of epochs) to do validation. - ckpt_dir: the directory to save the checkpoint. - amp: whether to enable auto-mixed-precision training. - use_gpu: whether to use GPU in training. - - """ - self.max_epochs = wf_config["max_epochs"] - self.learning_rate = wf_config["learning_rate"] - self.data_list_json_file = wf_config["data_list_json_file"] - self.val_interval = wf_config["val_interval"] - self.ckpt_dir = wf_config["ckpt_dir"] - self.amp = wf_config["amp"] - self.use_gpu = wf_config["use_gpu"] - self.local_rank = local_rank - self.app_root = app_root - self.dataset_folder_name = dataset_folder_name - if not os.path.exists(os.path.join(app_root, self.dataset_folder_name)): - self.download_spleen_dataset() - - def set_device(self): - device = torch.device("cuda" if self.use_gpu else "cpu") - self.device = device - - def download_spleen_dataset(self): - url = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar" - name = os.path.join(self.app_root, self.dataset_folder_name) - tarfile_name = f"{name}.tar" - download_and_extract(url=url, filepath=tarfile_name, output_dir=self.app_root) - - def configure(self): - self.set_device() - network = UNet( - dimensions=3, - in_channels=1, - out_channels=2, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - norm=Norm.BATCH, - ).to(self.device) - train_transforms = Compose( - [ - LoadImaged(keys=("image", "label")), - EnsureChannelFirstd(keys=("image", "label")), - Spacingd( - keys=["image", "label"], - pixdim=(1.5, 1.5, 2.0), - mode=("bilinear", "nearest"), - ), - Orientationd(keys=["image", "label"], axcodes="RAS"), - ScaleIntensityRanged( - keys="image", - a_min=-57, - a_max=164, - b_min=0.0, - b_max=1.0, - clip=True, - ), - CropForegroundd(keys=("image", "label"), source_key="image"), - RandCropByPosNegLabeld( - keys=("image", "label"), - label_key="label", - spatial_size=(64, 64, 64), - pos=1, - neg=1, - num_samples=4, - image_key="image", - image_threshold=0, - ), - ToTensord(keys=("image", "label")), - ] - ) - # set datalist - train_datalist = load_decathlon_datalist( - os.path.join(self.app_root, self.data_list_json_file), - is_segmentation=True, - data_list_key="training", - base_dir=os.path.join(self.app_root, self.dataset_folder_name), - ) - val_datalist = load_decathlon_datalist( - os.path.join(self.app_root, self.data_list_json_file), - is_segmentation=True, - data_list_key="validation", - base_dir=os.path.join(self.app_root, self.dataset_folder_name), - ) - train_ds = CacheDataset( - data=train_datalist, - transform=train_transforms, - cache_rate=1.0, - num_workers=4, - ) - train_data_loader = DataLoader( - train_ds, - batch_size=2, - shuffle=True, - num_workers=4, - ) - val_transforms = Compose( - [ - LoadImaged(keys=("image", "label")), - EnsureChannelFirstd(keys=("image", "label")), - Spacingd( - keys=["image", "label"], - pixdim=(1.5, 1.5, 2.0), - mode=("bilinear", "nearest"), - ), - Orientationd(keys=["image", "label"], axcodes="RAS"), - ScaleIntensityRanged( - keys="image", - a_min=-57, - a_max=164, - b_min=0.0, - b_max=1.0, - clip=True, - ), - CropForegroundd(keys=("image", "label"), source_key="image"), - ToTensord(keys=("image", "label")), - ] - ) - - val_ds = CacheDataset( - data=val_datalist, transform=val_transforms, cache_rate=0.0, num_workers=4 - ) - val_data_loader = DataLoader( - val_ds, - batch_size=1, - shuffle=False, - num_workers=4, - ) - post_transform = Compose( - [ - Activationsd(keys="pred", softmax=True), - AsDiscreted( - keys=["pred", "label"], - argmax=[True, False], - to_onehot=2, - ), - ] - ) - # metric - key_val_metric = { - "val_mean_dice": MeanDice( - include_background=False, - output_transform=from_engine(["pred", "label"]), - ) - } - val_handlers = [ - StatsHandler(output_transform=lambda x: None), - CheckpointSaver( - save_dir=self.ckpt_dir, - save_dict={"model": network}, - save_key_metric=True, - ), - TensorBoardStatsHandler( - log_dir=self.ckpt_dir, output_transform=lambda x: None - ), - ] - self.eval_engine = SupervisedEvaluator( - device=self.device, - val_data_loader=val_data_loader, - network=network, - inferer=SlidingWindowInferer( - roi_size=[160, 160, 160], - sw_batch_size=4, - overlap=0.5, - ), - postprocessing=post_transform, - key_val_metric=key_val_metric, - val_handlers=val_handlers, - amp=self.amp, - ) - - optimizer = torch.optim.Adam(network.parameters(), self.learning_rate) - loss_function = DiceLoss(to_onehot_y=True, softmax=True) - lr_scheduler = torch.optim.lr_scheduler.StepLR( - optimizer, step_size=5000, gamma=0.1 - ) - train_handlers = [ - LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), - ValidationHandler( - validator=self.eval_engine, interval=self.val_interval, epoch_level=True - ), - StatsHandler( - tag_name="train_loss", output_transform=from_engine("loss", first=True) - ), - TensorBoardStatsHandler( - log_dir=self.ckpt_dir, - tag_name="train_loss", - output_transform=from_engine("loss", first=True), - ), - ] - - self.train_engine = SupervisedTrainer( - device=self.device, - max_epochs=self.max_epochs, - train_data_loader=train_data_loader, - network=network, - optimizer=optimizer, - loss_function=loss_function, - inferer=SimpleInferer(), - postprocessing=post_transform, - key_train_metric=None, - train_handlers=train_handlers, - amp=self.amp, - ) diff --git a/federated_learning/nvflare/nvflare_spleen_example/requirements.txt b/federated_learning/nvflare/nvflare_spleen_example/requirements.txt index e37fc83d93..5eb8087a25 100644 --- a/federated_learning/nvflare/nvflare_spleen_example/requirements.txt +++ b/federated_learning/nvflare/nvflare_spleen_example/requirements.txt @@ -1,6 +1,6 @@ pip setuptools -nvflare==2.0.1 +nvflare==2.0.2 monai==0.8.0 pytorch-ignite==0.4.6 tqdm==4.61.2