diff --git a/federated_learning/nvflare/nvflare_spleen_example/3-Admin.ipynb b/federated_learning/nvflare/nvflare_spleen_example/3-Admin.ipynb
index 342a71b8f2..193b2edc0a 100644
--- a/federated_learning/nvflare/nvflare_spleen_example/3-Admin.ipynb
+++ b/federated_learning/nvflare/nvflare_spleen_example/3-Admin.ipynb
@@ -27,7 +27,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -46,17 +46,9 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "a working folder for the admin exists!!!\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"admin_path = \"poc/admin/startup/\"\n",
"\n",
@@ -73,50 +65,28 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!mkdir -p poc/admin/transfer\n",
- "!cp -r hello_monai/ poc/admin/transfer/"
+ "!cp -r hello-monai/ poc/admin/transfer/"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "['hello_monai']"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"os.listdir(\"poc/admin/transfer/\")"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "['fl_admin.sh']"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"os.listdir(admin_path)"
]
@@ -144,23 +114,9 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- " Open a new terminal"
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"HTML(' Open a new terminal')"
]
@@ -192,51 +148,18 @@
"\n",
"The commands can be:\n",
"```\n",
- "upload_app hello_monai\n",
+ "upload_app hello-monai\n",
"set_run_number 1\n",
- "deploy_app hello_monai server\n",
- "deploy_app hello_monai client\n",
- "```\n",
- "\n",
- "Now, let's check if the folder has been distributed into the server and all client(s):"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "config files on server: ['app_server', 'fl_app.txt']\n",
- " \n",
- "config files on site-1: ['app_site-1', 'fl_app.txt']\n",
- " \n",
- "config files on site-2: ['fl_app.txt', 'app_site-2']\n",
- " \n"
- ]
- }
- ],
- "source": [
- "run_file = \"run_1\"\n",
- "\n",
- "poc_path = \"poc/\"\n",
- "\n",
- "for name in [\"server\", \"site-1\", \"site-2\"]:\n",
- " path = os.path.join(poc_path, name, run_file)\n",
- " if os.path.exists(path):\n",
- " print(\"config files on {}: {}\".\n",
- " format(name, os.listdir(path)))\n",
- " print(\" \")"
+ "deploy_app hello-monai server\n",
+ "deploy_app hello-monai client\n",
+ "```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "This example prepares two different data list files: `dataset_part1.json` and `dataset_part2.json`, and they have the same validation set and totally different training set. The default file used in `config_train.json` is `config/dataset_part1.json`. Therefore, if you want to let two clients train on different data, you can switch to use `dataset_part2.json` for `org1-b`.\n",
+ "This example prepares two different data list files: `dataset_part1.json` and `dataset_part2.json`, and they have the same validation set and totally different training sets. The default file used in `config_train.json` is `config/dataset_part1.json`. Therefore, if you want to let two clients train on different data, you can switch to use `dataset_part2.json` for `org1-b`.\n",
"\n",
"[Link to site-1 config](poc/site-1/run_1/app_site-1/config/config_train.json)\n",
"\n",
@@ -250,13 +173,12 @@
"### (Optional) Copy Dataset\n",
"\n",
"After starting a client (for example `site-1`), the Spleen dataset will be downloaded into:\n",
- "`run_1/app_site-1/Task09_Spleen.tar`.\n",
+ "`poc/site-1/`.\n",
"\n",
- "To prevent repeatedly downloading the dataset, you can copy the uncompressed `Task09_Spleen` into the corresponding place after running the `deploy_app` command.\n",
- "For example:\n",
+ "If you already have the `Task09_Spleen`, you can copy it into the corresponding place directly to prevent repeatedly download. For example:\n",
"\n",
"```\n",
- "cp -r /path-to-dataset/Task09_Spleen poc/site-1/run_2/app_site-1/\n",
+ "cp -r /path-to-dataset/Task09_Spleen poc/site-1/\n",
"```"
]
},
diff --git a/federated_learning/nvflare/nvflare_spleen_example/README.md b/federated_learning/nvflare/nvflare_spleen_example/README.md
index 9f27a8fdaa..a4574ff48e 100644
--- a/federated_learning/nvflare/nvflare_spleen_example/README.md
+++ b/federated_learning/nvflare/nvflare_spleen_example/README.md
@@ -2,13 +2,15 @@
## Brief Introduction
-This repository contains an end-to-end Federated training example based on MONAI trainers and [NVIDIA FLARE](https://github.com/nvidia/nvflare).
+This repository contains an end-to-end Federated training example based on MONAI trainers and [NVIDIA FLARE](https://github.com/nvidia/nvflare). Please also download the `hello-monai` folder in [NVFlare/examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/hello-monai), and copy it into this directory.
+
+
This example requires Python 3.8.10
-Inside this folder:
+Inside this directory:
- All Jupiter notebooks are used to build an FL experiment step-by-step.
-- `hello-monai` is a folder containing all required config files for the experiment (in `config/`) and the customized trainer (in `custom`) and its components.
+- hello-monai (needs to be downloaded) is a folder containing all required config files for the experiment (in `config/`) and the customized trainer (in `custom`) and its components.
## Installation
diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_client.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_client.json
deleted file mode 100644
index 98bb905fbc..0000000000
--- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_client.json
+++ /dev/null
@@ -1,20 +0,0 @@
-{
- "format_version": 2,
- "executors": [
- {
- "tasks": ["train"],
- "executor": {
- "path": "monai_trainer.MONAITrainer",
- "args": {
- "aggregation_epochs": 10
- }
- }
- }
- ],
- "task_result_filters": [
- ],
- "task_data_filters": [
- ],
- "components": [
- ]
-}
diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_server.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_server.json
deleted file mode 100644
index 555b919dee..0000000000
--- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_fed_server.json
+++ /dev/null
@@ -1,63 +0,0 @@
-{
- "format_version": 2,
-
- "server": {
- "heart_beat_timeout": 600
- },
- "task_data_filters": [],
- "task_result_filters": [],
- "components": [
- {
- "id": "persistor",
- "name": "PTFileModelPersistor",
- "args": {
- "model": {
- "path": "monai.networks.nets.unet.UNet",
- "args": {
- "dimensions": 3,
- "in_channels": 1,
- "out_channels": 2,
- "channels": [16, 32, 64, 128, 256],
- "strides": [2, 2, 2, 2],
- "num_res_units": 2,
- "norm": "batch"
- }
- }
- }
- },
- {
- "id": "shareable_generator",
- "name": "FullModelShareableGenerator",
- "args": {}
- },
- {
- "id": "aggregator",
- "name": "AccumulateWeightedAggregator",
- "args": {
- "aggregation_weights": {
- "site-1": 1.0,
- "site-2": 0.5
- },
- "expected_data_kind": "WEIGHTS"
- }
- }
- ],
- "workflows": [
- {
- "id": "scatter_and_gather",
- "name": "ScatterAndGather",
- "args": {
- "min_clients" : 1,
- "num_rounds" : 100,
- "start_round": 0,
- "wait_time_after_min_received": 10,
- "aggregator_id": "aggregator",
- "persistor_id": "persistor",
- "shareable_generator_id": "shareable_generator",
- "train_task_name": "train",
- "train_timeout": 0,
- "ignore_result_error": true
- }
- }
- ]
-}
diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_train.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_train.json
deleted file mode 100644
index 9724632477..0000000000
--- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/config_train.json
+++ /dev/null
@@ -1,9 +0,0 @@
-{
- "max_epochs": 100,
- "learning_rate": 2e-4,
- "amp": true,
- "use_gpu": true,
- "val_interval": 5,
- "data_list_json_file": "config/dataset_part1.json",
- "ckpt_dir": "models"
-}
diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part1.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part1.json
deleted file mode 100644
index f88ccc8402..0000000000
--- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part1.json
+++ /dev/null
@@ -1 +0,0 @@
-{"training": [{"image": "./imagesTr/spleen_40.nii.gz", "label": "./labelsTr/spleen_40.nii.gz"}, {"image": "./imagesTr/spleen_44.nii.gz", "label": "./labelsTr/spleen_44.nii.gz"}, {"image": "./imagesTr/spleen_38.nii.gz", "label": "./labelsTr/spleen_38.nii.gz"}, {"image": "./imagesTr/spleen_25.nii.gz", "label": "./labelsTr/spleen_25.nii.gz"}, {"image": "./imagesTr/spleen_13.nii.gz", "label": "./labelsTr/spleen_13.nii.gz"}, {"image": "./imagesTr/spleen_6.nii.gz", "label": "./labelsTr/spleen_6.nii.gz"}, {"image": "./imagesTr/spleen_19.nii.gz", "label": "./labelsTr/spleen_19.nii.gz"}, {"image": "./imagesTr/spleen_24.nii.gz", "label": "./labelsTr/spleen_24.nii.gz"}, {"image": "./imagesTr/spleen_52.nii.gz", "label": "./labelsTr/spleen_52.nii.gz"}, {"image": "./imagesTr/spleen_9.nii.gz", "label": "./labelsTr/spleen_9.nii.gz"}, {"image": "./imagesTr/spleen_10.nii.gz", "label": "./labelsTr/spleen_10.nii.gz"}, {"image": "./imagesTr/spleen_41.nii.gz", "label": "./labelsTr/spleen_41.nii.gz"}, {"image": "./imagesTr/spleen_60.nii.gz", "label": "./labelsTr/spleen_60.nii.gz"}, {"image": "./imagesTr/spleen_56.nii.gz", "label": "./labelsTr/spleen_56.nii.gz"}, {"image": "./imagesTr/spleen_26.nii.gz", "label": "./labelsTr/spleen_26.nii.gz"}, {"image": "./imagesTr/spleen_17.nii.gz", "label": "./labelsTr/spleen_17.nii.gz"}], "validation": [{"image": "./imagesTr/spleen_31.nii.gz", "label": "./labelsTr/spleen_31.nii.gz"}, {"image": "./imagesTr/spleen_33.nii.gz", "label": "./labelsTr/spleen_33.nii.gz"}, {"image": "./imagesTr/spleen_8.nii.gz", "label": "./labelsTr/spleen_8.nii.gz"}, {"image": "./imagesTr/spleen_21.nii.gz", "label": "./labelsTr/spleen_21.nii.gz"}, {"image": "./imagesTr/spleen_22.nii.gz", "label": "./labelsTr/spleen_22.nii.gz"}, {"image": "./imagesTr/spleen_2.nii.gz", "label": "./labelsTr/spleen_2.nii.gz"}, {"image": "./imagesTr/spleen_3.nii.gz", "label": "./labelsTr/spleen_3.nii.gz"}, {"image": "./imagesTr/spleen_45.nii.gz", "label": "./labelsTr/spleen_45.nii.gz"}, {"image": "./imagesTr/spleen_32.nii.gz", "label": "./labelsTr/spleen_32.nii.gz"}]}
diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part2.json b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part2.json
deleted file mode 100644
index aab7f22875..0000000000
--- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/config/dataset_part2.json
+++ /dev/null
@@ -1 +0,0 @@
-{"training": [{"image": "./imagesTr/spleen_16.nii.gz", "label": "./labelsTr/spleen_16.nii.gz"}, {"image": "./imagesTr/spleen_20.nii.gz", "label": "./labelsTr/spleen_20.nii.gz"}, {"image": "./imagesTr/spleen_18.nii.gz", "label": "./labelsTr/spleen_18.nii.gz"}, {"image": "./imagesTr/spleen_46.nii.gz", "label": "./labelsTr/spleen_46.nii.gz"}, {"image": "./imagesTr/spleen_27.nii.gz", "label": "./labelsTr/spleen_27.nii.gz"}, {"image": "./imagesTr/spleen_49.nii.gz", "label": "./labelsTr/spleen_49.nii.gz"}, {"image": "./imagesTr/spleen_62.nii.gz", "label": "./labelsTr/spleen_62.nii.gz"}, {"image": "./imagesTr/spleen_53.nii.gz", "label": "./labelsTr/spleen_53.nii.gz"}, {"image": "./imagesTr/spleen_12.nii.gz", "label": "./labelsTr/spleen_12.nii.gz"}, {"image": "./imagesTr/spleen_47.nii.gz", "label": "./labelsTr/spleen_47.nii.gz"}, {"image": "./imagesTr/spleen_28.nii.gz", "label": "./labelsTr/spleen_28.nii.gz"}, {"image": "./imagesTr/spleen_61.nii.gz", "label": "./labelsTr/spleen_61.nii.gz"}, {"image": "./imagesTr/spleen_29.nii.gz", "label": "./labelsTr/spleen_29.nii.gz"}, {"image": "./imagesTr/spleen_14.nii.gz", "label": "./labelsTr/spleen_14.nii.gz"}, {"image": "./imagesTr/spleen_63.nii.gz", "label": "./labelsTr/spleen_63.nii.gz"}, {"image": "./imagesTr/spleen_59.nii.gz", "label": "./labelsTr/spleen_59.nii.gz"}], "validation": [{"image": "./imagesTr/spleen_31.nii.gz", "label": "./labelsTr/spleen_31.nii.gz"}, {"image": "./imagesTr/spleen_33.nii.gz", "label": "./labelsTr/spleen_33.nii.gz"}, {"image": "./imagesTr/spleen_8.nii.gz", "label": "./labelsTr/spleen_8.nii.gz"}, {"image": "./imagesTr/spleen_21.nii.gz", "label": "./labelsTr/spleen_21.nii.gz"}, {"image": "./imagesTr/spleen_22.nii.gz", "label": "./labelsTr/spleen_22.nii.gz"}, {"image": "./imagesTr/spleen_2.nii.gz", "label": "./labelsTr/spleen_2.nii.gz"}, {"image": "./imagesTr/spleen_3.nii.gz", "label": "./labelsTr/spleen_3.nii.gz"}, {"image": "./imagesTr/spleen_45.nii.gz", "label": "./labelsTr/spleen_45.nii.gz"}, {"image": "./imagesTr/spleen_32.nii.gz", "label": "./labelsTr/spleen_32.nii.gz"}]}
diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/monai_trainer.py b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/monai_trainer.py
deleted file mode 100644
index 6593897f91..0000000000
--- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/monai_trainer.py
+++ /dev/null
@@ -1,262 +0,0 @@
-# Copyright 2020 MONAI Consortium
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from typing import Dict
-
-import numpy as np
-import torch
-from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
-from nvflare.apis.event_type import EventType
-from nvflare.apis.executor import Executor
-from nvflare.apis.fl_constant import FLContextKey, ReturnCode
-from nvflare.apis.fl_context import FLContext
-from nvflare.apis.shareable import Shareable
-from nvflare.apis.signal import Signal
-from nvflare.app_common.app_constant import AppConstants
-
-from train_configer import TrainConfiger
-
-
-class MONAITrainer(Executor):
- """
- This class implements a MONAI based trainer that can be used for Federated Learning with NVFLARE.
-
- Args:
- aggregation_epochs: the number of training epochs for a round. Defaults to 1.
-
- """
-
- def __init__(
- self,
- aggregation_epochs: int = 1,
- train_task_name: str = AppConstants.TASK_TRAIN,
- ):
- super().__init__()
- """
- Trainer init happens at the very beginning, only the basic info regarding the trainer is set here,
- and the actual run has not started at this point.
- """
- self.aggregation_epochs = aggregation_epochs
- self._train_task_name = train_task_name
-
- def _initialize_trainer(self, fl_ctx: FLContext):
- """
- The trainer's initialization function. At the beginning of a FL experiment,
- the train and evaluate engines, as well as train context and FL context
- should be initialized.
- """
- # Initialize train and evaluation engines.
- app_root = fl_ctx.get_prop(FLContextKey.APP_ROOT)
- fl_args = fl_ctx.get_prop(FLContextKey.ARGS)
- # will update multi-gpu supports later
- # num_gpus = fl_ctx.get_prop(AppConstants.NUMBER_OF_GPUS, 1)
- # self.multi_gpu = num_gpus > 1
- self.client_name = fl_ctx.get_identity_name()
- self.log_info(
- fl_ctx,
- f"Client {self.client_name} initialized at \n {app_root} \n with args: {fl_args}",
- )
- conf = TrainConfiger(
- app_root=app_root,
- wf_config_file_name=fl_args.train_config,
- local_rank=fl_args.local_rank,
- )
- conf.configure()
-
- # train_engine, and eval_engine are MONAI engines that will be used for training and validation.
- # The corresponding training/validation settings, such as transforms, network and dataset
- # are contained in `TrainConfiger`.
- # The engine will be started when `.run()` is called, and when `.terminate()` is called,
- # it will be completely terminated after the current iteration is finished.
- self.train_engine = conf.train_engine
- self.eval_engine = conf.eval_engine
-
- def assign_current_model(self, model_weights: Dict[str, np.ndarray]):
- """
- This function is used to load provided weights for the network.
- Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
- More info of HE:
- https://github.com/NVIDIA/clara-train-examples/blob/master/PyTorch/NoteBooks/FL/Homomorphic_Encryption.ipynb
-
- """
- net = self.train_engine.network
-
- local_var_dict = net.state_dict()
- model_keys = model_weights.keys()
- for var_name in local_var_dict:
- if var_name in model_keys:
- weights = model_weights[var_name]
- try:
- local_var_dict[var_name] = torch.as_tensor(
- np.reshape(weights, local_var_dict[var_name].shape)
- )
- except Exception as e:
- raise ValueError(
- "Convert weight from {} failed with error: {}".format(
- var_name, str(e)
- )
- )
-
- net.load_state_dict(local_var_dict)
-
- def extract_model(self) -> Dict[str, np.ndarray]:
- """
- This function is used to extract weights of the network.
- The extracted weights will be converted into a numpy array based dict.
- """
- net = self.train_engine.network
- local_state_dict = net.state_dict()
- local_model_dict = {}
- for var_name in local_state_dict:
- try:
- local_model_dict[var_name] = local_state_dict[var_name].cpu().numpy()
- except Exception as e:
- raise ValueError(
- "Convert weight from {} failed with error: {}".format(
- var_name, str(e)
- )
- )
-
- return local_model_dict
-
- def generate_shareable(self):
- """
- This function is used to generate a DXO instance.
- The instance can contain not only model weights, but also
- some additional information that clients want to share.
- """
- # update meta, NUM_STEPS_CURRENT_ROUND is needed for aggregation.
- if self.achieved_meta is None:
- meta = {MetaKey.NUM_STEPS_CURRENT_ROUND: self.current_iters}
- else:
- meta = self.achieved_meta
- meta[MetaKey.NUM_STEPS_CURRENT_ROUND] = self.current_iters
- return DXO(
- data_kind=DataKind.WEIGHTS,
- data=self.extract_model(),
- meta=meta,
- ).to_shareable()
-
- def handle_event(self, event_type: str, fl_ctx: FLContext):
- """
- This function is an extended function from the super class.
- It is used to handle two events:
-
- 1) `START_RUN`. At the start point of a FL experiment,
- necessary components should be initialized.
- 2) `ABORT_TASK`, when this event is fired, the running engines
- should be terminated (this example uses MONAI engines to do train
- and validation, and the engines can be terminated from another thread.
- If the solution does not provide any way to interrupt/end the execution,
- handle this event is not feasible).
-
-
- Args:
- event_type: the type of event that will be fired. In MONAITrainer,
- only `START_RUN` and `END_RUN` need to be handled.
- fl_ctx: an `FLContext` object.
-
- """
- if event_type == EventType.START_RUN:
- self._initialize_trainer(fl_ctx)
- elif event_type == EventType.ABORT_TASK:
- # This event is fired to abort the current execution task. We are using the ignite engine to run the task.
- # Unfortunately the ignite engine does not support the abort of task right now. We have to wait until
- # the current task finishes.
- pass
- elif event_type == EventType.END_RUN:
- self.eval_engine.terminate()
- self.train_engine.terminate()
-
- def _abort_execution(self) -> Shareable:
- shareable = Shareable()
- shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
- return shareable
-
- def execute(
- self,
- task_name: str,
- shareable: Shareable,
- fl_ctx: FLContext,
- abort_signal: Signal,
- ) -> Shareable:
- """
- This function is an extended function from the super class.
- As a supervised learning based trainer, the execute function will run
- evaluate and train engines based on model weights from `shareable`.
- After fininshing training, a new `Shareable` object will be submitted
- to server for aggregation.
-
- Args:
- task_name: decide which task will be executed.
- shareable: the `Shareable` object acheived from server.
- fl_ctx: the `FLContext` object achieved from server.
- abort_signal: if triggered, the training will be aborted. In order to interrupt the training/validation
- state, a separate is used to check the signal information every few seconds. The implementation is
- shown in the `handle_event` function.
- Returns:
- a new `Shareable` object to be submitted to server for aggregation.
- """
- if task_name == self._train_task_name:
- # convert shareable into DXO instance
- dxo = from_shareable(shareable)
- # check if dxo is valid.
- if not isinstance(dxo, DXO):
- self.log_exception(
- fl_ctx, f"dxo excepted type DXO. Got {type(dxo)} instead."
- )
- shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
- return shareable
-
- # ensure data kind is weights.
- if not dxo.data_kind == DataKind.WEIGHTS:
- self.log_exception(
- fl_ctx,
- f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.",
- )
- shareable.set_return_code(ReturnCode.EXECUTION_EXCEPTION)
- return shareable
-
- # load weights from dxo
- self.assign_current_model(dxo.data)
- # collect meta from dxo
- self.achieved_meta = dxo.meta
-
- # set engine state max epochs.
- self.train_engine.state.max_epochs = (
- self.train_engine.state.epoch + self.aggregation_epochs
- )
- # get current iteration when a round starts
- iter_of_start_time = self.train_engine.state.iteration
-
- # execute validation at the beginning of every round
- self.eval_engine.run(self.train_engine.state.epoch + 1)
-
- # check abort signal after validation
- if abort_signal.triggered:
- return self._abort_execution()
-
- self.train_engine.run()
-
- # check abort signal after train
- if abort_signal.triggered:
- return self._abort_execution()
-
- # calculate current iteration and epoch data after training.
- self.current_iters = self.train_engine.state.iteration - iter_of_start_time
- # create a new `Shareable` object
- return self.generate_shareable()
- else:
- # If unknown task name, set ReturnCode accordingly.
- shareable = Shareable()
- shareable.set_return_code(ReturnCode.TASK_UNKNOWN)
- return shareable
diff --git a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/train_configer.py b/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/train_configer.py
deleted file mode 100644
index ca7f315b1f..0000000000
--- a/federated_learning/nvflare/nvflare_spleen_example/hello_monai/custom/train_configer.py
+++ /dev/null
@@ -1,280 +0,0 @@
-# Copyright 2020 MONAI Consortium
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import os
-
-import torch
-from monai.apps.utils import download_and_extract
-from monai.data import CacheDataset, DataLoader, load_decathlon_datalist
-from monai.engines import SupervisedEvaluator, SupervisedTrainer
-from monai.handlers import (
- CheckpointSaver,
- LrScheduleHandler,
- MeanDice,
- StatsHandler,
- TensorBoardStatsHandler,
- ValidationHandler,
- from_engine,
-)
-from monai.inferers import SimpleInferer, SlidingWindowInferer
-from monai.losses import DiceLoss
-from monai.networks.layers import Norm
-from monai.networks.nets import UNet
-from monai.transforms import (
- Activationsd,
- AsDiscreted,
- Compose,
- CropForegroundd,
- EnsureChannelFirstd,
- LoadImaged,
- Orientationd,
- RandCropByPosNegLabeld,
- ScaleIntensityRanged,
- Spacingd,
- ToTensord,
-)
-
-
-class TrainConfiger:
- """
- This class is used to config the necessary components of train and evaluate engines
- for MONAI trainer.
- Please check the implementation of `SupervisedEvaluator` and `SupervisedTrainer`
- from `monai.engines` and determine which components can be used.
- Args:
- app_root: root folder path of config files.
- wf_config_file_name: json file name of the workflow config file.
- """
-
- def __init__(
- self,
- app_root: str,
- wf_config_file_name: str,
- local_rank: int = 0,
- dataset_folder_name: str = "Task09_Spleen",
- ):
- with open(os.path.join(app_root, wf_config_file_name)) as file:
- wf_config = json.load(file)
-
- self.wf_config = wf_config
- """
- config Args:
- max_epochs: the total epoch number for trainer to run.
- learning_rate: the learning rate for optimizer.
- dataset_dir: the directory containing the dataset. if `dataset_folder_name` does not
- exist in the directory, it will be downloaded first.
- data_list_json_file: the data list json file.
- val_interval: the interval (number of epochs) to do validation.
- ckpt_dir: the directory to save the checkpoint.
- amp: whether to enable auto-mixed-precision training.
- use_gpu: whether to use GPU in training.
-
- """
- self.max_epochs = wf_config["max_epochs"]
- self.learning_rate = wf_config["learning_rate"]
- self.data_list_json_file = wf_config["data_list_json_file"]
- self.val_interval = wf_config["val_interval"]
- self.ckpt_dir = wf_config["ckpt_dir"]
- self.amp = wf_config["amp"]
- self.use_gpu = wf_config["use_gpu"]
- self.local_rank = local_rank
- self.app_root = app_root
- self.dataset_folder_name = dataset_folder_name
- if not os.path.exists(os.path.join(app_root, self.dataset_folder_name)):
- self.download_spleen_dataset()
-
- def set_device(self):
- device = torch.device("cuda" if self.use_gpu else "cpu")
- self.device = device
-
- def download_spleen_dataset(self):
- url = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
- name = os.path.join(self.app_root, self.dataset_folder_name)
- tarfile_name = f"{name}.tar"
- download_and_extract(url=url, filepath=tarfile_name, output_dir=self.app_root)
-
- def configure(self):
- self.set_device()
- network = UNet(
- dimensions=3,
- in_channels=1,
- out_channels=2,
- channels=(16, 32, 64, 128, 256),
- strides=(2, 2, 2, 2),
- num_res_units=2,
- norm=Norm.BATCH,
- ).to(self.device)
- train_transforms = Compose(
- [
- LoadImaged(keys=("image", "label")),
- EnsureChannelFirstd(keys=("image", "label")),
- Spacingd(
- keys=["image", "label"],
- pixdim=(1.5, 1.5, 2.0),
- mode=("bilinear", "nearest"),
- ),
- Orientationd(keys=["image", "label"], axcodes="RAS"),
- ScaleIntensityRanged(
- keys="image",
- a_min=-57,
- a_max=164,
- b_min=0.0,
- b_max=1.0,
- clip=True,
- ),
- CropForegroundd(keys=("image", "label"), source_key="image"),
- RandCropByPosNegLabeld(
- keys=("image", "label"),
- label_key="label",
- spatial_size=(64, 64, 64),
- pos=1,
- neg=1,
- num_samples=4,
- image_key="image",
- image_threshold=0,
- ),
- ToTensord(keys=("image", "label")),
- ]
- )
- # set datalist
- train_datalist = load_decathlon_datalist(
- os.path.join(self.app_root, self.data_list_json_file),
- is_segmentation=True,
- data_list_key="training",
- base_dir=os.path.join(self.app_root, self.dataset_folder_name),
- )
- val_datalist = load_decathlon_datalist(
- os.path.join(self.app_root, self.data_list_json_file),
- is_segmentation=True,
- data_list_key="validation",
- base_dir=os.path.join(self.app_root, self.dataset_folder_name),
- )
- train_ds = CacheDataset(
- data=train_datalist,
- transform=train_transforms,
- cache_rate=1.0,
- num_workers=4,
- )
- train_data_loader = DataLoader(
- train_ds,
- batch_size=2,
- shuffle=True,
- num_workers=4,
- )
- val_transforms = Compose(
- [
- LoadImaged(keys=("image", "label")),
- EnsureChannelFirstd(keys=("image", "label")),
- Spacingd(
- keys=["image", "label"],
- pixdim=(1.5, 1.5, 2.0),
- mode=("bilinear", "nearest"),
- ),
- Orientationd(keys=["image", "label"], axcodes="RAS"),
- ScaleIntensityRanged(
- keys="image",
- a_min=-57,
- a_max=164,
- b_min=0.0,
- b_max=1.0,
- clip=True,
- ),
- CropForegroundd(keys=("image", "label"), source_key="image"),
- ToTensord(keys=("image", "label")),
- ]
- )
-
- val_ds = CacheDataset(
- data=val_datalist, transform=val_transforms, cache_rate=0.0, num_workers=4
- )
- val_data_loader = DataLoader(
- val_ds,
- batch_size=1,
- shuffle=False,
- num_workers=4,
- )
- post_transform = Compose(
- [
- Activationsd(keys="pred", softmax=True),
- AsDiscreted(
- keys=["pred", "label"],
- argmax=[True, False],
- to_onehot=2,
- ),
- ]
- )
- # metric
- key_val_metric = {
- "val_mean_dice": MeanDice(
- include_background=False,
- output_transform=from_engine(["pred", "label"]),
- )
- }
- val_handlers = [
- StatsHandler(output_transform=lambda x: None),
- CheckpointSaver(
- save_dir=self.ckpt_dir,
- save_dict={"model": network},
- save_key_metric=True,
- ),
- TensorBoardStatsHandler(
- log_dir=self.ckpt_dir, output_transform=lambda x: None
- ),
- ]
- self.eval_engine = SupervisedEvaluator(
- device=self.device,
- val_data_loader=val_data_loader,
- network=network,
- inferer=SlidingWindowInferer(
- roi_size=[160, 160, 160],
- sw_batch_size=4,
- overlap=0.5,
- ),
- postprocessing=post_transform,
- key_val_metric=key_val_metric,
- val_handlers=val_handlers,
- amp=self.amp,
- )
-
- optimizer = torch.optim.Adam(network.parameters(), self.learning_rate)
- loss_function = DiceLoss(to_onehot_y=True, softmax=True)
- lr_scheduler = torch.optim.lr_scheduler.StepLR(
- optimizer, step_size=5000, gamma=0.1
- )
- train_handlers = [
- LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
- ValidationHandler(
- validator=self.eval_engine, interval=self.val_interval, epoch_level=True
- ),
- StatsHandler(
- tag_name="train_loss", output_transform=from_engine("loss", first=True)
- ),
- TensorBoardStatsHandler(
- log_dir=self.ckpt_dir,
- tag_name="train_loss",
- output_transform=from_engine("loss", first=True),
- ),
- ]
-
- self.train_engine = SupervisedTrainer(
- device=self.device,
- max_epochs=self.max_epochs,
- train_data_loader=train_data_loader,
- network=network,
- optimizer=optimizer,
- loss_function=loss_function,
- inferer=SimpleInferer(),
- postprocessing=post_transform,
- key_train_metric=None,
- train_handlers=train_handlers,
- amp=self.amp,
- )
diff --git a/federated_learning/nvflare/nvflare_spleen_example/requirements.txt b/federated_learning/nvflare/nvflare_spleen_example/requirements.txt
index e37fc83d93..5eb8087a25 100644
--- a/federated_learning/nvflare/nvflare_spleen_example/requirements.txt
+++ b/federated_learning/nvflare/nvflare_spleen_example/requirements.txt
@@ -1,6 +1,6 @@
pip
setuptools
-nvflare==2.0.1
+nvflare==2.0.2
monai==0.8.0
pytorch-ignite==0.4.6
tqdm==4.61.2