From 51d6b1c1b88e6e56cc66051d94abfae63a4b6841 Mon Sep 17 00:00:00 2001 From: mmelqin Date: Thu, 19 May 2022 23:40:45 -0700 Subject: [PATCH 01/13] WIP MONAI Bundle inference operator. Signed-off-by: mmelqin --- .../ai_spleen_seg_app/spleen_seg_operator.py | 80 +-- monai/deploy/operators/__init__.py | 2 + .../monai_bundle_inference_operator.py | 489 ++++++++++++++++++ 3 files changed, 502 insertions(+), 69 deletions(-) create mode 100644 monai/deploy/operators/monai_bundle_inference_operator.py diff --git a/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py b/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py index b8ef28f5..f41340bc 100644 --- a/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py +++ b/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py @@ -16,39 +16,23 @@ import monai.deploy.core as md from monai.deploy.core import ExecutionContext, Image, InputContext, IOType, Operator, OutputContext -from monai.deploy.operators.monai_seg_inference_operator import InMemImageReader, MonaiSegInferenceOperator -from monai.transforms import ( - Activationsd, - AsDiscreted, - Compose, - CropForegroundd, - EnsureChannelFirstd, - Invertd, - LoadImaged, - SaveImaged, - ScaleIntensityRanged, - Spacingd, - ToTensord, -) +from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator @md.input("image", Image, IOType.IN_MEMORY) @md.output("seg_image", Image, IOType.IN_MEMORY) -@md.env(pip_packages=["monai>=0.8.1", "torch>=1.5", "numpy>=1.21", "nibabel"]) +@md.env(pip_packages=["monai>=0.8.1", "torch>=1.10.2", "numpy>=1.21", "nibabel"]) class SpleenSegOperator(Operator): """Performs Spleen segmentation with a 3D image converted from a DICOM CT series. - This operator makes use of the App SDK MonaiSegInferenceOperator in a compsition approach. - It creates the pre-transforms as well as post-transforms with MONAI dictionary based transforms. - Note that the App SDK InMemImageReader, derived from MONAI ImageReader, is passed to LoadImaged. - This derived reader is needed to parse the in memory image object, and return the expected data structure. - Loading of the model, and predicting using in-proc PyTorch inference is done by MonaiSegInferenceOperator. + This operator makes use of the App SDK MonaiBundleInferenceOperator in a composition approach. + Loading of the model, and predicting using in-proc PyTorch inference is done by MonaiBundleInferenceOperator. """ - def __init__(self): + def __init__(self, *args, **kwargs): self.logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) - super().__init__() + super().__init__(*args, **kwargs) self._input_dataset_key = "image" self._pred_dataset_key = "pred" @@ -58,25 +42,13 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe if not input_image: raise ValueError("Input image is not found.") - # Get the output path from the execution context for saving file(s) to app output. - # Without using this path, operator would be saving files to its designated path, e.g. - # $PWD/.monai_workdir/operators/6048d75a-5de1-45b9-8bd1-2252f88827f2/0/output - output_path = context.output.get().path - - # This operator gets an in-memory Image object, so a specialized ImageReader is needed. - _reader = InMemImageReader(input_image) - pre_transforms = self.pre_process(_reader) - post_transforms = self.post_process(pre_transforms, path.join(output_path, "prediction_output")) - # Delegates inference and saving output to the built-in operator. - infer_operator = MonaiSegInferenceOperator( - ( - 160, - 160, - 160, + infer_operator = MonaiBundleInferenceOperator( + roi_size=( + 96, + 96, + 96, ), - pre_transforms, - post_transforms, ) # Setting the keys used in the dictironary based transforms may change. @@ -85,33 +57,3 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe # Now let the built-in operator handles the work with the I/O spec and execution context. infer_operator.compute(op_input, op_output, context) - - def pre_process(self, img_reader) -> Compose: - """Composes transforms for preprocessing input before predicting on a model.""" - - my_key = self._input_dataset_key - return Compose( - [ - LoadImaged(keys=my_key, reader=img_reader), - EnsureChannelFirstd(keys=my_key), - Spacingd(keys=my_key, pixdim=[1.0, 1.0, 1.0], mode=["bilinear"], align_corners=True), - ScaleIntensityRanged(keys=my_key, a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), - CropForegroundd(keys=my_key, source_key=my_key), - ToTensord(keys=my_key), - ] - ) - - def post_process(self, pre_transforms: Compose, out_dir: str = "./prediction_output") -> Compose: - """Composes transforms for postprocessing the prediction results.""" - - pred_key = self._pred_dataset_key - return Compose( - [ - Activationsd(keys=pred_key, softmax=True), - AsDiscreted(keys=pred_key, argmax=True), - Invertd( - keys=pred_key, transform=pre_transforms, orig_keys=self._input_dataset_key, nearest_interp=True - ), - SaveImaged(keys=pred_key, output_dir=out_dir, output_postfix="seg", output_dtype=uint8, resample=False), - ] - ) diff --git a/monai/deploy/operators/__init__.py b/monai/deploy/operators/__init__.py index c1b25701..10ab403d 100644 --- a/monai/deploy/operators/__init__.py +++ b/monai/deploy/operators/__init__.py @@ -19,6 +19,7 @@ DICOMSeriesToVolumeOperator DICOMTextSRWriterOperator InferenceOperator + MonaiBundleInferenceOperator MonaiSegInferenceOperator PNGConverterOperator PublisherOperator @@ -33,6 +34,7 @@ from .dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator from .dicom_text_sr_writer_operator import DICOMTextSRWriterOperator, EquipmentInfo, ModelInfo from .inference_operator import InferenceOperator +from .monai_bundle_inference_operator import MonaiBundleInferenceOperator from .monai_seg_inference_operator import MonaiSegInferenceOperator from .png_converter_operator import PNGConverterOperator from .publisher_operator import PublisherOperator diff --git a/monai/deploy/operators/monai_bundle_inference_operator.py b/monai/deploy/operators/monai_bundle_inference_operator.py new file mode 100644 index 00000000..7c2d83d1 --- /dev/null +++ b/monai/deploy/operators/monai_bundle_inference_operator.py @@ -0,0 +1,489 @@ +# Copyright 2021-2002 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import zipfile +import tempfile +import json + +from pathlib import Path +from threading import Lock +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np + +from monai.deploy.utils.importutil import optional_import +from monai.transforms.io.dictionary import LoadImageD + +torch, _ = optional_import("torch", "1.5") +np_str_obj_array_pattern, _ = optional_import("torch.utils.data._utils.collate", name="np_str_obj_array_pattern") +Dataset, _ = optional_import("monai.data", name="Dataset") +DataLoader, _ = optional_import("monai.data", name="DataLoader") +ImageReader_, image_reader_ok_ = optional_import("monai.data", name="ImageReader") +# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477 +ImageReader: Any = ImageReader_ +if not image_reader_ok_: + ImageReader = object # for 'class InMemImageReader(ImageReader):' to work +decollate_batch, _ = optional_import("monai.data", name="decollate_batch") +sliding_window_inference, _ = optional_import("monai.inferers", name="sliding_window_inference") +ensure_tuple, _ = optional_import("monai.utils", name="ensure_tuple") +Compose_, _ = optional_import("monai.transforms", name="Compose") +MapTransform_, _ = optional_import("monai.transforms", name="MapTransform") +LoadImaged_, _ = optional_import("monai.transforms", name="LoadImaged") +ConfigParser_, _ = optional_import("monai.bundle", name="ConfigParser") +# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477 +Compose: Any = Compose_ +MapTransform: Any = MapTransform_ +LoadImaged: Any = LoadImaged_ +ConfigParser: Any = ConfigParser_ + +simple_inference, _ = optional_import("monai.inferers", name="simple_inference") +sliding_window_inference, _ = optional_import("monai.inferers", name="sliding_window_inference") + +import monai.deploy.core as md +from monai.deploy.core import ExecutionContext, Image, InputContext, IOType, OutputContext + +from .inference_operator import InferenceOperator + +__all__ = ["MonaiBundleInferenceOperator", "InMemImageReader"] + +# TODO: For now, assuming single input and single output, but need to change +@md.env(pip_packages=["monai>=0.8.1", "torch>=1.10.02", "numpy>=1.21"]) +class MonaiBundleInferenceOperator(InferenceOperator): + """This segmentation operator uses MONAI transforms and Sliding Window Inference. + + This operator preforms pre-transforms on a input image, inference + using a given model, and post-transforms. The segmentation image is saved + as a named Image object in memory. + + If specified in the post transforms, results may also be saved to disk. + """ + + DISALLOWED_TRANSFORMS = ["LoadImage", "SaveImage"] + + def __init__( + self, + model_name: Optional[str] = "", + bundle_path: Optional[str] = None, + preproc_name: Optional[str] = "preprocessing", + postproc_name: Optional[str] = "postprocessing", + pre_transforms: Optional[Compose] = None , + post_transforms: Optional[Compose] = None, + roi_size: Union[Sequence[int], int] = (96, 96, 96,), + overlap: float = 0.5, + *args, + **kwargs, + ): + """Creates a instance of this class. + + Args: + model_name (Optional[str]): The name of the model in the MONAI Bundle. + bundle_path: Optional[str]: Path of the MONAI Bundle, overridden by model loader. + preproc_name: Optional[str]: Inference config item name for "preprocessing". + postproc_name: Optional[str]: Inference config name for "postprocessing". + pre_transforms (Compose): MONAI Compose object used for pre-transforms. + post_transforms (Compose): MONAI Compose object used for post-transforms. + roi_size (Union[Sequence[int], int]): The tensor size used in inference. + overlap (float): The overlap used in sliding window inference. + """ + + super().__init__(*args, **kwargs) + self._executing = False + self._lock = Lock() + self._model_name = model_name if model_name else "" + self._bundle_path = Path(bundle_path).expanduser().resolve() \ + if bundle_path and len(bundle_path) > 0 else None + self._preproc_name = preproc_name + self._postproc_name = postproc_name + self._parser = None # Delay init till execution context is set. + self._pre_transform = pre_transforms + self._post_transforms = post_transforms + + #TODO MQ to clean up + self._input_dataset_key = "image" + self._pred_dataset_key = "pred" + self._input_image = None # Image will come in when compute is called. + self._reader: Any = None + + self._roi_size = ensure_tuple(roi_size) + self.overlap = overlap + + @property + def model_name(self) -> str: + """The name of the model in the MONAI Bundle.""" + return self._model_name + + @model_name.setter + def model_name(self, name: str): + if not name or len(name) == 0: + raise ValueError(f"Value, {name}, must be a non-empty string.") + self._model_name = name + + @property + def bundle_path(self) -> Union[Path, None]: + """The path of the MONAI Bundle model.""" + return self._bundle_path + + @bundle_path.setter + def bundle_path(self, bundle_path: Union[str, Path]): + if not bundle_path or not Path(bundle_path).expanduser().is_file(): + raise ValueError(f"Value, {bundle_path}, is not a valid file path.") + self._bundle_path = Path(bundle_path).expanduser().resolve() + + @property + def parser(self) -> Union[ConfigParser, None]: + """The ConfigParser object.""" + return self._parser + + @parser.setter + def parser(self, parser: ConfigParser): + if parser and isinstance(parser, ConfigParser): + self._parser = parser + else: + raise ValueError(f"Value must be a valid ConfigParser object.") + + ## + + @property + def input_dataset_key(self): + """This is the input image key name used in dictionary based MONAI pre-transforms.""" + return self._input_dataset_key + + @input_dataset_key.setter + def input_dataset_key(self, val: str): + if not val or len(val) < 1: + raise ValueError("Value cannot be None or blank.") + self._input_dataset_key = val + + @property + def pred_dataset_key(self): + """This is the prediction key name used in dictionary based MONAI post-transforms.""" + return self._pred_dataset_key + + @pred_dataset_key.setter + def pred_dataset_key(self, val: str): + if not val or len(val) < 1: + raise ValueError("Value cannot be None or blank.") + self._pred_dataset_key = val + + @property + def overlap(self): + """This is the overlap used during sliding window inference""" + return self._overlap + + @overlap.setter + def overlap(self, val: float): + if val < 0 or val > 1: + raise ValueError("Overlap must be between 0 and 1.") + self._overlap = val + + def _get_bundle_config(self, bundle_path: Path) -> ConfigParser: + """Get the MONAI configuration parser from the specified MONAI Bundle file path. + + Args: + bundle_path (Path): Path of the MONAI Bundle + + Returns: + ConfigParser: MONAI Bundle config parser + """ + # The final path component, without its suffix, is expected to the model name + name = bundle_path.stem + parser = ConfigParser() + + with tempfile.TemporaryDirectory() as td: + archive = zipfile.ZipFile(str(bundle_path), "r") + archive.extract(name + "/extra/metadata.json", td) + archive.extract(name + "/extra/inference.json", td) + + os.rename(f"{td}/{name}/extra/inference.json", f"{td}/{name}/extra/config.json") + + parser.read_meta(f=f"{td}/{name}/extra/metadata.json") + parser.read_config(f=f"{td}/{name}/extra/config.json") + + parser.parse() + + return parser + + def _filter_compose(self, compose:Compose): + """ + Remove transforms from the given Compose object which shouldn't be used in an Operator. + """ + + if not compose: + return Compose([]) # Could just bounce the None input back. + + filtered = [] + for t in compose.transforms: + tname = type(t).__name__ + if not any(dis in tname for dis in self.DISALLOWED_TRANSFORMS): + filtered.append(t) + + return Compose(filtered) + + def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext): + """Infers with the input image and save the predicted image to output + + Args: + op_input (InputContext): An input context for the operator. + op_output (OutputContext): An output context for the operator. + context (ExecutionContext): An execution context for the operator. + """ + with self._lock: + if self._executing: + raise RuntimeError("Operator is already executing.") + else: + self._executing = True + + # If present, get the compliant model from context, else, from bundle path if given + model = None + if context.models: + # `context.models.get(model_name)` returns a model instance if exists. + # If model_name is not specified and only one model exists, it returns that model. + model = context.models.get(self.model_name) + if model: + self.bundle_path = model.path + if not model and self.bundle_path: + print(f"Loading TorchScript model from: {self.bundle_path}") + model = torch.jit.load(self.bundle_path, map_location=device) + + if not model: + raise IOError("Cannot find model file.") + + # Load the ConfigParser + self.parser = self._get_bundle_config(self.bundle_path) + + try: + input_image = op_input.get() + if not input_image: + raise ValueError("Input is None.") + + input_img_metadata = input_image.metadata() + # Need to give a name to the image as in-mem Image obj has no name. + img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context")) + + self._reader = InMemImageReader(input_image) # For convering Image to MONAI expected format + pre_transforms: Compose = \ + self._pre_transform if self._pre_transform else self.pre_process(self._reader) + post_transforms: Compose = \ + self._post_transforms if self._post_transforms else self.post_process(pre_transforms) + + #TODO: From bundle config + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms) + dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + + with torch.no_grad(): + for d in dataloader: + images = d[self._input_dataset_key].to(device) + sw_batch_size = 4 + d[self._pred_dataset_key] = sliding_window_inference( + inputs=images, + roi_size=self._roi_size, + sw_batch_size=sw_batch_size, + overlap=self.overlap, + predictor=model, + ) + d = [post_transforms(i) for i in decollate_batch(d)] + out_ndarray = d[0][self._pred_dataset_key].cpu().numpy() + # Need to squeeze out the channel dim fist + out_ndarray = np.squeeze(out_ndarray, 0) + # NOTE: The domain Image object simply contains a Arraylike obj as image as of now. + # When the original DICOM series is converted by the Series to Volume operator, + # using pydicom pixel_array, the 2D ndarray of each slice has index order HW, and + # when all slices are stacked with depth as first axis, DHW. In the pre-transforms, + # the image gets transposed to WHD and used as such in the inference pipeline. + # So once post-transforms have completed, and the channel is squeezed out, + # the resultant ndarray for the prediction image needs to be transposed back, so the + # array index order is back to DHW, the same order as the in-memory input Image obj. + out_ndarray = out_ndarray.T.astype(np.uint8) + print(f"Output Seg image numpy array shaped: {out_ndarray.shape}") + print(f"Output Seg image pixel max value: {np.amax(out_ndarray)}") + out_image = Image(out_ndarray, input_img_metadata) + op_output.set(out_image) + finally: + # Reset state on completing this method execution. + with self._lock: + self._executing = False + + def pre_process(self, img_reader) -> Union[Any, Image, Compose]: + """Transforms input before being used for predicting on a model.""" + + if not self.parser: + raise RuntimeError("ConfigParser object is None.") + + if self.parser.get(self._preproc_name) is not None: + preproc = self.parser.get_parsed_content(self._preproc_name) + self._pre_transform = self._filter_compose(preproc) + else: + self._pre_transform = Compose([]) # Could there be a scenario with no pre_processing? + + # Need to add the loadimage transform, single dataset key for now. + # TODO: MQ to find a better solution + load_image_transform = LoadImaged(keys=self.input_dataset_key, reader=img_reader) + self._pre_transform.transforms = (load_image_transform,) + self._pre_transform.transforms + + return self._pre_transform + + def post_process(self, pre_transforms: Compose, out_dir: str = "./infer_out") -> Union[Any, Image, Compose]: + """Transforms the prediction results from the model(s).""" + + if self.parser.get(self._postproc_name) is not None: + postproc = self.parser.get_parsed_content(self._postproc_name) + self._post_transforms = self._filter_compose(postproc) + else: + self._post_transforms = Compose([]) + return self._post_transforms + + def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any]: + """Predicts results using the models(s) with input tensors. + + This method must be overridden by a derived class. + + Raises: + NotImplementedError: When the subclass does not override this method. + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class InMemImageReader(ImageReader): + """Converts the App SDK Image object from memory. + + This is derived from MONAI ImageReader. Instead of reading image from file system, this + class simply converts a in-memory SDK Image object to the expected formats from ImageReader. + + The loaded data array will be in C order, for example, a 3D image NumPy array index order + will be `WHDC`. The actual data array loaded is to be the same as that from the + MONAI ITKReader, which can also load DICOM series. Furthermore, all Readers need to return the + array data the same way as the NibabelReader, i.e. a numpy array of index order WHDC with channel + being the last dim if present. More details are in the get_data() function. + + + """ + + def __init__(self, input_image: Image, channel_dim: Optional[int] = None, **kwargs): + super().__init__() + self.input_image = input_image + self.kwargs = kwargs + self.channel_dim = channel_dim + + def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: + return True + + def read(self, data: Union[Sequence[str], str], **kwargs) -> Union[Sequence[Any], Any]: + # Really does not have anything to do. Simply return the Image object + return self.input_image + + def get_data(self, input_image): + """Extracts data array and meta data from loaded image and return them. + + This function returns two objects, first is numpy array of image data, second is dict of meta data. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + A single image is loaded with a single set of metadata as of now. + + The App SDK Image asnumpy() function is expected to return a numpy array of index order `DHW`. + This is because in the DICOM series to volume operator pydicom Dataset pixel_array is used to + to get per instance pixel numpy array, with index order of `HW`. When all instances are stacked, + along the first axis, the Image numpy array's index order is `DHW`. ITK array_view_from_image + and SimpleITK GetArrayViewFromImage also returns a numpy array with the index order of `DHW`. + The channel would be the last dim/index if present. In the ITKReader get_data(), this numpy array + is then transposed, and the channel axis moved to be last dim post transpose; this is to be + consistent with the numpy returned from NibabelReader get_data(). + + The NibabelReader loads NIfTI image and uses the get_fdata() function of the loaded image to get + the numpy array, which has the index order in WHD with the channel being the last dim if present. + + Args: + input_image (Image): an App SDK Image object. + """ + + img_array: List[np.ndarray] = [] + compatible_meta: Dict = {} + + for i in ensure_tuple(input_image): + if not isinstance(i, Image): + raise TypeError("Only object of Image type is supported.") + + # The Image asnumpy() returns NumPy array similar to ITK array_view_from_image + # The array then needs to be transposed, as does in MONAI ITKReader, to align + # with the output from Nibabel reader loading NIfTI files. + data = i.asnumpy().T + img_array.append(data) + header = self._get_meta_dict(i) + _copy_compatible_dict(header, compatible_meta) + + # Stacking image is not really needed, as there is one image only. + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_meta_dict(self, img: Image) -> Dict: + """ + Gets the metadata of the image and converts to dict type. + + Args: + img: A SDK Image object. + """ + img_meta_dict: Dict = img.metadata() + meta_dict = {key: img_meta_dict[key] for key in img_meta_dict.keys()} + + # Will have to derive some key metadata as the SDK Image lacks the necessary interfaces. + # So, for now have to get to the Image generator, namely DICOMSeriesToVolumeOperator, and + # rely on its published metadata. + + # Referring to the MONAI ITKReader, the spacing is simply a NumPy array from the ITK image + # GetSpacing, in WHD. + meta_dict["spacing"] = np.asarray( + [ + img_meta_dict["row_pixel_spacing"], + img_meta_dict["col_pixel_spacing"], + img_meta_dict["depth_pixel_spacing"], + ] + ) + meta_dict["original_affine"] = np.asarray(img_meta_dict.get("nifti_affine_transform", None)) + meta_dict["affine"] = meta_dict["original_affine"] + # The spatial shape, again, referring to ITKReader, it is the WHD + meta_dict["spatial_shape"] = np.asarray(img.asnumpy().T.shape) + # Well, no channel as the image data shape is forced to the the same as spatial shape + meta_dict["original_channel_dim"] = "no_channel" + + return meta_dict + + +# Reuse MONAI code for the derived ImageReader +def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): + if not isinstance(to_dict, dict): + raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.") + if not to_dict: + for key in from_dict: + datum = from_dict[key] + if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None: + continue + to_dict[key] = datum + else: + affine_key, shape_key = "affine", "spatial_shape" + if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]): + raise RuntimeError( + "affine matrix of all images should be the same for channel-wise concatenation. " + f"Got {from_dict[affine_key]} and {to_dict[affine_key]}." + ) + if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]): + raise RuntimeError( + "spatial_shape of all images should be the same for channel-wise concatenation. " + f"Got {from_dict[shape_key]} and {to_dict[shape_key]}." + ) + + +def _stack_images(image_list: List, meta_dict: Dict): + if len(image_list) <= 1: + return image_list[0] + if meta_dict.get("original_channel_dim", None) not in ("no_channel", None): + raise RuntimeError("can not read a list of images which already have channel dimension.") + meta_dict["original_channel_dim"] = 0 + return np.stack(image_list, axis=0) From ca46a940870157e604f445418e514a4af3c92c7c Mon Sep 17 00:00:00 2001 From: mmelqin Date: Tue, 31 May 2022 11:15:15 -0700 Subject: [PATCH 02/13] Further simplified the code requiring model name only for bundle. Signed-off-by: mmelqin --- examples/apps/ai_spleen_seg_app/app.py | 11 +-- .../ai_spleen_seg_app/spleen_seg_operator.py | 22 ++---- .../monai_bundle_inference_operator.py | 79 +++++++++---------- 3 files changed, 49 insertions(+), 63 deletions(-) diff --git a/examples/apps/ai_spleen_seg_app/app.py b/examples/apps/ai_spleen_seg_app/app.py index f9da3846..31156990 100644 --- a/examples/apps/ai_spleen_seg_app/app.py +++ b/examples/apps/ai_spleen_seg_app/app.py @@ -11,13 +11,14 @@ import logging -from spleen_seg_operator import SpleenSegOperator +#from spleen_seg_operator import SpleenSegOperator from monai.deploy.core import Application, resource from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator +from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator from monai.deploy.operators.stl_conversion_operator import STLConversionOperator @@ -46,7 +47,7 @@ def compose(self): series_selector_op = DICOMSeriesSelectorOperator(Sample_Rules_Text) series_to_vol_op = DICOMSeriesToVolumeOperator() # Model specific inference operator, supporting MONAI transforms. - spleen_seg_op = SpleenSegOperator() + bundle_spleen_seg_op = MonaiBundleInferenceOperator(model_name="model") # Create DICOM Seg writer with segment label name in a string list dicom_seg_writer = DICOMSegmentationWriterOperator(seg_labels=["Spleen"]) # Create the surface mesh STL conversion operator @@ -58,14 +59,14 @@ def compose(self): self.add_flow( series_selector_op, series_to_vol_op, {"study_selected_series_list": "study_selected_series_list"} ) - self.add_flow(series_to_vol_op, spleen_seg_op, {"image": "image"}) + self.add_flow(series_to_vol_op, bundle_spleen_seg_op, {"image": ""}) # Note below the dicom_seg_writer requires two inputs, each coming from a upstream operator. self.add_flow( series_selector_op, dicom_seg_writer, {"study_selected_series_list": "study_selected_series_list"} ) - self.add_flow(spleen_seg_op, dicom_seg_writer, {"seg_image": "seg_image"}) + self.add_flow(bundle_spleen_seg_op, dicom_seg_writer, {"": "seg_image"}) # Add the STL conversion operator as another leaf operator taking as input the seg image. - self.add_flow(spleen_seg_op, stl_conversion_op, {"seg_image": "image"}) + self.add_flow(bundle_spleen_seg_op, stl_conversion_op, {"": "image"}) self._logger.debug(f"End {self.compose.__name__}") diff --git a/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py b/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py index f41340bc..2b6f1d16 100644 --- a/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py +++ b/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py @@ -26,15 +26,17 @@ class SpleenSegOperator(Operator): """Performs Spleen segmentation with a 3D image converted from a DICOM CT series. This operator makes use of the App SDK MonaiBundleInferenceOperator in a composition approach. - Loading of the model, and predicting using in-proc PyTorch inference is done by MonaiBundleInferenceOperator. + Parsing of the bundle, transforms and inference are done in the MonaiBundleInferenceOperator. + + Single named input of Image type and single named output of Image type are supported here. + Mapping of the I/O of the operator to the "keys" in MONAI transforms is not supported yet. """ - def __init__(self, *args, **kwargs): + def __init__(self, model_name:str = "", *args, **kwargs): self.logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) super().__init__(*args, **kwargs) - self._input_dataset_key = "image" - self._pred_dataset_key = "pred" + self._model_name = model_name def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext): @@ -43,17 +45,7 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe raise ValueError("Input image is not found.") # Delegates inference and saving output to the built-in operator. - infer_operator = MonaiBundleInferenceOperator( - roi_size=( - 96, - 96, - 96, - ), - ) - - # Setting the keys used in the dictironary based transforms may change. - infer_operator.input_dataset_key = self._input_dataset_key - infer_operator.pred_dataset_key = self._pred_dataset_key + infer_operator = MonaiBundleInferenceOperator(self._model_name) # Now let the built-in operator handles the work with the I/O spec and execution context. infer_operator.compute(op_input, op_output, context) diff --git a/monai/deploy/operators/monai_bundle_inference_operator.py b/monai/deploy/operators/monai_bundle_inference_operator.py index 7c2d83d1..6ae36ca5 100644 --- a/monai/deploy/operators/monai_bundle_inference_operator.py +++ b/monai/deploy/operators/monai_bundle_inference_operator.py @@ -9,11 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os -import zipfile import tempfile -import json +# from types import NoneType +import zipfile from pathlib import Path from threading import Lock from typing import Any, Dict, List, Optional, Sequence, Union @@ -39,6 +40,7 @@ MapTransform_, _ = optional_import("monai.transforms", name="MapTransform") LoadImaged_, _ = optional_import("monai.transforms", name="LoadImaged") ConfigParser_, _ = optional_import("monai.bundle", name="ConfigParser") +SimpleInferer, _ = optional_import("monai.inferers", name="SimpleInferer") # Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477 Compose: Any = Compose_ MapTransform: Any = MapTransform_ @@ -56,7 +58,10 @@ __all__ = ["MonaiBundleInferenceOperator", "InMemImageReader"] # TODO: For now, assuming single input and single output, but need to change -@md.env(pip_packages=["monai>=0.8.1", "torch>=1.10.02", "numpy>=1.21"]) +@md.input("image", Image, IOType.IN_MEMORY) +@md.output("seg_image", Image, IOType.IN_MEMORY) +#@md.env(pip_packages=["monai>=0.9.0", "torch>=1.10.02", "numpy>=1.21"]) +@md.env(pip_packages=["monai-weekly>=0.9.dev2221", "torch>=1.10.02", "numpy>=1.21"]) class MonaiBundleInferenceOperator(InferenceOperator): """This segmentation operator uses MONAI transforms and Sliding Window Inference. @@ -75,50 +80,40 @@ def __init__( bundle_path: Optional[str] = None, preproc_name: Optional[str] = "preprocessing", postproc_name: Optional[str] = "postprocessing", - pre_transforms: Optional[Compose] = None , + inferer_name: Optional[str] = "inferer", + pre_transforms: Optional[Compose] = None, post_transforms: Optional[Compose] = None, - roi_size: Union[Sequence[int], int] = (96, 96, 96,), + roi_size: Union[Sequence[int], int] = ( + 96, + 96, + 96, + ), overlap: float = 0.5, *args, **kwargs, ): - """Creates a instance of this class. - - Args: - model_name (Optional[str]): The name of the model in the MONAI Bundle. - bundle_path: Optional[str]: Path of the MONAI Bundle, overridden by model loader. - preproc_name: Optional[str]: Inference config item name for "preprocessing". - postproc_name: Optional[str]: Inference config name for "postprocessing". - pre_transforms (Compose): MONAI Compose object used for pre-transforms. - post_transforms (Compose): MONAI Compose object used for post-transforms. - roi_size (Union[Sequence[int], int]): The tensor size used in inference. - overlap (float): The overlap used in sliding window inference. - """ super().__init__(*args, **kwargs) self._executing = False self._lock = Lock() self._model_name = model_name if model_name else "" - self._bundle_path = Path(bundle_path).expanduser().resolve() \ - if bundle_path and len(bundle_path) > 0 else None + self._bundle_path = Path(bundle_path).expanduser().resolve() if bundle_path and len(bundle_path) > 0 else None self._preproc_name = preproc_name self._postproc_name = postproc_name + self._inferer_name = inferer_name self._parser = None # Delay init till execution context is set. self._pre_transform = pre_transforms self._post_transforms = post_transforms + self._inferer = None - #TODO MQ to clean up + # TODO MQ to clean up self._input_dataset_key = "image" self._pred_dataset_key = "pred" self._input_image = None # Image will come in when compute is called. self._reader: Any = None - self._roi_size = ensure_tuple(roi_size) - self.overlap = overlap - @property def model_name(self) -> str: - """The name of the model in the MONAI Bundle.""" return self._model_name @model_name.setter @@ -198,21 +193,20 @@ def _get_bundle_config(self, bundle_path: Path) -> ConfigParser: name = bundle_path.stem parser = ConfigParser() + print(f"bundle path: {bundle_path}") with tempfile.TemporaryDirectory() as td: archive = zipfile.ZipFile(str(bundle_path), "r") archive.extract(name + "/extra/metadata.json", td) archive.extract(name + "/extra/inference.json", td) - os.rename(f"{td}/{name}/extra/inference.json", f"{td}/{name}/extra/config.json") - parser.read_meta(f=f"{td}/{name}/extra/metadata.json") - parser.read_config(f=f"{td}/{name}/extra/config.json") + parser.read_config(f=f"{td}/{name}/extra/inference.json") parser.parse() return parser - def _filter_compose(self, compose:Compose): + def _filter_compose(self, compose: Compose): """ Remove transforms from the given Compose object which shouldn't be used in an Operator. """ @@ -260,6 +254,12 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe # Load the ConfigParser self.parser = self._get_bundle_config(self.bundle_path) + # Get the inferer + if self._parser.get(self._inferer_name) is not None: + self._inferer = self._parser.get_parsed_content(self._inferer_name) + else: + self._inferer = SimpleInferer() + try: input_image = op_input.get() if not input_image: @@ -270,12 +270,12 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context")) self._reader = InMemImageReader(input_image) # For convering Image to MONAI expected format - pre_transforms: Compose = \ - self._pre_transform if self._pre_transform else self.pre_process(self._reader) - post_transforms: Compose = \ + pre_transforms: Compose = self._pre_transform if self._pre_transform else self.pre_process(self._reader) + post_transforms: Compose = ( self._post_transforms if self._post_transforms else self.post_process(pre_transforms) + ) - #TODO: From bundle config + # TODO: From bundle config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms) @@ -284,14 +284,7 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe with torch.no_grad(): for d in dataloader: images = d[self._input_dataset_key].to(device) - sw_batch_size = 4 - d[self._pred_dataset_key] = sliding_window_inference( - inputs=images, - roi_size=self._roi_size, - sw_batch_size=sw_batch_size, - overlap=self.overlap, - predictor=model, - ) + d[self._pred_dataset_key] = self._inferer(inputs=images, network=model) d = [post_transforms(i) for i in decollate_batch(d)] out_ndarray = d[0][self._pred_dataset_key].cpu().numpy() # Need to squeeze out the channel dim fist @@ -327,7 +320,7 @@ def pre_process(self, img_reader) -> Union[Any, Image, Compose]: self._pre_transform = Compose([]) # Could there be a scenario with no pre_processing? # Need to add the loadimage transform, single dataset key for now. - # TODO: MQ to find a better solution + # TODO: MQ to find a better solution, or use Compose callable directly instead of dataloader load_image_transform = LoadImaged(keys=self.input_dataset_key, reader=img_reader) self._pre_transform.transforms = (load_image_transform,) + self._pre_transform.transforms @@ -399,7 +392,7 @@ def get_data(self, input_image): consistent with the numpy returned from NibabelReader get_data(). The NibabelReader loads NIfTI image and uses the get_fdata() function of the loaded image to get - the numpy array, which has the index order in WHD with the channel being the last dim if present. + the numpy array, which has the index order in WHD with the channel being the last dim_get_compose if present. Args: input_image (Image): an App SDK Image object. @@ -413,7 +406,7 @@ def get_data(self, input_image): raise TypeError("Only object of Image type is supported.") # The Image asnumpy() returns NumPy array similar to ITK array_view_from_image - # The array then needs to be transposed, as does in MONAI ITKReader, to align + # The array then needs to be transposed, as does in MONAI ITKReader, to align_get_compose # with the output from Nibabel reader loading NIfTI files. data = i.asnumpy().T img_array.append(data) From 2a5f48f980345b21266b10f407dab2e90f83df68 Mon Sep 17 00:00:00 2001 From: mmelqin Date: Tue, 21 Jun 2022 22:43:37 -0700 Subject: [PATCH 03/13] Check in the bundle operator and multipl model support. Signed-off-by: mmelqin --- .../ai_spleen_seg_app/spleen_seg_operator.py | 51 ------------------- 1 file changed, 51 deletions(-) delete mode 100644 examples/apps/ai_spleen_seg_app/spleen_seg_operator.py diff --git a/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py b/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py deleted file mode 100644 index 2b6f1d16..00000000 --- a/examples/apps/ai_spleen_seg_app/spleen_seg_operator.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2021 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from os import path - -from numpy import uint8 - -import monai.deploy.core as md -from monai.deploy.core import ExecutionContext, Image, InputContext, IOType, Operator, OutputContext -from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator - - -@md.input("image", Image, IOType.IN_MEMORY) -@md.output("seg_image", Image, IOType.IN_MEMORY) -@md.env(pip_packages=["monai>=0.8.1", "torch>=1.10.2", "numpy>=1.21", "nibabel"]) -class SpleenSegOperator(Operator): - """Performs Spleen segmentation with a 3D image converted from a DICOM CT series. - - This operator makes use of the App SDK MonaiBundleInferenceOperator in a composition approach. - Parsing of the bundle, transforms and inference are done in the MonaiBundleInferenceOperator. - - Single named input of Image type and single named output of Image type are supported here. - Mapping of the I/O of the operator to the "keys" in MONAI transforms is not supported yet. - """ - - def __init__(self, model_name:str = "", *args, **kwargs): - - self.logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) - super().__init__(*args, **kwargs) - self._model_name = model_name - - def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext): - - input_image = op_input.get("image") - if not input_image: - raise ValueError("Input image is not found.") - - # Delegates inference and saving output to the built-in operator. - infer_operator = MonaiBundleInferenceOperator(self._model_name) - - # Now let the built-in operator handles the work with the I/O spec and execution context. - infer_operator.compute(op_input, op_output, context) From 7e496e258a1e4d51fb7ccc16214aea3da8e71e96 Mon Sep 17 00:00:00 2001 From: mmelqin Date: Wed, 22 Jun 2022 00:18:10 -0700 Subject: [PATCH 04/13] Added changes for bundle operator and app. Signed-off-by: mmelqin --- docs/requirements.txt | 4 +- .../livertumor_seg_operator.py | 1 + examples/apps/ai_spleen_seg_app/app.py | 55 +- monai/deploy/operators/__init__.py | 4 +- .../monai_bundle_inference_operator.py | 897 +++++++++++------- .../operators/monai_seg_inference_operator.py | 5 +- requirements-dev.txt | 3 +- requirements-examples.txt | 3 + requirements.txt | 2 +- setup.cfg | 2 +- 10 files changed, 612 insertions(+), 364 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 9b36d855..2a9b707c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,7 @@ Sphinx==4.1.2 sphinx-autobuild==2021.3.14 myst-parser==0.15.2 -numpy==1.21 # CVE-2021-33430 +numpy==1.21.2 # CVE-2021-33430 matplotlib==3.3.4 ipywidgets==7.6.4 pandas==1.1.5 @@ -21,7 +21,7 @@ sphinxemoji==0.1.8 scipy scikit-image plotly -nibabel +nibabel>=3.2.1 monai pydicom sphinx-autodoc-typehints==1.12.0 diff --git a/examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py b/examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py index d30c0f04..f4ba3d44 100644 --- a/examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py +++ b/examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py @@ -94,6 +94,7 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe pre_transforms, post_transforms, overlap=0.6, + model_name=" ", ) # Setting the keys used in the dictironary based transforms may change. diff --git a/examples/apps/ai_spleen_seg_app/app.py b/examples/apps/ai_spleen_seg_app/app.py index 31156990..12b69e55 100644 --- a/examples/apps/ai_spleen_seg_app/app.py +++ b/examples/apps/ai_spleen_seg_app/app.py @@ -1,4 +1,4 @@ -# Copyright 2021 MONAI Consortium +# Copyright 2021-2022 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,15 +11,17 @@ import logging -#from spleen_seg_operator import SpleenSegOperator - from monai.deploy.core import Application, resource +from monai.deploy.core.domain import Image +from monai.deploy.core.domain.datapath import DataPath +from monai.deploy.core.io_type import IOType from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator -from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator -from monai.deploy.operators.stl_conversion_operator import STLConversionOperator +from monai.deploy.operators.monai_bundle_inference_operator import IOMapping, MonaiBundleInferenceOperator + +# from monai.deploy.operators.stl_conversion_operator import STLConversionOperator # import as needed. @resource(cpu=1, gpu=1, memory="7Gi") @@ -33,25 +35,35 @@ def __init__(self, *args, **kwargs): def run(self, *args, **kwargs): # This method calls the base class to run. Can be omitted if simply calling through. - self._logger.debug(f"Begin {self.run.__name__}") + self._logger.info(f"Begin {self.run.__name__}") super().run(*args, **kwargs) - self._logger.debug(f"End {self.run.__name__}") + self._logger.info(f"End {self.run.__name__}") def compose(self): """Creates the app specific operators and chain them up in the processing DAG.""" - self._logger.debug(f"Begin {self.compose.__name__}") + logging.info(f"Begin {self.compose.__name__}") - # Creates the custom operator(s) as well as SDK built-in operator(s). + # Create the custom operator(s) as well as SDK built-in operator(s). study_loader_op = DICOMDataLoaderOperator() series_selector_op = DICOMSeriesSelectorOperator(Sample_Rules_Text) series_to_vol_op = DICOMSeriesToVolumeOperator() - # Model specific inference operator, supporting MONAI transforms. - bundle_spleen_seg_op = MonaiBundleInferenceOperator(model_name="model") + + # Create the inference operator that supports MONAI Bundle and automates the inference. + # The IOMapping labels match the input and prediction keys in the pre and post processing. + # The model_name is optional when the app has only one model. + # The bundle_path argument optionally can be set to an accessible bundle file path in the dev + # environment, so when the app is packaged into a MAP, the operator can complete the bundle parsing + # during init to provide the optional packages info, parsed from the bundle, to the packager + # for it to install the packages in the MAP docker image. + # Setting output IOType to DISK only works only for leaf operators, not the case in this example. + bundle_spleen_seg_op = MonaiBundleInferenceOperator( + input_mapping=[IOMapping("image", Image, IOType.IN_MEMORY)], + output_mapping=[IOMapping("pred", Image, IOType.IN_MEMORY)], + ) + # Create DICOM Seg writer with segment label name in a string list dicom_seg_writer = DICOMSegmentationWriterOperator(seg_labels=["Spleen"]) - # Create the surface mesh STL conversion operator - stl_conversion_op = STLConversionOperator(output_file="stl/spleen.stl") # Create the processing pipeline, by specifying the upstream and downstream operators, and # ensuring the output from the former matches the input of the latter, in both name and type. @@ -59,16 +71,18 @@ def compose(self): self.add_flow( series_selector_op, series_to_vol_op, {"study_selected_series_list": "study_selected_series_list"} ) - self.add_flow(series_to_vol_op, bundle_spleen_seg_op, {"image": ""}) + self.add_flow(series_to_vol_op, bundle_spleen_seg_op, {"image": "image"}) # Note below the dicom_seg_writer requires two inputs, each coming from a upstream operator. self.add_flow( series_selector_op, dicom_seg_writer, {"study_selected_series_list": "study_selected_series_list"} ) - self.add_flow(bundle_spleen_seg_op, dicom_seg_writer, {"": "seg_image"}) - # Add the STL conversion operator as another leaf operator taking as input the seg image. - self.add_flow(bundle_spleen_seg_op, stl_conversion_op, {"": "image"}) + self.add_flow(bundle_spleen_seg_op, dicom_seg_writer, {"pred": "seg_image"}) + # Create the surface mesh STL conversion operator and add it to the app execution flow, if needed, by + # uncommenting the following couple lines. + # stl_conversion_op = STLConversionOperator(output_file="stl/spleen.stl") + # self.add_flow(bundle_spleen_seg_op, stl_conversion_op, {"pred": "image"}) - self._logger.debug(f"End {self.compose.__name__}") + logging.info(f"End {self.compose.__name__}") # This is a sample series selection rule in JSON, simply selecting CT series. @@ -95,8 +109,7 @@ def compose(self): # -i , for input DICOM CT series folder # -o , for the output folder, default $PWD/output # e.g. - # python3 app.py -i input -m model/model.ts + # monai-deploy exec app.py -i input -m model/model.ts # logging.basicConfig(level=logging.DEBUG) - app_instance = AISpleenSegApp() # Optional params' defaults are fine. - app_instance.run() + app_instance = AISpleenSegApp(do_run=True) diff --git a/monai/deploy/operators/__init__.py b/monai/deploy/operators/__init__.py index 10ab403d..b0ac5ed9 100644 --- a/monai/deploy/operators/__init__.py +++ b/monai/deploy/operators/__init__.py @@ -12,6 +12,7 @@ .. autosummary:: :toctree: _autosummary + BundleConfigNames ClaraVizOperator DICOMDataLoaderOperator DICOMSegmentationWriterOperator @@ -19,6 +20,7 @@ DICOMSeriesToVolumeOperator DICOMTextSRWriterOperator InferenceOperator + IOMapping MonaiBundleInferenceOperator MonaiSegInferenceOperator PNGConverterOperator @@ -34,7 +36,7 @@ from .dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator from .dicom_text_sr_writer_operator import DICOMTextSRWriterOperator, EquipmentInfo, ModelInfo from .inference_operator import InferenceOperator -from .monai_bundle_inference_operator import MonaiBundleInferenceOperator +from .monai_bundle_inference_operator import BundleConfigNames, IOMapping, MonaiBundleInferenceOperator from .monai_seg_inference_operator import MonaiSegInferenceOperator from .png_converter_operator import PNGConverterOperator from .publisher_operator import PublisherOperator diff --git a/monai/deploy/operators/monai_bundle_inference_operator.py b/monai/deploy/operators/monai_bundle_inference_operator.py index 6ae36ca5..407c46fa 100644 --- a/monai/deploy/operators/monai_bundle_inference_operator.py +++ b/monai/deploy/operators/monai_bundle_inference_operator.py @@ -1,4 +1,4 @@ -# Copyright 2021-2002 MONAI Consortium +# Copyright 2002 MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,107 +10,285 @@ # limitations under the License. import json +import logging import os -import tempfile - -# from types import NoneType +import pickle +import time import zipfile +from http.client import OK +from multiprocessing.sharedctypes import Value from pathlib import Path from threading import Lock -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy as np +from genericpath import exists +import monai.deploy.core as md +from monai.deploy.core import DataPath, ExecutionContext, Image, InputContext, IOType, OutputContext +from monai.deploy.core.operator import OperatorEnv +from monai.deploy.exceptions import ItemNotExistsError from monai.deploy.utils.importutil import optional_import -from monai.transforms.io.dictionary import LoadImageD -torch, _ = optional_import("torch", "1.5") -np_str_obj_array_pattern, _ = optional_import("torch.utils.data._utils.collate", name="np_str_obj_array_pattern") -Dataset, _ = optional_import("monai.data", name="Dataset") -DataLoader, _ = optional_import("monai.data", name="DataLoader") -ImageReader_, image_reader_ok_ = optional_import("monai.data", name="ImageReader") -# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477 -ImageReader: Any = ImageReader_ -if not image_reader_ok_: - ImageReader = object # for 'class InMemImageReader(ImageReader):' to work -decollate_batch, _ = optional_import("monai.data", name="decollate_batch") -sliding_window_inference, _ = optional_import("monai.inferers", name="sliding_window_inference") -ensure_tuple, _ = optional_import("monai.utils", name="ensure_tuple") +from .inference_operator import InferenceOperator + +nibabel, _ = optional_import("nibabel", "3.2.1") +torch, _ = optional_import("torch", "1.10.0") + +PostFix, _ = optional_import("monai.utils.enums", name="PostFix") # For the default meta_key_postfix +first, _ = optional_import("monai.utils.misc", name="first") Compose_, _ = optional_import("monai.transforms", name="Compose") -MapTransform_, _ = optional_import("monai.transforms", name="MapTransform") -LoadImaged_, _ = optional_import("monai.transforms", name="LoadImaged") ConfigParser_, _ = optional_import("monai.bundle", name="ConfigParser") +MapTransform_, _ = optional_import("monai.transforms", name="MapTransform") SimpleInferer, _ = optional_import("monai.inferers", name="SimpleInferer") + # Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477 Compose: Any = Compose_ MapTransform: Any = MapTransform_ -LoadImaged: Any = LoadImaged_ ConfigParser: Any = ConfigParser_ -simple_inference, _ = optional_import("monai.inferers", name="simple_inference") -sliding_window_inference, _ = optional_import("monai.inferers", name="sliding_window_inference") +__all__ = ["MonaiBundleInferenceOperator", "IOMapping", "BundleConfigNames"] -import monai.deploy.core as md -from monai.deploy.core import ExecutionContext, Image, InputContext, IOType, OutputContext -from .inference_operator import InferenceOperator +def get_bundle_config(bundle_path, config_names): + """ + Gets the configuration parser from the specified Torchscript bundle file path. + """ -__all__ = ["MonaiBundleInferenceOperator", "InMemImageReader"] + def _read_from_archive(archive, root_name: str, relative_path: str, path_list: List[str]): + """A helper function for reading a file in an zip archive. -# TODO: For now, assuming single input and single output, but need to change -@md.input("image", Image, IOType.IN_MEMORY) -@md.output("seg_image", Image, IOType.IN_MEMORY) -#@md.env(pip_packages=["monai>=0.9.0", "torch>=1.10.02", "numpy>=1.21"]) -@md.env(pip_packages=["monai-weekly>=0.9.dev2221", "torch>=1.10.02", "numpy>=1.21"]) -class MonaiBundleInferenceOperator(InferenceOperator): - """This segmentation operator uses MONAI transforms and Sliding Window Inference. + Tries to read with the full path of # a archive file, if error, then find the relative + path and then read the file. + """ + content_text = None + try: + content_text = archive.read(f"{root_name}/{relative_path}") + except KeyError: + logging.debug(f"Trying to find the metadata/config file in the bundle archive: {relative_path}.") + for n in path_list: + if relative_path in n: + content_text = archive.read(n) + break + if content_text is None: + raise + + return content_text + + if isinstance(config_names, str): + config_names = [config_names] + + name, _ = os.path.splitext(os.path.basename(bundle_path)) + parser = ConfigParser() + + # Parser to read the required metadata and extra config contents from the archive + with zipfile.ZipFile(bundle_path, "r") as archive: + name_list = archive.namelist() + metadata_relative_path = "extra/metadata.json" + metadata_text = _read_from_archive(archive, name, metadata_relative_path, name_list) + parser.read_meta(f=json.loads(metadata_text)) + + for cn in config_names: + config_relative_path = f"extra/{cn}.json" + config_text = _read_from_archive(archive, name, config_relative_path, name_list) + parser.read_config(f=json.loads(config_text)) + + parser.parse() + + return parser + + +DISALLOW_LOAD_SAVE = ["LoadImage", "SaveImage"] +DISALLOW_SAVE = ["SaveImage"] + + +def filter_compose(compose, disallowed_prefixes): + """ + Removes transforms from the given Compose object whose names begin with `disallowed_prefixes`. + """ + filtered = [] + for t in compose.transforms: + tname = type(t).__name__ + if not any(dis in tname for dis in disallowed_prefixes): + filtered.append(t) + + compose.transforms = tuple(filtered) + return compose + + +def is_map_compose(compose): + """ + Returns True if the given Compose object uses MapTransform instances. + """ + return isinstance(first(compose.transforms), MapTransform) + + +class IOMapping: + """This object holds an I/O definition for an operator.""" + + def __init__( + self, + label: str, + data_type: Type, + storage_type: IOType, + ): + """Creates an object holding an operator I/O definitions. + + Limitations apply with the combination of data_type and storage_type, which will + be validated at runtime. + + Args: + label (str): Label for the operator input or output. + data_type (Type): Datatype of the I/O data content. + storage_type (IOType): The storage type expected, i.e. IN_MEMORY or DISK. + """ + self.label: str = label + self.data_type: Type = data_type + self.storage_type: IOType = storage_type + + +class BundleConfigNames: + """This object holds the name of relevant config items used in a MONAI Bundle.""" + + def __init__( + self, + preproc_name: str = "preprocessing", + postproc_name: str = "postprocessing", + inferer_name: str = "inferer", + config_names: Union[List[str], Tuple[str], str] = ["inference"], + ) -> None: + """Creates an object holding the names of relevant config items in a MONAI Bundle. + + This object holds the names of the config items in a MONAI Bundle that will need to be + parsed by the inference operator for automating the object creations and inference. + Defaults values are provided per conversion, so the arguments only need to be set as needed. + + Args: + preproc_name (str, optional): Name of the config item for pre-processing transforms. + Defaults to "preprocessing". + postproc_name (str, optional): Name of the config item for post-processing transforms. + Defaults to "postprocessing". + inferer_name (str, optional): Name of the config item for inferer. + Defaults to "inferer". + config_names (List[str], optional): Name of config file(s) in the Bundle for parsing. + Defaults to ["inference"]. File ext must be .json. + """ + + def _ensure_str_list(config_names): + names = [] + if isinstance(config_names, (List, Tuple)): + if len(config_names) < 1: + raise ValueError("At least one config name must be provided.") + names = [str(name) for name in config_names] + else: + names = [str(config_names)] + + return names - This operator preforms pre-transforms on a input image, inference - using a given model, and post-transforms. The segmentation image is saved - as a named Image object in memory. + self.preproc_name: str = preproc_name + self.postproc_name: str = postproc_name + self.inferer_name: str = inferer_name + self.config_names: List[str] = _ensure_str_list(config_names) - If specified in the post transforms, results may also be saved to disk. + +# The operator env decorator defines the required pip packages commonly used in the Bundles. +# The MONAI Deploy App SDK packager currently relies on the App to consolidate all required packages in order to +# install them in the MAP Docker image. +# TODO: Dynamically setting the pip_packages env on init requires the bundle path be passed in. Apps using this +# operator may choose to pass in a accessible bundle path at development and packaging stage. Ideally, +# the bundle path should be passed in by the Packager, e.g. via env var, when the App is initialized. +# As of now, the Packager only passes in the model path after the App including all operators are init'ed. +@md.env(pip_packages=["monai>=0.9.0", "torch>=1.10.02", "numpy>=1.21", "nibabel>=3.2.1"]) +class MonaiBundleInferenceOperator(InferenceOperator): + """This inference operator automates the inference operation for a given MONAI Bundle. + + This inference operator configures itself based on the parsed data from a MONAI bundle file. This file is included + with a MAP as a Torchscript file with added bundle metadata or a zipped bundle with weights. The class will + configure how to do pre- and post-processing, inference, which device to use, state its inputs, outputs, and + dependencies. Its compute method is meant to be general purpose to most any bundle such that it will handle + any input specified in the bundle and produce output as specified, using the inference object the bundle defines. + A number of methods are provided which define parts of functionality relating to this behavior, users may wish + to overwrite these to change behavior is needed for specific bundles. + + The input(s) and output(s) for this operator need to be provided when an instance is created, and their labels need + to correspond to the bundle network input and output names, which are also used as the keys in the pre and post processing. + + For image input and output, the type is the `Image` class. For output of probabilities, the type is `Dict`. + + This operator is expected to be linked with both upstream and downstream operators, e.g. receiving an `Image` object from + the `DICOMSeriesToVolumeOperator`, and passing a segmentation `Image` to the `DICOMSegmentationWriterOperator`. + In such cases, the I/O storage type can only be `IN_MEMORY` due to the restrictions imposed by the application executor. + However, when used as the first operator in an application, its input storage type needs to be `DISK`, and the file needs + to be a Python pickle file, e.g. containing an `Image` instance. When used as the last operator, its output storage type + also needs to `DISK` with the path being the application's output folder, and the operator's output will be saved as + a pickle file whose name is the same as the output name. """ DISALLOWED_TRANSFORMS = ["LoadImage", "SaveImage"] + known_io_data_types = { + "image": Image, # Image object + "series": np.ndarray, + "tuples": np.ndarray, + "probabilities": Dict[str, Any], # dictionary containing probabilities and predicted labels + } + def __init__( self, + input_mapping: List[IOMapping], + output_mapping: List[IOMapping], model_name: Optional[str] = "", bundle_path: Optional[str] = None, - preproc_name: Optional[str] = "preprocessing", - postproc_name: Optional[str] = "postprocessing", - inferer_name: Optional[str] = "inferer", - pre_transforms: Optional[Compose] = None, - post_transforms: Optional[Compose] = None, - roi_size: Union[Sequence[int], int] = ( - 96, - 96, - 96, - ), - overlap: float = 0.5, + bundle_config_names: BundleConfigNames = BundleConfigNames(), *args, **kwargs, ): + """_summary_ + + Args: + input_mapping (List[IOMapping]): Define the inputs' name, type, and storage type. + output_mapping (List[IOMapping]): Defines the outputs' name, type, and storage type. + model_name (Optional[str], optional): Name of the model/bundle, needed in multi-model case. Defaults to "". + bundle_path (Optional[str], optional): For completing . Defaults to None. + bundle_config_names (BundleConfigNames, optional): Relevant config item names in a the bundle. Defaults to BundleConfigNames(). + """ super().__init__(*args, **kwargs) self._executing = False self._lock = Lock() - self._model_name = model_name if model_name else "" - self._bundle_path = Path(bundle_path).expanduser().resolve() if bundle_path and len(bundle_path) > 0 else None - self._preproc_name = preproc_name - self._postproc_name = postproc_name - self._inferer_name = inferer_name - self._parser = None # Delay init till execution context is set. - self._pre_transform = pre_transforms - self._post_transforms = post_transforms - self._inferer = None - - # TODO MQ to clean up - self._input_dataset_key = "image" - self._pred_dataset_key = "pred" - self._input_image = None # Image will come in when compute is called. - self._reader: Any = None + + self._model_name = model_name.strip() if isinstance(model_name, str) else "" + self._bundle_config_names = bundle_config_names + self._input_mapping = input_mapping + self._output_mapping = output_mapping + + self._parser = None # Needs known bundle path, either on init or when compute function is called. + self._inferer = None # Will be set during bundle parsing. + self._init_completed = False + + # Need to set the operator's input(s) and output(s). Even when the bundle parsing is done in init, + # there is still a need to define what op inputs/outputs map to what keys in the bundle config, + # along with the op input/output storage type. + # Also, the App Executor needs to set the IO context of the operator before calling the compute function. + self._add_inputs(self._input_mapping) + self._add_outputs(self._output_mapping) + + # Complete the init if the bundle path is known, otherwise delay till the compute function is called + # and try to get the model/bundle path from the execution context. + try: + self._bundle_path = ( + Path(bundle_path).expanduser().resolve() if bundle_path and len(bundle_path.strip()) > 0 else None + ) + + if self._bundle_path and self._bundle_path.exists(): + self._init_config(self._bundle_config_names.config_names) + self._init_completed = True + else: + logging.debug(f"Bundle path, {self._bundle_path}, not valid. Will get it in the execution context.") + self._bundle_path = None + except Exception: + logging.warn("Bundle parsing is not completed on init, delayed till this operator is called to execute.") + self._bundle_path = None @property def model_name(self) -> str: @@ -118,7 +296,7 @@ def model_name(self) -> str: @model_name.setter def model_name(self, name: str): - if not name or len(name) == 0: + if not name or isinstance(name, str): raise ValueError(f"Value, {name}, must be a non-empty string.") self._model_name = name @@ -145,293 +323,366 @@ def parser(self, parser: ConfigParser): else: raise ValueError(f"Value must be a valid ConfigParser object.") - ## + def _init_config(self, config_names): + """Completes the init with a known path to the MONAI Bundle - @property - def input_dataset_key(self): - """This is the input image key name used in dictionary based MONAI pre-transforms.""" - return self._input_dataset_key + Args: + config_names ([str]): Names of the config (files) in the bundle + """ - @input_dataset_key.setter - def input_dataset_key(self, val: str): - if not val or len(val) < 1: - raise ValueError("Value cannot be None or blank.") - self._input_dataset_key = val + parser = get_bundle_config(str(self._bundle_path), config_names) + self._parser = parser - @property - def pred_dataset_key(self): - """This is the prediction key name used in dictionary based MONAI post-transforms.""" - return self._pred_dataset_key + meta = self.parser["_meta_"] - @pred_dataset_key.setter - def pred_dataset_key(self, val: str): - if not val or len(val) < 1: - raise ValueError("Value cannot be None or blank.") - self._pred_dataset_key = val + # When this function is NOT called by the __init__, setting the pip_packages env here + # will not get dependencies to the App SDK Packager to install the packages in the MAP. + pip_packages = ["monai"] + [f"{k}=={v}" for k, v in meta["optional_packages_version"].items()] + if self._env: + self._env.pip_packages.extend(pip_packages) # Duplicates will be figured out on use. + else: + self._env = OperatorEnv(pip_packages=pip_packages) - @property - def overlap(self): - """This is the overlap used during sliding window inference""" - return self._overlap + if parser.get("device") is not None: + self._device = parser.get_parsed_content("device") + else: + self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - @overlap.setter - def overlap(self, val: float): - if val < 0 or val > 1: - raise ValueError("Overlap must be between 0 and 1.") - self._overlap = val + if parser.get(self._bundle_config_names.inferer_name) is not None: + self._inferer = parser.get_parsed_content(self._bundle_config_names.inferer_name) + else: + self._inferer = SimpleInferer() - def _get_bundle_config(self, bundle_path: Path) -> ConfigParser: - """Get the MONAI configuration parser from the specified MONAI Bundle file path. + self._inputs = meta["network_data_format"]["inputs"] + self._outputs = meta["network_data_format"]["outputs"] - Args: - bundle_path (Path): Path of the MONAI Bundle + # Given the restriction on operator I/O storage type, and known use cases, the I/O storage type of + # this operator is limited to IN_MEMRORY objects, so we will remove the LoadImage and SaveImage + self._preproc = self._get_compose(self._bundle_config_names.preproc_name, DISALLOW_LOAD_SAVE) + self._postproc = self._get_compose(self._bundle_config_names.postproc_name, DISALLOW_LOAD_SAVE) - Returns: - ConfigParser: MONAI Bundle config parser - """ - # The final path component, without its suffix, is expected to the model name - name = bundle_path.stem - parser = ConfigParser() + # Need to find out the meta_key_postfix. The key name of the input concatenated with this postfix + # will be the key name for the metadata for the input. + # Customized metadata key names are not supported as of now. + self._meta_key_postfix = self._get_meta_key_postfix(self._preproc) + + logging.debug(f"Effective transforms in pre-processing: {[type(t).__name__ for t in self._preproc.transforms]}") + logging.debug( + f"Effective Transforms in post-processing: {[type(t).__name__ for t in self._preproc.transforms]}" + ) - print(f"bundle path: {bundle_path}") - with tempfile.TemporaryDirectory() as td: - archive = zipfile.ZipFile(str(bundle_path), "r") - archive.extract(name + "/extra/metadata.json", td) - archive.extract(name + "/extra/inference.json", td) + def _get_compose(self, obj_name, disallowed_prefixes): + """Gets a Compose object containing a sequence fo transforms from item `obj_name` in `self._parser`.""" - parser.read_meta(f=f"{td}/{name}/extra/metadata.json") - parser.read_config(f=f"{td}/{name}/extra/inference.json") + if self._parser.get(obj_name) is not None: + compose = self._parser.get_parsed_content(obj_name) + return filter_compose(compose, disallowed_prefixes) - parser.parse() + return Compose([]) - return parser + def _get_meta_key_postfix(self, compose: Compose, key_name: str = "meta_key_postfix") -> str: + post_fix = PostFix.meta() + if compose and key_name: + for t in compose.transforms: + if isinstance(t, MapTransform) and hasattr(t, key_name): + post_fix = getattr(t, key_name) + # For some reason the attr is a tuple + if isinstance(post_fix, tuple): + post_fix = post_fix[0] + break - def _filter_compose(self, compose: Compose): + return post_fix + + def _get_io_data_type(self, conf): """ - Remove transforms from the given Compose object which shouldn't be used in an Operator. + Gets the input/output type of the given input or output metadata dictionary. The known Python types for input + or output types are given in the dictionary `BundleOperator.known_io_data_types` which relate type names to + the actual type. if `conf["type"]` is an actual object that's not a string then this is assumed to be the + type specifier and is returned. The fallback type is `bytes` which indicates the value is a pickled object. + + Args: + conf: configuration dictionary for an input or output from the "network_data_format" metadata section + + Returns: + A concrete type associated with this input/output type, this can be Image or np.ndarray or a Python type """ - if not compose: - return Compose([]) # Could just bounce the None input back. + # The Bundle's network_data_format for inputs and outputs does not indicate the storage type, i.e. IN_MEMORY + # or DISK, for the input(s) and output(s) of the operators. Configuration is required, though limited to + # IN_MEMORY for now. + # Certain association and transform are also required. The App SDK IN_MEMORY I/O can hold + # Any type, so if the type match and content format matches, data can simply be used as is, however, with + # the Type being Image, the object needs to be converted before being used as the expected "image" type. + ctype = conf["type"] + if ctype in self.known_io_data_types: # known type name from the specification + return self.known_io_data_types[ctype] + elif isinstance(ctype, type): # type object + return ctype + else: # don't know, something that hasn't been figured out + return object + + def _add_inputs(self, input_mapping: List[IOMapping]): + """Adds operator inputs as specified.""" - filtered = [] - for t in compose.transforms: - tname = type(t).__name__ - if not any(dis in tname for dis in self.DISALLOWED_TRANSFORMS): - filtered.append(t) + [self.add_input(v.label, v.data_type, v.storage_type) for v in input_mapping] - return Compose(filtered) + def _add_outputs(self, output_mapping: List[IOMapping]): + """Adds operator outputs as specified.""" + + [self.add_output(v.label, v.data_type, v.storage_type) for v in output_mapping] def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext): - """Infers with the input image and save the predicted image to output + """Infers with the input(s) and saves the prediction result(s) to output Args: op_input (InputContext): An input context for the operator. op_output (OutputContext): An output context for the operator. context (ExecutionContext): An execution context for the operator. """ - with self._lock: - if self._executing: - raise RuntimeError("Operator is already executing.") - else: - self._executing = True - - # If present, get the compliant model from context, else, from bundle path if given - model = None - if context.models: - # `context.models.get(model_name)` returns a model instance if exists. - # If model_name is not specified and only one model exists, it returns that model. - model = context.models.get(self.model_name) - if model: - self.bundle_path = model.path - if not model and self.bundle_path: - print(f"Loading TorchScript model from: {self.bundle_path}") - model = torch.jit.load(self.bundle_path, map_location=device) - - if not model: - raise IOError("Cannot find model file.") - - # Load the ConfigParser - self.parser = self._get_bundle_config(self.bundle_path) - - # Get the inferer - if self._parser.get(self._inferer_name) is not None: - self._inferer = self._parser.get_parsed_content(self._inferer_name) - else: - self._inferer = SimpleInferer() - try: - input_image = op_input.get() - if not input_image: - raise ValueError("Input is None.") - - input_img_metadata = input_image.metadata() - # Need to give a name to the image as in-mem Image obj has no name. - img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context")) - - self._reader = InMemImageReader(input_image) # For convering Image to MONAI expected format - pre_transforms: Compose = self._pre_transform if self._pre_transform else self.pre_process(self._reader) - post_transforms: Compose = ( - self._post_transforms if self._post_transforms else self.post_process(pre_transforms) - ) - - # TODO: From bundle config - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms) - dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) - - with torch.no_grad(): - for d in dataloader: - images = d[self._input_dataset_key].to(device) - d[self._pred_dataset_key] = self._inferer(inputs=images, network=model) - d = [post_transforms(i) for i in decollate_batch(d)] - out_ndarray = d[0][self._pred_dataset_key].cpu().numpy() - # Need to squeeze out the channel dim fist - out_ndarray = np.squeeze(out_ndarray, 0) - # NOTE: The domain Image object simply contains a Arraylike obj as image as of now. - # When the original DICOM series is converted by the Series to Volume operator, - # using pydicom pixel_array, the 2D ndarray of each slice has index order HW, and - # when all slices are stacked with depth as first axis, DHW. In the pre-transforms, - # the image gets transposed to WHD and used as such in the inference pipeline. - # So once post-transforms have completed, and the channel is squeezed out, - # the resultant ndarray for the prediction image needs to be transposed back, so the - # array index order is back to DHW, the same order as the in-memory input Image obj. - out_ndarray = out_ndarray.T.astype(np.uint8) - print(f"Output Seg image numpy array shaped: {out_ndarray.shape}") - print(f"Output Seg image pixel max value: {np.amax(out_ndarray)}") - out_image = Image(out_ndarray, input_img_metadata) - op_output.set(out_image) - finally: - # Reset state on completing this method execution. - with self._lock: - self._executing = False - - def pre_process(self, img_reader) -> Union[Any, Image, Compose]: - """Transforms input before being used for predicting on a model.""" - - if not self.parser: - raise RuntimeError("ConfigParser object is None.") - - if self.parser.get(self._preproc_name) is not None: - preproc = self.parser.get_parsed_content(self._preproc_name) - self._pre_transform = self._filter_compose(preproc) + # Try to get the Model object and its path from the context. + # If operator is not fully initialized, use model path as bundle path to finish it. + # If Model not loaded, but bundle path exists, load model, just in case. + # + # `context.models.get(model_name)` returns a model instance if exists. + # If model_name is not specified and only one model exists, it returns that model. + model = context.models.get(self._model_name) if context.models else None + if model: + if not self._init_completed: + with self._lock: + if not self._init_completed: + self._bundle_path = model.path + self._init_config(self._bundle_config_names.config_names) + self._init_completed + elif self._bundle_path: + logging.debug(f"Model network not loaded. Trying to load from model path: {self._bundle_path}") + model = torch.jit.load(self.bundle_path, map_location=self._device).eval() else: - self._pre_transform = Compose([]) # Could there be a scenario with no pre_processing? - - # Need to add the loadimage transform, single dataset key for now. - # TODO: MQ to find a better solution, or use Compose callable directly instead of dataloader - load_image_transform = LoadImaged(keys=self.input_dataset_key, reader=img_reader) - self._pre_transform.transforms = (load_image_transform,) + self._pre_transform.transforms - - return self._pre_transform - - def post_process(self, pre_transforms: Compose, out_dir: str = "./infer_out") -> Union[Any, Image, Compose]: - """Transforms the prediction results from the model(s).""" - - if self.parser.get(self._postproc_name) is not None: - postproc = self.parser.get_parsed_content(self._postproc_name) - self._post_transforms = self._filter_compose(postproc) + raise IOError("Model network is not load and model file not found.") + + first_input_name, *other_names = list(self._inputs.keys()) + + with torch.no_grad(): + inputs = {} + + start = time.time() + for name in self._inputs.keys(): + value, metadata = self._receive_input(name, op_input, context) + inputs[name] = value + if metadata: + inputs[(f"{name}_{self._meta_key_postfix}")] = metadata + + inputs = self.pre_process(inputs) + first_input = inputs.pop(first_input_name)[None].to(self._device) # select first input + input_metadata = inputs.get(f"{first_input_name}_{self._meta_key_postfix}", None) + + # select other tensor inputs + other_inputs = {k: v[None].to(self._device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} + # select other non-tensor inputs + other_inputs.update({k: inputs[k] for k in other_names if not isinstance(inputs[k], torch.Tensor)}) + logging.debug(f"Ingest and Pre-processing elapsed time (seconds): {time.time() - start}") + + start = time.time() + outputs = self.predict(data=first_input, network=model, **other_inputs) + logging.debug(f"Inference elapsed time (seconds): {time.time() - start}") + + # TODO: Does this work for models where multiple outputs are returned? + # Note that the inputs are needed because the invert transform requires it. + start = time.time() + outputs = self.post_process(outputs[0], inputs) + logging.debug(f"Post-processing elapsed time (seconds): {time.time() - start}") + if isinstance(outputs, (tuple, list)): + output_dict = dict(zip(self._outputs.keys(), outputs)) + elif not isinstance(outputs, dict): + output_dict = {first(self._outputs.keys()): outputs} else: - self._post_transforms = Compose([]) - return self._post_transforms - - def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any]: - """Predicts results using the models(s) with input tensors. - - This method must be overridden by a derived class. - - Raises: - NotImplementedError: When the subclass does not override this method. - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - -class InMemImageReader(ImageReader): - """Converts the App SDK Image object from memory. - - This is derived from MONAI ImageReader. Instead of reading image from file system, this - class simply converts a in-memory SDK Image object to the expected formats from ImageReader. + output_dict = outputs - The loaded data array will be in C order, for example, a 3D image NumPy array index order - will be `WHDC`. The actual data array loaded is to be the same as that from the - MONAI ITKReader, which can also load DICOM series. Furthermore, all Readers need to return the - array data the same way as the NibabelReader, i.e. a numpy array of index order WHDC with channel - being the last dim if present. More details are in the get_data() function. + for name in self._outputs.keys(): + # Note that the input metadata needs to be passed. + # Please see the comments in the called function for the reasons. + self._send_output(output_dict[name], name, input_metadata, op_output, context) + def predict(self, data: Any, network: Any, *args, **kwargs) -> Union[Image, Any]: + """Predicts output using the inferer.""" + return self._inferer(inputs=data, network=network, *args, **kwargs) - """ - - def __init__(self, input_image: Image, channel_dim: Optional[int] = None, **kwargs): - super().__init__() - self.input_image = input_image - self.kwargs = kwargs - self.channel_dim = channel_dim + def pre_process(self, data: Any) -> Union[Image, Any]: + """Processes the input dictionary with the stored transform sequence `self._preproc`.""" - def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool: - return True + if is_map_compose(self._preproc): + return self._preproc(data) + return {k: self._preproc(v) for k, v in data.items()} - def read(self, data: Union[Sequence[str], str], **kwargs) -> Union[Sequence[Any], Any]: - # Really does not have anything to do. Simply return the Image object - return self.input_image + def post_process(self, data: Any, inputs: Dict) -> Union[Image, Any]: + """Processes the output list/dictionary with the stored transform sequence `self._postproc`.""" - def get_data(self, input_image): - """Extracts data array and meta data from loaded image and return them. - - This function returns two objects, first is numpy array of image data, second is dict of meta data. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - A single image is loaded with a single set of metadata as of now. + if is_map_compose(self._postproc): + if isinstance(data, (list, tuple)): + outputs_dict = dict(zip(data, self._outputs.keys())) + elif not isinstance(data, dict): + oname = first(self._outputs.keys()) + outputs_dict = {oname: data} + else: + outputs_dict = data - The App SDK Image asnumpy() function is expected to return a numpy array of index order `DHW`. - This is because in the DICOM series to volume operator pydicom Dataset pixel_array is used to - to get per instance pixel numpy array, with index order of `HW`. When all instances are stacked, - along the first axis, the Image numpy array's index order is `DHW`. ITK array_view_from_image - and SimpleITK GetArrayViewFromImage also returns a numpy array with the index order of `DHW`. - The channel would be the last dim/index if present. In the ITKReader get_data(), this numpy array - is then transposed, and the channel axis moved to be last dim post transpose; this is to be - consistent with the numpy returned from NibabelReader get_data(). + # Need to add back the inputs including metadata as they are needed by the invert transform. + outputs_dict.update(inputs) + logging.debug(f"Effective output dict keys: {outputs_dict.keys()}") + return self._postproc(outputs_dict) + else: + if isinstance(data, (list, tuple)): + return list(map(self._postproc, data)) + + return self._postproc(data) + + def _receive_input(self, name: str, op_input: InputContext, context: ExecutionContext): + """Extracts the input value for the given input name.""" + + # The op_input can have the storage type of IN_MEMORY with the data type being Image or others, + # as well as the other type of DISK with data type being DataPath. + # The problem is, the op_input object does not have an attribute for the storage type, which + # needs to be inferred from data type, with DataPath meaning DISK storage type. The file + # content type may be interpreted from the bundle's network input type, but it is indirect + # as the op_input is the input for processing transforms, not necessarily directly for the network. + in_conf = self._inputs[name] + itype = self._get_io_data_type(in_conf) + value = op_input.get(name) + + metadata = None + if isinstance(value, DataPath): + if not value.path.exists(): + raise ValueError(f"Input path, {value.path}, does not exist.") + + file_path = value.path / name + # The named input can only be a folder as of now, but just in case things change. + if value.path.is_file(): + file_path = value.path + elif not file_path.exists() and value.path.is_dir: + # Expect one and only one file exists for use. + files = [f for f in value.path.glob("*") if f.is_file()] + if len(files) != 1: + raise ValueError(f"Input path, {value.path}, should have one and only one file.") + + file_path = files[0] + + # Only Python pickle file and or numpy file are supported as of now. + with open(file_path, "rb") as f: + if itype == np.ndarray: + value = np.load(file_path, allow_pickle=True) + else: + value = pickle.load(f) + + # Once extracted, the input data may be further processed depending on its actual type. + if isinstance(value, Image): + # Need to get the image ndarray as well as metadata + value, metadata = self._convert_from_image(value) + logging.debug(f"Shape of the converted input image: {value.shape}") + logging.debug(f"Metadata of the converted input image: {metadata}") + elif isinstance(value, np.ndarray): + value = torch.from_numpy(value).to(self._device) + + # else value is some other object from memory + + return value, metadata + + def _send_output(self, value, name: str, metadata: Dict, op_output: OutputContext, context: ExecutionContext): + """Send the given output value to the output context.""" + + logging.debug(f"Setting output {name}") + + out_conf = self._outputs[name] + otype = self._get_io_data_type(out_conf) + + if otype == Image: + # The value must be torch.tensor or ndarray. Note also that by convention the image/tensor + # out of the MONAI post processing is [CWHD] with dim for batch already squeezed out. + # Prediction image, e.g. segmentation image, needs to have its dimensions + # rearranged to fit the conventions used by Image class, i.e. [DHW], without channel dim. + # Also, based on known use cases, e.g. prediction being seg image and the downstream + # operators expect the data type to be unit8, conversion needs to be done as well. + # Metadata, such as pixel spacing and orientation, also needs to be set in the Image object, + # which is why metadata is expected to be passed in. + # TODO: Revisit when multi-channel images are supported. + + if isinstance(value, torch.Tensor): + value = value.cpu().numpy() + elif not isinstance(value, np.ndarray): + raise TypeError("arg 1 must be of type torch.Tensor or ndarray.") + + logging.debug(f"Output {name} numpy image shape: {value.shape}") + result = Image(np.swapaxes(np.squeeze(value, 0), 0, 2).astype(np.uint8), metadata=metadata) + logging.debug(f"Converted Image shape: {result.asnumpy().shape}") + elif otype == np.ndarray: + result = np.asarray(value) + elif out_conf["type"] == "probabilities": + _, value_class = value.max(dim=0) + prediction = [out_conf["channel_def"][str(int(v))] for v in value.flatten()] + + result = {"result": prediction, "probabilities": value.cpu().numpy()} + elif isinstance(value, torch.Tensor): + result = value.cpu().numpy() + + # The operator output currently has many limitation depending on if the operator is + # a leaf node or not. The get method throws for non-leaf node, irrespective of storage type, + # and for leaf node if the storage type is IN_MEMORY. + try: + op_output_config = op_output.get(name) + if isinstance(op_output_config, DataPath): + output_file = op_output_config.path / name + output_file.parent.mkdir(exist_ok=True) + # Save pickle file + with open(output_file, "wb") as wf: + pickle.dump(result, wf) + + # Cannot (re)set/modify the op_output path to the actual file like below + # op_output.set(str(output_file), name) + else: + op_output.set(result, name) + except ItemNotExistsError: + # The following throws if the output storage type is DISK, but The OutputContext + # currently does not expose the storage type. Try and let it throw if need be. + op_output.set(result, name) - The NibabelReader loads NIfTI image and uses the get_fdata() function of the loaded image to get - the numpy array, which has the index order in WHD with the channel being the last dim_get_compose if present. + def _convert_from_image(self, img: Image) -> Tuple[np.ndarray, Dict]: + """Converts the Image object to the expected numpy array with metadata dictionary. Args: - input_image (Image): an App SDK Image object. + img: A SDK Image object. """ - img_array: List[np.ndarray] = [] - compatible_meta: Dict = {} + # The Image class provides a numpy array and a metadata dict without a defined set of keys. + # In most scenarios, if not all, DICOM series is converted to Image by the + # DICOMSeriesToVolumeOperator, but the generated metadata lacks the specifics keys expected + # by the MONAI transforms. So there is need to convert the Image object. + # Also, there is not a defined key to express the source or producer of an Image object, so, + # one has to inspect certain keys, based on known conversion, to infer the producer. + # An issues already exists for the improvement of the Image class. - for i in ensure_tuple(input_image): - if not isinstance(i, Image): - raise TypeError("Only object of Image type is supported.") + img_meta_dict: Dict = img.metadata() - # The Image asnumpy() returns NumPy array similar to ITK array_view_from_image - # The array then needs to be transposed, as does in MONAI ITKReader, to align_get_compose - # with the output from Nibabel reader loading NIfTI files. - data = i.asnumpy().T - img_array.append(data) - header = self._get_meta_dict(i) - _copy_compatible_dict(header, compatible_meta) + if ( + not img_meta_dict + or ("spacing" in img_meta_dict and "original_affine" in img_meta_dict) + or "row_pixel_spacing" not in img_meta_dict + ): - # Stacking image is not really needed, as there is one image only. - return _stack_images(img_array, compatible_meta), compatible_meta + return img.asnumpy(), img_meta_dict + else: + return self._convert_from_image_dicom_source(img) - def _get_meta_dict(self, img: Image) -> Dict: - """ - Gets the metadata of the image and converts to dict type. + def _convert_from_image_dicom_source(self, img: Image) -> Tuple[np.ndarray, Dict]: + """Converts the Image object to the expected numpy array with metadata dictionary. Args: - img: A SDK Image object. + img: A SDK Image object converted from DICOM instances. """ + img_meta_dict: Dict = img.metadata() meta_dict = {key: img_meta_dict[key] for key in img_meta_dict.keys()} - # Will have to derive some key metadata as the SDK Image lacks the necessary interfaces. - # So, for now have to get to the Image generator, namely DICOMSeriesToVolumeOperator, and - # rely on its published metadata. - - # Referring to the MONAI ITKReader, the spacing is simply a NumPy array from the ITK image - # GetSpacing, in WHD. + # The MONAI ImageReader, e.g. the ITKReader, arranges the image spatial dims in WHD, + # so the "spacing" needs to be expressed in such an order too, as expected by the transforms. meta_dict["spacing"] = np.asarray( [ img_meta_dict["row_pixel_spacing"], @@ -441,42 +692,16 @@ def _get_meta_dict(self, img: Image) -> Dict: ) meta_dict["original_affine"] = np.asarray(img_meta_dict.get("nifti_affine_transform", None)) meta_dict["affine"] = meta_dict["original_affine"] - # The spatial shape, again, referring to ITKReader, it is the WHD - meta_dict["spatial_shape"] = np.asarray(img.asnumpy().T.shape) - # Well, no channel as the image data shape is forced to the the same as spatial shape - meta_dict["original_channel_dim"] = "no_channel" - return meta_dict - - -# Reuse MONAI code for the derived ImageReader -def _copy_compatible_dict(from_dict: Dict, to_dict: Dict): - if not isinstance(to_dict, dict): - raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.") - if not to_dict: - for key in from_dict: - datum = from_dict[key] - if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None: - continue - to_dict[key] = datum - else: - affine_key, shape_key = "affine", "spatial_shape" - if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]): - raise RuntimeError( - "affine matrix of all images should be the same for channel-wise concatenation. " - f"Got {from_dict[affine_key]} and {to_dict[affine_key]}." - ) - if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]): - raise RuntimeError( - "spatial_shape of all images should be the same for channel-wise concatenation. " - f"Got {from_dict[shape_key]} and {to_dict[shape_key]}." - ) + # Similarly the Image ndarray has dim order DHW, to be rearranged to WHD. + # TODO: Need to revisit this once multi-channel image is supported and the Image class itself + # is enhanced to provide attributes or functions for channel and dim order details. + converted_image = np.swapaxes(img.asnumpy(), 0, 2) + + # The spatial shape is then that of the converted image, in WHD + meta_dict["spatial_shape"] = np.asarray(converted_image.shape) + # Well, now channel for now. + meta_dict["original_channel_dim"] = "no_channel" -def _stack_images(image_list: List, meta_dict: Dict): - if len(image_list) <= 1: - return image_list[0] - if meta_dict.get("original_channel_dim", None) not in ("no_channel", None): - raise RuntimeError("can not read a list of images which already have channel dimension.") - meta_dict["original_channel_dim"] = 0 - return np.stack(image_list, axis=0) + return converted_image, meta_dict diff --git a/monai/deploy/operators/monai_seg_inference_operator.py b/monai/deploy/operators/monai_seg_inference_operator.py index 576003ad..13f58c8e 100644 --- a/monai/deploy/operators/monai_seg_inference_operator.py +++ b/monai/deploy/operators/monai_seg_inference_operator.py @@ -62,6 +62,7 @@ def __init__( roi_size: Union[Sequence[int], int], pre_transforms: Compose, post_transforms: Compose, + model_name: Optional[str] = "", overlap: float = 0.5, *args, **kwargs, @@ -72,6 +73,7 @@ def __init__( roi_size (Union[Sequence[int], int]): The tensor size used in inference. pre_transforms (Compose): MONAI Compose object used for pre-transforms. post_transforms (Compose): MONAI Compose object used for post-transforms. + model_name (str, optional): Name of the model. Default to "" for single model app. overlap (float): The overlap used in sliding window inference. """ @@ -85,6 +87,7 @@ def __init__( self._roi_size = ensure_tuple(roi_size) self._pre_transform = pre_transforms self._post_transforms = post_transforms + self._model_name = model_name.strip() if isinstance(model_name, str) else "" self.overlap = overlap @property @@ -202,7 +205,7 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe if context.models: # `context.models.get(model_name)` returns a model instance if exists. # If model_name is not specified and only one model exists, it returns that model. - model = context.models.get() + model = context.models.get(self._model_name) else: print(f"Loading TorchScript model from: {MonaiSegInferenceOperator.MODEL_LOCAL_PATH}") model = torch.jit.load(MonaiSegInferenceOperator.MODEL_LOCAL_PATH, map_location=device) diff --git a/requirements-dev.txt b/requirements-dev.txt index 26286eca..9af3e98b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -23,7 +23,7 @@ pytest==6.2.4 pytest-cov==2.12.1 pytest-lazy-fixture==0.6.3 cucim~=21.06; platform_system == "Linux" -monai>=0.8.1 +monai>=0.9.0 docker>=5.0.0 pydicom>=1.4.2 SimpleITK>=2.0.0 @@ -33,3 +33,4 @@ scikit-image >= 0.17.2 nibabel >= 3.2.1 numpy-stl >= 2.12.0 trimesh >= 3.8.11 +torch>=1.10.0 \ No newline at end of file diff --git a/requirements-examples.txt b/requirements-examples.txt index ee7fbc6e..00dcc8f3 100644 --- a/requirements-examples.txt +++ b/requirements-examples.txt @@ -2,6 +2,9 @@ scikit-image >= 0.17.2 pydicom >= 1.4.2 SimpleITK >= 2.0.0 Pillow >= 8.0.0 +numpy-stl>=2.12.0 +trimesh>=3.8.11 nibabel >= 3.2.1 numpy-stl >= 2.12.0 trimesh >= 3.8.11 +torch >= 1.10.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 74252c9b..ada874ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -numpy>=1.21 # CVE-2021-33430 +numpy>=1.21.2 networkx>=2.4 colorama>=0.4.1 typeguard>=2.12.1 diff --git a/setup.cfg b/setup.cfg index 09761522..361fd924 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,7 +23,7 @@ python_requires = >= 3.7 # setup_requires = # cucim install_requires = - numpy>=1.21 # CVE-2021-33430 + numpy>=1.21.2 # CVE-2021-33430 networkx>=2.4 colorama>=0.4.1 typeguard>=2.12.1 From fb43b3ae5a8cd9d9566758f253fa317e1244718e Mon Sep 17 00:00:00 2001 From: mmelqin Date: Wed, 22 Jun 2022 01:29:27 -0700 Subject: [PATCH 05/13] Fix styling check errors. Signed-off-by: mmelqin --- examples/apps/ai_spleen_seg_app/app.py | 1 - .../operators/monai_bundle_inference_operator.py | 15 +++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/apps/ai_spleen_seg_app/app.py b/examples/apps/ai_spleen_seg_app/app.py index 12b69e55..533b0879 100644 --- a/examples/apps/ai_spleen_seg_app/app.py +++ b/examples/apps/ai_spleen_seg_app/app.py @@ -13,7 +13,6 @@ from monai.deploy.core import Application, resource from monai.deploy.core.domain import Image -from monai.deploy.core.domain.datapath import DataPath from monai.deploy.core.io_type import IOType from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator diff --git a/monai/deploy/operators/monai_bundle_inference_operator.py b/monai/deploy/operators/monai_bundle_inference_operator.py index 407c46fa..45c66c91 100644 --- a/monai/deploy/operators/monai_bundle_inference_operator.py +++ b/monai/deploy/operators/monai_bundle_inference_operator.py @@ -15,14 +15,11 @@ import pickle import time import zipfile -from http.client import OK -from multiprocessing.sharedctypes import Value from pathlib import Path from threading import Lock from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy as np -from genericpath import exists import monai.deploy.core as md from monai.deploy.core import DataPath, ExecutionContext, Image, InputContext, IOType, OutputContext @@ -239,7 +236,7 @@ def __init__( output_mapping: List[IOMapping], model_name: Optional[str] = "", bundle_path: Optional[str] = None, - bundle_config_names: BundleConfigNames = BundleConfigNames(), + bundle_config_names: BundleConfigNames = None, *args, **kwargs, ): @@ -248,9 +245,11 @@ def __init__( Args: input_mapping (List[IOMapping]): Define the inputs' name, type, and storage type. output_mapping (List[IOMapping]): Defines the outputs' name, type, and storage type. - model_name (Optional[str], optional): Name of the model/bundle, needed in multi-model case. Defaults to "". + model_name (Optional[str], optional): Name of the model/bundle, needed in multi-model case. + Defaults to "". bundle_path (Optional[str], optional): For completing . Defaults to None. - bundle_config_names (BundleConfigNames, optional): Relevant config item names in a the bundle. Defaults to BundleConfigNames(). + bundle_config_names (BundleConfigNames, optional): Relevant config item names in a the bundle. + Defaults to None. """ super().__init__(*args, **kwargs) @@ -258,7 +257,7 @@ def __init__( self._lock = Lock() self._model_name = model_name.strip() if isinstance(model_name, str) else "" - self._bundle_config_names = bundle_config_names + self._bundle_config_names = bundle_config_names if bundle_config_names else BundleConfigNames() self._input_mapping = input_mapping self._output_mapping = output_mapping @@ -321,7 +320,7 @@ def parser(self, parser: ConfigParser): if parser and isinstance(parser, ConfigParser): self._parser = parser else: - raise ValueError(f"Value must be a valid ConfigParser object.") + raise ValueError("Value must be a valid ConfigParser object.") def _init_config(self, config_names): """Completes the init with a known path to the MONAI Bundle From a9642d68631457b1289005d745eebc0201eb1340 Mon Sep 17 00:00:00 2001 From: mmelqin Date: Wed, 22 Jun 2022 01:36:17 -0700 Subject: [PATCH 06/13] Fix flake8 check error. Signed-off-by: mmelqin --- monai/deploy/operators/monai_bundle_inference_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/deploy/operators/monai_bundle_inference_operator.py b/monai/deploy/operators/monai_bundle_inference_operator.py index 45c66c91..1890ca87 100644 --- a/monai/deploy/operators/monai_bundle_inference_operator.py +++ b/monai/deploy/operators/monai_bundle_inference_operator.py @@ -152,7 +152,7 @@ def __init__( preproc_name: str = "preprocessing", postproc_name: str = "postprocessing", inferer_name: str = "inferer", - config_names: Union[List[str], Tuple[str], str] = ["inference"], + config_names: Union[List[str], Tuple[str], str] = "inference", ) -> None: """Creates an object holding the names of relevant config items in a MONAI Bundle. From 606ba9b245c27b643b977d2dbb1771226a948053 Mon Sep 17 00:00:00 2001 From: mmelqin Date: Fri, 24 Jun 2022 17:14:24 -0700 Subject: [PATCH 07/13] Correct all MyPy type checking errors and improve inline doc. The base inference operator needs to adjust the function signatures due to the expanded use of types resulting from the intro if bundle inference operator. The new version of MyPy also seems to be stricter on types. Signed-off-by: mmelqin --- monai/deploy/operators/inference_operator.py | 8 +-- .../monai_bundle_inference_operator.py | 56 ++++++++++++------- .../operators/monai_seg_inference_operator.py | 25 +++++++-- 3 files changed, 59 insertions(+), 30 deletions(-) diff --git a/monai/deploy/operators/inference_operator.py b/monai/deploy/operators/inference_operator.py index 5a07b02a..d1f47a00 100644 --- a/monai/deploy/operators/inference_operator.py +++ b/monai/deploy/operators/inference_operator.py @@ -10,7 +10,7 @@ # limitations under the License. from abc import abstractmethod -from typing import Any, Union +from typing import Any, Dict, Tuple, Union from monai.deploy.core import ExecutionContext, Image, InputContext, Operator, OutputContext @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs): super().__init__() @abstractmethod - def pre_process(self, data: Any) -> Union[Image, Any]: + def pre_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Transforms input before being used for predicting on a model. This method must be overridden by a derived class. @@ -50,7 +50,7 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe pass @abstractmethod - def predict(self, data: Any) -> Union[Image, Any]: + def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Predicts results using the models(s) with input tensors. This method must be overridden by a derived class. @@ -61,7 +61,7 @@ def predict(self, data: Any) -> Union[Image, Any]: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @abstractmethod - def post_process(self, data: Any) -> Union[Image, Any]: + def post_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Transform the prediction results from the model(s). This method must be overridden by a derived class. diff --git a/monai/deploy/operators/monai_bundle_inference_operator.py b/monai/deploy/operators/monai_bundle_inference_operator.py index 1890ca87..e931d445 100644 --- a/monai/deploy/operators/monai_bundle_inference_operator.py +++ b/monai/deploy/operators/monai_bundle_inference_operator.py @@ -34,6 +34,7 @@ PostFix, _ = optional_import("monai.utils.enums", name="PostFix") # For the default meta_key_postfix first, _ = optional_import("monai.utils.misc", name="first") +ensure_tuple, _ = optional_import("monai.utils", name="ensure_tuple") Compose_, _ = optional_import("monai.transforms", name="Compose") ConfigParser_, _ = optional_import("monai.bundle", name="ConfigParser") MapTransform_, _ = optional_import("monai.transforms", name="MapTransform") @@ -236,7 +237,7 @@ def __init__( output_mapping: List[IOMapping], model_name: Optional[str] = "", bundle_path: Optional[str] = None, - bundle_config_names: BundleConfigNames = None, + bundle_config_names: Optional[BundleConfigNames] = None, *args, **kwargs, ): @@ -261,9 +262,9 @@ def __init__( self._input_mapping = input_mapping self._output_mapping = output_mapping - self._parser = None # Needs known bundle path, either on init or when compute function is called. - self._inferer = None # Will be set during bundle parsing. - self._init_completed = False + self._parser: ConfigParser = None # Needs known bundle path, either on init or when compute function is called. + self._inferer: Any = None # Will be set during bundle parsing. + self._init_completed: bool = False # Need to set the operator's input(s) and output(s). Even when the bundle parsing is done in init, # there is still a need to define what op inputs/outputs map to what keys in the bundle config, @@ -289,6 +290,9 @@ def __init__( logging.warn("Bundle parsing is not completed on init, delayed till this operator is called to execute.") self._bundle_path = None + # Lazy init of model network till execution time when the context is fully set. + self._model_network: Any = None + @property def model_name(self) -> str: return self._model_name @@ -390,7 +394,7 @@ def _get_meta_key_postfix(self, compose: Compose, key_name: str = "meta_key_post post_fix = post_fix[0] break - return post_fix + return str(post_fix) def _get_io_data_type(self, conf): """ @@ -441,28 +445,32 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe # Try to get the Model object and its path from the context. # If operator is not fully initialized, use model path as bundle path to finish it. - # If Model not loaded, but bundle path exists, load model, just in case. + # If Model not loaded, but bundle path exists, load model; edge case for local dev. # # `context.models.get(model_name)` returns a model instance if exists. # If model_name is not specified and only one model exists, it returns that model. - model = context.models.get(self._model_name) if context.models else None - if model: + + self._model_network = context.models.get(self._model_name) if context.models else None + if self._model_network: if not self._init_completed: with self._lock: if not self._init_completed: - self._bundle_path = model.path + self._bundle_path = self._model_network.path self._init_config(self._bundle_config_names.config_names) self._init_completed elif self._bundle_path: + # For the case of local dev/testing when the bundle path is not passed in as an exec cmd arg. + # When run as a MAP docker, the bundle file is expected to be in the context, even if the model + # network is loaded on a remote inference server (when the feature is introduced). logging.debug(f"Model network not loaded. Trying to load from model path: {self._bundle_path}") - model = torch.jit.load(self.bundle_path, map_location=self._device).eval() + self._model_network = torch.jit.load(self.bundle_path, map_location=self._device).eval() else: raise IOError("Model network is not load and model file not found.") first_input_name, *other_names = list(self._inputs.keys()) with torch.no_grad(): - inputs = {} + inputs: Any = {} # Use type Any to quiet MyPy type checking complaints. start = time.time() for name in self._inputs.keys(): @@ -482,13 +490,13 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe logging.debug(f"Ingest and Pre-processing elapsed time (seconds): {time.time() - start}") start = time.time() - outputs = self.predict(data=first_input, network=model, **other_inputs) + outputs: Any = self.predict(data=first_input, **other_inputs) # Use type Any to quiet MyPy complaints. logging.debug(f"Inference elapsed time (seconds): {time.time() - start}") # TODO: Does this work for models where multiple outputs are returned? # Note that the inputs are needed because the invert transform requires it. start = time.time() - outputs = self.post_process(outputs[0], inputs) + outputs = self.post_process(ensure_tuple(outputs)[0], preprocessed_inputs=inputs) logging.debug(f"Post-processing elapsed time (seconds): {time.time() - start}") if isinstance(outputs, (tuple, list)): output_dict = dict(zip(self._outputs.keys(), outputs)) @@ -502,19 +510,27 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe # Please see the comments in the called function for the reasons. self._send_output(output_dict[name], name, input_metadata, op_output, context) - def predict(self, data: Any, network: Any, *args, **kwargs) -> Union[Image, Any]: + def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Predicts output using the inferer.""" - return self._inferer(inputs=data, network=network, *args, **kwargs) - def pre_process(self, data: Any) -> Union[Image, Any]: + return self._inferer(inputs=data, network=self._model_network, *args, **kwargs) + + def pre_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Processes the input dictionary with the stored transform sequence `self._preproc`.""" if is_map_compose(self._preproc): return self._preproc(data) return {k: self._preproc(v) for k, v in data.items()} - def post_process(self, data: Any, inputs: Dict) -> Union[Image, Any]: - """Processes the output list/dictionary with the stored transform sequence `self._postproc`.""" + def post_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: + """Processes the output list/dictionary with the stored transform sequence `self._postproc`. + + The "processed_inputs", in fact the metadata in it, need to be passed in so that the + invertible transforms in the post processing can work properly. + """ + + # Expect the inputs be passed in so that the inversion can work. + inputs = kwargs.get("preprocessed_inputs", {}) if is_map_compose(self._postproc): if isinstance(data, (list, tuple)): @@ -585,7 +601,7 @@ def _receive_input(self, name: str, op_input: InputContext, context: ExecutionCo return value, metadata - def _send_output(self, value, name: str, metadata: Dict, op_output: OutputContext, context: ExecutionContext): + def _send_output(self, value: Any, name: str, metadata: Dict, op_output: OutputContext, context: ExecutionContext): """Send the given output value to the output context.""" logging.debug(f"Setting output {name}") @@ -610,7 +626,7 @@ def _send_output(self, value, name: str, metadata: Dict, op_output: OutputContex raise TypeError("arg 1 must be of type torch.Tensor or ndarray.") logging.debug(f"Output {name} numpy image shape: {value.shape}") - result = Image(np.swapaxes(np.squeeze(value, 0), 0, 2).astype(np.uint8), metadata=metadata) + result: Any = Image(np.swapaxes(np.squeeze(value, 0), 0, 2).astype(np.uint8), metadata=metadata) logging.debug(f"Converted Image shape: {result.asnumpy().shape}") elif otype == np.ndarray: result = np.asarray(value) diff --git a/monai/deploy/operators/monai_seg_inference_operator.py b/monai/deploy/operators/monai_seg_inference_operator.py index 13f58c8e..5f368f2b 100644 --- a/monai/deploy/operators/monai_seg_inference_operator.py +++ b/monai/deploy/operators/monai_seg_inference_operator.py @@ -10,7 +10,7 @@ # limitations under the License. from threading import Lock -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -31,7 +31,6 @@ Compose_, _ = optional_import("monai.transforms", name="Compose") # Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477 Compose: Any = Compose_ -sliding_window_inference, _ = optional_import("monai.inferers", name="sliding_window_inference") import monai.deploy.core as md from monai.deploy.core import ExecutionContext, Image, InputContext, IOType, OutputContext @@ -246,30 +245,44 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe with self._lock: self._executing = False - def pre_process(self, img_reader) -> Union[Any, Image, Compose]: + def pre_process(self, data: Any, *args, **kwargs) -> Union[Any, Image, Tuple[Any, ...], Dict[Any, Any]]: """Transforms input before being used for predicting on a model. This method must be overridden by a derived class. + Expected return is monai.transforms.Compose. + + Args: + data(monai.data.ImageReader): Reader used in LoadImage to load `monai.deploy.core.Image` as the input. + + Returns: + monai.transforms.Compose encapsulating pre transforms Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def post_process(self, pre_transforms: Compose, out_dir: str = "./infer_out") -> Union[Any, Image, Compose]: + def post_process(self, data: Any, *args, **kwargs) -> Union[Any, Image, Tuple[Any, ...], Dict[Any, Any]]: """Transforms the prediction results from the model(s). This method must be overridden by a derived class. + Expected return is monai.transforms.Compose. + + Args: + data(monai.transforms.Compose): The pre-processing transforms in a Compose object. + + Returns: + monai.transforms.Compose encapsulating post-processing transforms. Raises: NotImplementedError: When the subclass does not override this method. """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any]: + def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: """Predicts results using the models(s) with input tensors. - This method must be overridden by a derived class. + This method is currently not used in this class, instead monai.inferers.sliding_window_inference is used. Raises: NotImplementedError: When the subclass does not override this method. From be3d64671a6cdb41f3027ce2920b589e72b5c624 Mon Sep 17 00:00:00 2001 From: mmelqin Date: Sat, 25 Jun 2022 00:56:32 -0700 Subject: [PATCH 08/13] MyPy on Git is complaining things that are not found in local checking. Signed-off-by: mmelqin --- .../monai_bundle_inference_operator.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/monai/deploy/operators/monai_bundle_inference_operator.py b/monai/deploy/operators/monai_bundle_inference_operator.py index e931d445..88d2bd36 100644 --- a/monai/deploy/operators/monai_bundle_inference_operator.py +++ b/monai/deploy/operators/monai_bundle_inference_operator.py @@ -189,6 +189,8 @@ def _ensure_str_list(config_names): self.config_names: List[str] = _ensure_str_list(config_names) +DEFAULT_BundleConfigNames = BundleConfigNames() + # The operator env decorator defines the required pip packages commonly used in the Bundles. # The MONAI Deploy App SDK packager currently relies on the App to consolidate all required packages in order to # install them in the MAP Docker image. @@ -222,8 +224,6 @@ class MonaiBundleInferenceOperator(InferenceOperator): a pickle file whose name is the same as the output name. """ - DISALLOWED_TRANSFORMS = ["LoadImage", "SaveImage"] - known_io_data_types = { "image": Image, # Image object "series": np.ndarray, @@ -231,13 +231,15 @@ class MonaiBundleInferenceOperator(InferenceOperator): "probabilities": Dict[str, Any], # dictionary containing probabilities and predicted labels } + kw_preprocessed_inputs = "preprocessed_inputs" + def __init__( self, input_mapping: List[IOMapping], output_mapping: List[IOMapping], model_name: Optional[str] = "", - bundle_path: Optional[str] = None, - bundle_config_names: Optional[BundleConfigNames] = None, + bundle_path: Optional[str] = "", + bundle_config_names: Optional[BundleConfigNames] = DEFAULT_BundleConfigNames, *args, **kwargs, ): @@ -391,7 +393,7 @@ def _get_meta_key_postfix(self, compose: Compose, key_name: str = "meta_key_post post_fix = getattr(t, key_name) # For some reason the attr is a tuple if isinstance(post_fix, tuple): - post_fix = post_fix[0] + post_fix = str(post_fix[0]) break return str(post_fix) @@ -496,7 +498,8 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe # TODO: Does this work for models where multiple outputs are returned? # Note that the inputs are needed because the invert transform requires it. start = time.time() - outputs = self.post_process(ensure_tuple(outputs)[0], preprocessed_inputs=inputs) + kw_args = {self.kw_preprocessed_inputs: inputs} + outputs = self.post_process(ensure_tuple(outputs)[0], **kw_args) logging.debug(f"Post-processing elapsed time (seconds): {time.time() - start}") if isinstance(outputs, (tuple, list)): output_dict = dict(zip(self._outputs.keys(), outputs)) @@ -530,7 +533,7 @@ def post_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[An """ # Expect the inputs be passed in so that the inversion can work. - inputs = kwargs.get("preprocessed_inputs", {}) + inputs = kwargs.get(self.kw_preprocessed_inputs, {}) if is_map_compose(self._postproc): if isinstance(data, (list, tuple)): From 17aa29e69f3db4dc7ed680e9798ad4623a286990 Mon Sep 17 00:00:00 2001 From: mmelqin Date: Wed, 6 Jul 2022 15:49:32 -0700 Subject: [PATCH 09/13] Update per review comments. Signed-off-by: mmelqin --- examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py | 2 +- monai/deploy/operators/monai_bundle_inference_operator.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py b/examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py index f4ba3d44..d2fccf67 100644 --- a/examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py +++ b/examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py @@ -94,7 +94,7 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe pre_transforms, post_transforms, overlap=0.6, - model_name=" ", + model_name="", ) # Setting the keys used in the dictironary based transforms may change. diff --git a/monai/deploy/operators/monai_bundle_inference_operator.py b/monai/deploy/operators/monai_bundle_inference_operator.py index 88d2bd36..e6081624 100644 --- a/monai/deploy/operators/monai_bundle_inference_operator.py +++ b/monai/deploy/operators/monai_bundle_inference_operator.py @@ -424,6 +424,7 @@ def _get_io_data_type(self, conf): elif isinstance(ctype, type): # type object return ctype else: # don't know, something that hasn't been figured out + logging.warn(f"I/O data type, {ctype}, is not a known/supported type. Return as Type object.") return object def _add_inputs(self, input_mapping: List[IOMapping]): From b5648364e1ec5956cef61b038ebba49c1741b45d Mon Sep 17 00:00:00 2001 From: mmelqin Date: Thu, 7 Jul 2022 13:31:51 -0700 Subject: [PATCH 10/13] Rebased for the requirements.txt file. Not req'ed for docs/requirements but for consistency. Signed-off-by: mmelqin --- docs/requirements.txt | 9 ++++++--- requirements-dev.txt | 8 ++++---- requirements-examples.txt | 3 ++- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 2a9b707c..a154b509 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -19,10 +19,13 @@ docutils==0.16 # 0.17 causes error. https://github.com/executablebooks/MyST-Par pydata_sphinx_theme==0.6.3 sphinxemoji==0.1.8 scipy -scikit-image +scikit-image>=0.17.2 plotly nibabel>=3.2.1 -monai +monai>=0.9.0 +torch>=1.10.0 +numpy-stl>=2.12.0 +trimesh>=3.8.11 pydicom sphinx-autodoc-typehints==1.12.0 sphinxcontrib-applehelp==1.0.2 @@ -31,4 +34,4 @@ sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 -sphinxcontrib-mermaid==0.7.1 +sphinxcontrib-mermaid==0.7.1 \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 9af3e98b..8581d1e0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -29,8 +29,8 @@ pydicom>=1.4.2 SimpleITK>=2.0.0 Pillow>=8.0.0 bump2version==1.0.1 -scikit-image >= 0.17.2 -nibabel >= 3.2.1 -numpy-stl >= 2.12.0 -trimesh >= 3.8.11 +scikit-image>=0.17.2 +nibabel>=3.2.1 +numpy-stl>=2.12.0 +trimesh>=3.8.11 torch>=1.10.0 \ No newline at end of file diff --git a/requirements-examples.txt b/requirements-examples.txt index 00dcc8f3..abbf0a97 100644 --- a/requirements-examples.txt +++ b/requirements-examples.txt @@ -7,4 +7,5 @@ trimesh>=3.8.11 nibabel >= 3.2.1 numpy-stl >= 2.12.0 trimesh >= 3.8.11 -torch >= 1.10.0 \ No newline at end of file +torch >= 1.10.0 +monai >= 0.9.0 \ No newline at end of file From 888ea6ff1e3fe5cb841bccd17b300c024f35005b Mon Sep 17 00:00:00 2001 From: mmelqin Date: Fri, 8 Jul 2022 00:00:00 -0700 Subject: [PATCH 11/13] Add the notebook for creating app with MONAI Bundle inference operator. Signed-off-by: mmelqin --- notebooks/tutorials/06_monai_bundle_app.ipynb | 875 ++++++++++++++++++ 1 file changed, 875 insertions(+) create mode 100644 notebooks/tutorials/06_monai_bundle_app.ipynb diff --git a/notebooks/tutorials/06_monai_bundle_app.ipynb b/notebooks/tutorials/06_monai_bundle_app.ipynb new file mode 100644 index 00000000..35bff664 --- /dev/null +++ b/notebooks/tutorials/06_monai_bundle_app.ipynb @@ -0,0 +1,875 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating a Deploy App with MONAI Deploy App SDK and MONAI Bundle\n", + "\n", + "This tutorial shows how to create an organ segmentation application for a PyTorch model that has been trained with MONAI and packaged in the [MONAI Bundle](https://docs.monai.io/en/latest/bundle_intro.html) format.\n", + "\n", + "Deploying AI models requires the integration with clinical imaging network, even if in a for-research-use setting. This means that the AI deploy application will need to support standards-based imaging protocols, and specifically for Radiological imaging, DICOM protocol.\n", + "\n", + "Typically, DICOM network communication, either in DICOM TCP/IP network protocol or DICOMWeb, would be handled by DICOM devices or services, e.g. MONAI Deploy Informatics Gateway, so the deploy application itself would only need to use DICOM Part 10 files as input and save the AI result in DICOM Part10 file(s). For segmentation use cases, the DICOM instance file for AI results could be a DICOM Segmentation object or a DICOM RT Structure Set, and for classification, DICOM Structure Report and/or DICOM Encapsulated PDF.\n", + "\n", + "When integrated with imaging networks and receiving DICOM instances from modalities and Picture Archiving and Communications System (PACS), an AI deploy application has to deal with a whole DICOM study with multiple series, whose images' spacing may not be the same as expected by the trained model. To address these cases consistently and efficiently, MONAI Deploy Application SDK provides classes, called operators, to parse DICOM studies, select specific series with application-defined rules, and convert the selected DICOM series into domain-specific image format along with meta-data representing the pertinent DICOM attributes. The image is then further processed in the pre-processing stage to normalize spacing, orientation, intensity,etc, before pixel data as Tensors are used for inference.\n", + "\n", + "In the following sections, we will demonstrate how to create a MONAI Deploy application package using the MONAI Deploy App SDK, and importantly, using the built-in MONAI Bundle Inference Operator to perform inference with the Spleen CT Segmentation PyTorch model in a MONAI Bundle.\n", + "\n", + ":::{note}\n", + "For local testing, if there is a lack of DICOM Part 10 files, one can use open source programs, e.g. 3D Slicer, to convert a NIfTI file to a DICOM series.\n", + "\n", + "To make running this example simpler, the DICOM files and the [Spleen CT Segmentation MONAI Bundle](https://github.com/Project-MONAI/model-zoo/tree/dev/models/spleen_ct_segmentation), published in [MONAI Model Zoo](https://github.com/Project-MONAI/model-zoo), have been packaged and shared on Google Drive.\n", + "\n", + ":::\n", + "\n", + "## Creating Operators and connecting them in Application class\n", + "\n", + "We will implement an application that consists of five Operators:\n", + "\n", + "- **DICOMDataLoaderOperator**:\n", + " - **Input(dicom_files)**: a folder path ([`DataPath`](/modules/_autosummary/monai.deploy.core.domain.DataPath))\n", + " - **Output(dicom_study_list)**: a list of DICOM studies in memory (List[[`DICOMStudy`](/modules/_autosummary/monai.deploy.core.domain.DICOMStudy)])\n", + "- **DICOMSeriesSelectorOperator**:\n", + " - **Input(dicom_study_list)**: a list of DICOM studies in memory (List[[`DICOMStudy`](/modules/_autosummary/monai.deploy.core.domain.DICOMStudy)])\n", + " - **Input(selection_rules)**: a selection rule (Dict)\n", + " - **Output(study_selected_series_list)**: a DICOM series object in memory ([`StudySelectedSeries`](/modules/_autosummary/monai.deploy.core.domain.StudySelectedSeries))\n", + "- **DICOMSeriesToVolumeOperator**:\n", + " - **Input(study_selected_series_list)**: a DICOM series object in memory ([`StudySelectedSeries`](/modules/_autosummary/monai.deploy.core.domain.StudySelectedSeries))\n", + " - **Output(image)**: an image object in memory ([`Image`](/modules/_autosummary/monai.deploy.core.domain.Image))\n", + "- **MonaiBundleInferenceOperator**:\n", + " - **Input(image)**: an image object in memory ([`Image`](/modules/_autosummary/monai.deploy.core.domain.Image))\n", + " - **Output(pred)**: an image object in memory ([`Image`](/modules/_autosummary/monai.deploy.core.domain.Image))\n", + "- **DICOMSegmentationWriterOperator**:\n", + " - **Input(seg_image)**: a segmentation image object in memory ([`Image`](/modules/_autosummary/monai.deploy.core.domain.Image))\n", + " - **Input(study_selected_series_list)**: a DICOM series object in memory ([`StudySelectedSeries`](/modules/_autosummary/monai.deploy.core.domain.StudySelectedSeries))\n", + " - **Output(dicom_seg_instance)**: a file path ([`DataPath`](/modules/_autosummary/monai.deploy.core.domain.DataPath))\n", + "\n", + "\n", + ":::{note}\n", + "The `DICOMSegmentationWriterOperator` needs both the segmentation image as well as the original DICOM series meta-data in order to use the patient demographics and the DICOM Study level attributes.\n", + ":::\n", + "\n", + "The workflow of the application is illustrated below.\n", + "\n", + "```{mermaid}\n", + "%%{init: {\"theme\": \"base\", \"themeVariables\": { \"fontSize\": \"16px\"}} }%%\n", + "\n", + "classDiagram\n", + " direction TB\n", + "\n", + " DICOMDataLoaderOperator --|> DICOMSeriesSelectorOperator : dicom_study_list...dicom_study_list\n", + " DICOMSeriesSelectorOperator --|> DICOMSeriesToVolumeOperator : study_selected_series_list...study_selected_series_list\n", + " DICOMSeriesToVolumeOperator --|> MonaiBundleInferenceOperator : image...image\n", + " DICOMSeriesSelectorOperator --|> DICOMSegmentationWriterOperator : study_selected_series_list...study_selected_series_list\n", + " MonaiBundleInferenceOperator --|> DICOMSegmentationWriterOperator : pred...seg_image\n", + "\n", + "\n", + " class DICOMDataLoaderOperator {\n", + " dicom_files : DISK\n", + " dicom_study_list(out) IN_MEMORY\n", + " }\n", + " class DICOMSeriesSelectorOperator {\n", + " dicom_study_list : IN_MEMORY\n", + " selection_rules : IN_MEMORY\n", + " study_selected_series_list(out) IN_MEMORY\n", + " }\n", + " class DICOMSeriesToVolumeOperator {\n", + " study_selected_series_list : IN_MEMORY\n", + " image(out) IN_MEMORY\n", + " }\n", + " class MonaiBundleInferenceOperator {\n", + " image : IN_MEMORY\n", + " pred(out) IN_MEMORY\n", + " }\n", + " class DICOMSegmentationWriterOperator {\n", + " seg_image : IN_MEMORY\n", + " study_selected_series_list : IN_MEMORY\n", + " dicom_seg_instance(out) DISK\n", + " }\n", + "```\n", + "\n", + "### Setup environment\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Install MONAI and other necessary image processing packages for the application\n", + "!python -c \"import monai\" || pip install --upgrade -q \"monai\"\n", + "!python -c \"import torch\" || pip install -q \"torch>=1.5\"\n", + "!python -c \"import numpy\" || pip install -q \"numpy>=1.21\"\n", + "!python -c \"import nibabel\" || pip install -q \"nibabel>=3.2.1\"\n", + "!python -c \"import pydicom\" || pip install -q \"pydicom>=1.4.2\"\n", + "!python -c \"import SimpleITK\" || pip install -q \"SimpleITK>=2.0.0\"\n", + "!python -c \"import typeguard\" || pip install -q \"typeguard>=2.12.1\"\n", + "\n", + "# Install MONAI Deploy App SDK package\n", + "!python -c \"import monai.deploy\" || pip install --upgrade -q \"monai-deploy-app-sdk\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: you may need to restart the Jupyter kernel to use the updated packages." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download/Extract input and model/bundle files from Google Drive" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading...\n", + "From: https://drive.google.com/uc?id=1GC_N8YQk_mOWN02oOzAU_2YDmNRWk--n\n", + "To: ~/src/monai-deploy-app-sdk/notebooks/tutorials/ai_spleen_bundle_data.zip\n", + "104MB [00:10, 10.3MB/s] \n", + "Archive: ai_spleen_bundle_data.zip\n", + " creating: dcm/\n", + " inflating: dcm/IMG0001.dcm \n", + " inflating: dcm/IMG0002.dcm \n", + " inflating: dcm/IMG0003.dcm \n", + " inflating: dcm/IMG0004.dcm \n", + " inflating: dcm/IMG0005.dcm \n", + " inflating: dcm/IMG0006.dcm \n", + " inflating: dcm/IMG0007.dcm \n", + " inflating: dcm/IMG0008.dcm \n", + " inflating: dcm/IMG0009.dcm \n", + "... \n", + " inflating: dcm/IMG0509.dcm \n", + " inflating: dcm/IMG0510.dcm \n", + " inflating: dcm/IMG0511.dcm \n", + " inflating: dcm/IMG0512.dcm \n", + " inflating: dcm/IMG0513.dcm \n", + " inflating: dcm/IMG0514.dcm \n", + " inflating: dcm/IMG0515.dcm \n", + " inflating: model.ts \n" + ] + } + ], + "source": [ + "# Download the test data and MONAI bundle zip file\n", + "!pip install gdown \n", + "!gdown \"https://drive.google.com/uc?id=1cJq0iQh_yzYIxVElSlVa141aEmHZADJh\"\n", + "\n", + "# After downloading ai_spleen_bundle_data zip file from the web browser or using gdown,\n", + "!unzip -o \"ai_spleen_bundle_data.zip\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup imports\n", + "\n", + "Let's import necessary classes/decorators to define Application and Operator." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "from monai.deploy.core import Application, resource\n", + "from monai.deploy.core.domain import Image\n", + "from monai.deploy.core.io_type import IOType\n", + "from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator\n", + "from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator\n", + "from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator\n", + "from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator\n", + "from monai.deploy.operators.monai_bundle_inference_operator import IOMapping, MonaiBundleInferenceOperator\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Determining the Input and Output for the Model Bundle Inference Operator\n", + "\n", + "The App SDK provides a `MonaiBundleInferenceOperator` class to perform inference with a MONAI Bundle, which is essentially a PyTorch model in TorchScript with additional metadata describing the model network and processing specification. This operator uses the MONAI utilities to parse a MONAI Bundle to automatically instantiate the objects required for input and output processing as well as inference, as such it depends on MONAI transforms, inferers, and in turn their dependencies.\n", + "\n", + "Each Operator class inherits from the base [Operator](/modules/_autosummary/monai.deploy.core.Operator) class, and its input/output properties are specified by using [@input](/modules/_autosummary/monai.deploy.core.input)/[@output](/modules/_autosummary/monai.deploy.core.output) decorators. For the `MonaiBundleInferenceOperator` class, the input/output need to be defined to match those of the model network, both in name and data type. For the current release, an `IOMapping` object is used to connect the operator input/output to those of the model network by using the same names. This is likely to change, to be automated, in the future releases once certain limitation in the App SDK is removed.\n", + "\n", + "The Spleen CT Segmentation model network has a named input, called \"image\", and the named output called \"pred\", and both are of image type, which can all be mapped to the App SDK [Image](/modules/_autosummary/monai.deploy.core.domain.Image). This piece of information is typically acquired by examining the model metadata `network_data_format` attribute in the bundle, as seen in this [example] (https://github.com/Project-MONAI/model-zoo/blob/dev/models/spleen_ct_segmentation/configs/metadata.json)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating Application class\n", + "\n", + "Our application class would look like below.\n", + "\n", + "It defines `App` class, inheriting [Application](/modules/_autosummary/monai.deploy.core.Application) class.\n", + "\n", + "The requirements (resource and package dependency) for the App can be specified by using [@resource](/modules/_autosummary/monai.deploy.core.resource) and [@env](/modules/_autosummary/monai.deploy.core.env) decorators.\n", + "\n", + "The base class method, `compose`, is overridden. Objects required for DICOM parsing, series selection (selecting the first series for the current release), pixel data conversion to volume image, and segmentation instance creation are created, so is the model-specific `SpleenSegOperator`. The execution pipeline, as a Directed Acyclic Graph, is created by connecting these objects through self.add_flow()." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "@resource(cpu=1, gpu=1, memory=\"7Gi\")\n", + "class AISpleenSegApp(Application):\n", + " def __init__(self, *args, **kwargs):\n", + " \"\"\"Creates an application instance.\"\"\"\n", + " self._logger = logging.getLogger(\"{}.{}\".format(__name__, type(self).__name__))\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " def run(self, *args, **kwargs):\n", + " self._logger.info(f\"Begin {self.run.__name__}\")\n", + " super().run(*args, **kwargs)\n", + " self._logger.info(f\"End {self.run.__name__}\")\n", + "\n", + " def compose(self):\n", + " \"\"\"Creates the app specific operators and chain them up in the processing DAG.\"\"\"\n", + "\n", + " logging.info(f\"Begin {self.compose.__name__}\")\n", + "\n", + " study_loader_op = DICOMDataLoaderOperator()\n", + " series_selector_op = DICOMSeriesSelectorOperator()\n", + " series_to_vol_op = DICOMSeriesToVolumeOperator()\n", + "\n", + " # Create the inference operator that supports MONAI Bundle and automates the inference.\n", + " # The IOMapping labels match the input and prediction keys in the pre and post processing.\n", + " # The model_name is optional when the app has only one model.\n", + " # The bundle_path argument optionally can be set to an accessible bundle file path in the dev\n", + " # environment, so when the app is packaged into a MAP, the operator can complete the bundle parsing\n", + " # during init to provide the optional packages info, parsed from the bundle, to the packager\n", + " # for it to install the packages in the MAP docker image.\n", + " # Setting output IOType to DISK only works only for leaf operators, not the case in this example.\n", + " bundle_spleen_seg_op = MonaiBundleInferenceOperator(\n", + " input_mapping=[IOMapping(\"image\", Image, IOType.IN_MEMORY)],\n", + " output_mapping=[IOMapping(\"pred\", Image, IOType.IN_MEMORY)],\n", + " )\n", + "\n", + " # Create DICOM Seg writer with segment label name in a string list\n", + " dicom_seg_writer = DICOMSegmentationWriterOperator(seg_labels=[\"Spleen\"])\n", + "\n", + " # Create the processing pipeline, by specifying the upstream and downstream operators, and\n", + " # ensuring the output from the former matches the input of the latter, in both name and type.\n", + " self.add_flow(study_loader_op, series_selector_op, {\"dicom_study_list\": \"dicom_study_list\"})\n", + " self.add_flow(\n", + " series_selector_op, series_to_vol_op, {\"study_selected_series_list\": \"study_selected_series_list\"}\n", + " )\n", + " self.add_flow(series_to_vol_op, bundle_spleen_seg_op, {\"image\": \"image\"})\n", + " # Note below the dicom_seg_writer requires two inputs, each coming from a upstream operator.\n", + " self.add_flow(\n", + " series_selector_op, dicom_seg_writer, {\"study_selected_series_list\": \"study_selected_series_list\"}\n", + " )\n", + " self.add_flow(bundle_spleen_seg_op, dicom_seg_writer, {\"pred\": \"seg_image\"})\n", + "\n", + " logging.info(f\"End {self.compose.__name__}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Executing app locally\n", + "\n", + "We can execute the app in the Jupyter notebook. Note that the DICOM files of the CT Abdomen series must be present in the `dcm` and the Torch Script model at `model.ts`. Please use the actual path in your environment.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mGoing to initiate execution of operator DICOMDataLoaderOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMDataLoaderOperator \u001b[33m(Process ID: 4288, Operator ID: 5874a40f-a44b-4f81-b9d7-86456fc82732)\u001b[39m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2022-07-07 20:10:19,497] [WARNING] (root) - No selection rules given; select all series.\n", + "[2022-07-07 20:10:19,498] [INFO] (root) - Working on study, instance UID: 1.2.826.0.1.3680043.2.1125.1.67295333199898911264201812221946213\n", + "[2022-07-07 20:10:19,499] [INFO] (root) - Working on series, instance UID: 1.2.826.0.1.3680043.2.1125.1.68102559796966796813942775094416763\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mDone performing execution of operator DICOMDataLoaderOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSeriesSelectorOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSeriesSelectorOperator \u001b[33m(Process ID: 4288, Operator ID: d7802157-c044-452f-bbf4-ba381d04e474)\u001b[39m\n", + "Working on study, instance UID: 1.2.826.0.1.3680043.2.1125.1.67295333199898911264201812221946213\n", + "Working on series, instance UID: 1.2.826.0.1.3680043.2.1125.1.68102559796966796813942775094416763\n", + "\u001b[34mDone performing execution of operator DICOMSeriesSelectorOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSeriesToVolumeOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSeriesToVolumeOperator \u001b[33m(Process ID: 4288, Operator ID: 6ae16897-42b4-454f-9598-338682aa0dae)\u001b[39m\n", + "\u001b[34mDone performing execution of operator DICOMSeriesToVolumeOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator MonaiBundleInferenceOperator\u001b[39m\n", + "\u001b[32mExecuting operator MonaiBundleInferenceOperator \u001b[33m(Process ID: 4288, Operator ID: 910d8955-578e-45ef-b60e-f4e0fc7d4b06)\u001b[39m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2022-07-07 20:10:47,624] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Number of DICOM instance datasets in the list: 515\n", + "[2022-07-07 20:10:47,624] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Number of slices in the numpy image: 515\n", + "[2022-07-07 20:10:47,625] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Labels of the segments: ['Spleen']\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mDone performing execution of operator MonaiBundleInferenceOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSegmentationWriterOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSegmentationWriterOperator \u001b[33m(Process ID: 4288, Operator ID: e5fd81c6-0e0c-438d-9e32-cb7c2a4fa481)\u001b[39m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2022-07-07 20:10:49,639] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Unique values in seg image: [0 1]\n", + "[2022-07-07 20:10:50,856] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Saving output file /home/mqin/src/monai-app-sdk/notebooks/tutorials/output/dicom_seg-DICOMSEG.dcm\n", + "[2022-07-07 20:10:50,919] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - File saved.\n", + "[2022-07-07 20:10:50,926] [INFO] (__main__.AISpleenSegApp) - End run\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mDone performing execution of operator DICOMSegmentationWriterOperator\n", + "\u001b[39m\n" + ] + } + ], + "source": [ + "app = AISpleenSegApp()\n", + "\n", + "app.run(input=\"dcm\", output=\"output\", model=\"model.ts\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the application is verified inside Jupyter notebook, we can write the above Python code into Python files in an application folder.\n", + "\n", + "The application folder structure would look like below:\n", + "\n", + "```bash\n", + "my_app\n", + "├── __main__.py\n", + "└── app.py\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Create an application folder\n", + "!mkdir -p my_app" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### app.py" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting my_app/app.py\n" + ] + } + ], + "source": [ + "%%writefile my_app/app.py\n", + "import logging\n", + "\n", + "from monai.deploy.core import Application, resource\n", + "from monai.deploy.core.domain import Image\n", + "from monai.deploy.core.io_type import IOType\n", + "from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator\n", + "from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator\n", + "from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator\n", + "from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator\n", + "from monai.deploy.operators.monai_bundle_inference_operator import IOMapping, MonaiBundleInferenceOperator\n", + "\n", + "\n", + "@resource(cpu=1, gpu=1, memory=\"7Gi\")\n", + "class AISpleenSegApp(Application):\n", + " def __init__(self, *args, **kwargs):\n", + " \"\"\"Creates an application instance.\"\"\"\n", + " self._logger = logging.getLogger(\"{}.{}\".format(__name__, type(self).__name__))\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " def run(self, *args, **kwargs):\n", + " # This method calls the base class to run. Can be omitted if simply calling through.\n", + " self._logger.info(f\"Begin {self.run.__name__}\")\n", + " super().run(*args, **kwargs)\n", + " self._logger.info(f\"End {self.run.__name__}\")\n", + "\n", + " def compose(self):\n", + " \"\"\"Creates the app specific operators and chain them up in the processing DAG.\"\"\"\n", + "\n", + " logging.info(f\"Begin {self.compose.__name__}\")\n", + "\n", + " # Create the custom operator(s) as well as SDK built-in operator(s).\n", + " study_loader_op = DICOMDataLoaderOperator()\n", + " series_selector_op = DICOMSeriesSelectorOperator()\n", + " series_to_vol_op = DICOMSeriesToVolumeOperator()\n", + "\n", + " # Create the inference operator that supports MONAI Bundle and automates the inference.\n", + " # The IOMapping labels match the input and prediction keys in the pre and post processing.\n", + " # The model_name is optional when the app has only one model.\n", + " # The bundle_path argument optionally can be set to an accessible bundle file path in the dev\n", + " # environment, so when the app is packaged into a MAP, the operator can complete the bundle parsing\n", + " # during init to provide the optional packages info, parsed from the bundle, to the packager\n", + " # for it to install the packages in the MAP docker image.\n", + " # Setting output IOType to DISK only works only for leaf operators, not the case in this example.\n", + " bundle_spleen_seg_op = MonaiBundleInferenceOperator(\n", + " input_mapping=[IOMapping(\"image\", Image, IOType.IN_MEMORY)],\n", + " output_mapping=[IOMapping(\"pred\", Image, IOType.IN_MEMORY)],\n", + " )\n", + "\n", + " # Create DICOM Seg writer with segment label name in a string list\n", + " dicom_seg_writer = DICOMSegmentationWriterOperator(seg_labels=[\"Spleen\"])\n", + "\n", + " # Create the processing pipeline, by specifying the upstream and downstream operators, and\n", + " # ensuring the output from the former matches the input of the latter, in both name and type.\n", + " self.add_flow(study_loader_op, series_selector_op, {\"dicom_study_list\": \"dicom_study_list\"})\n", + " self.add_flow(\n", + " series_selector_op, series_to_vol_op, {\"study_selected_series_list\": \"study_selected_series_list\"}\n", + " )\n", + " self.add_flow(series_to_vol_op, bundle_spleen_seg_op, {\"image\": \"image\"})\n", + " # Note below the dicom_seg_writer requires two inputs, each coming from a upstream operator.\n", + " self.add_flow(\n", + " series_selector_op, dicom_seg_writer, {\"study_selected_series_list\": \"study_selected_series_list\"}\n", + " )\n", + " self.add_flow(bundle_spleen_seg_op, dicom_seg_writer, {\"pred\": \"seg_image\"})\n", + "\n", + " logging.info(f\"End {self.compose.__name__}\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " # Creates the app and test it standalone. When running in this mode, please note the following:\n", + " # -m , for model file path\n", + " # -i , for input DICOM CT series folder\n", + " # -o , for the output folder, default $PWD/output\n", + " # e.g.\n", + " # monai-deploy exec app.py -i input -m model/model.ts\n", + " #\n", + " logging.basicConfig(level=logging.DEBUG)\n", + " app_instance = AISpleenSegApp(do_run=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```python\n", + "if __name__ == \"__main__\":\n", + " AISpleenSegApp(do_run=True)\n", + "```\n", + "\n", + "The above lines are needed to execute the application code by using `python` interpreter.\n", + "\n", + "### \\_\\_main\\_\\_.py\n", + "\n", + "\\_\\_main\\_\\_.py is needed for MONAI Application Packager to detect the main application code (`app.py`) when the application is executed with the application folder path (e.g., `python simple_imaging_app`)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting my_app/__main__.py\n" + ] + } + ], + "source": [ + "%%writefile my_app/__main__.py\n", + "from app import AISpleenSegApp\n", + "\n", + "if __name__ == \"__main__\":\n", + " AISpleenSegApp(do_run=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "app.py\t__main__.py __pycache__\n" + ] + } + ], + "source": [ + "!ls my_app" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this time, let's execute the app in the command line." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mGoing to initiate execution of operator DICOMDataLoaderOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMDataLoaderOperator \u001b[33m(Process ID: 4765, Operator ID: b02877e6-a841-43e6-9267-53c2d24402fb)\u001b[39m\n", + "\u001b[34mDone performing execution of operator DICOMDataLoaderOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSeriesSelectorOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSeriesSelectorOperator \u001b[33m(Process ID: 4765, Operator ID: 44b01560-4417-4173-853f-2a4d6d751892)\u001b[39m\n", + "[2022-07-07 20:12:09,138] [WARNING] (root) - No selection rules given; select all series.\n", + "[2022-07-07 20:12:09,138] [INFO] (root) - Working on study, instance UID: 1.2.826.0.1.3680043.2.1125.1.67295333199898911264201812221946213\n", + "Working on study, instance UID: 1.2.826.0.1.3680043.2.1125.1.67295333199898911264201812221946213\n", + "[2022-07-07 20:12:09,138] [INFO] (root) - Working on series, instance UID: 1.2.826.0.1.3680043.2.1125.1.68102559796966796813942775094416763\n", + "Working on series, instance UID: 1.2.826.0.1.3680043.2.1125.1.68102559796966796813942775094416763\n", + "\u001b[34mDone performing execution of operator DICOMSeriesSelectorOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSeriesToVolumeOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSeriesToVolumeOperator \u001b[33m(Process ID: 4765, Operator ID: 58d2e9ea-0c95-4d27-8535-02fc6a3106d3)\u001b[39m\n", + "\u001b[34mDone performing execution of operator DICOMSeriesToVolumeOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator MonaiBundleInferenceOperator\u001b[39m\n", + "\u001b[32mExecuting operator MonaiBundleInferenceOperator \u001b[33m(Process ID: 4765, Operator ID: 1d361473-ae16-4f98-9406-71b7cc11f983)\u001b[39m\n", + "\u001b[34mDone performing execution of operator MonaiBundleInferenceOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSegmentationWriterOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSegmentationWriterOperator \u001b[33m(Process ID: 4765, Operator ID: 76fcb554-8754-49bd-95aa-91496465196e)\u001b[39m\n", + "[2022-07-07 20:12:36,824] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Number of DICOM instance datasets in the list: 515\n", + "[2022-07-07 20:12:36,824] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Number of slices in the numpy image: 515\n", + "[2022-07-07 20:12:36,824] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Labels of the segments: ['Spleen']\n", + "[2022-07-07 20:12:38,800] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Unique values in seg image: [0 1]\n", + "[2022-07-07 20:12:39,891] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Saving output file /home/mqin/src/monai-app-sdk/notebooks/tutorials/output/dicom_seg-DICOMSEG.dcm\n", + "[2022-07-07 20:12:39,954] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - File saved.\n", + "\u001b[34mDone performing execution of operator DICOMSegmentationWriterOperator\n", + "\u001b[39m\n", + "[2022-07-07 20:12:39,958] [INFO] (app.AISpleenSegApp) - End run\n" + ] + } + ], + "source": [ + "!python my_app -i dcm -o output -m model.ts" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Above command is same with the following command line:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[34mGoing to initiate execution of operator DICOMDataLoaderOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMDataLoaderOperator \u001b[33m(Process ID: 4852, Operator ID: 5d746e23-c941-4b9f-8d59-1209f6674b2e)\u001b[39m\n", + "\u001b[34mDone performing execution of operator DICOMDataLoaderOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSeriesSelectorOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSeriesSelectorOperator \u001b[33m(Process ID: 4852, Operator ID: 0b43252c-4e8c-45f7-be3e-b6a0f1f9211f)\u001b[39m\n", + "[2022-07-07 20:12:58,224] [WARNING] (root) - No selection rules given; select all series.\n", + "[2022-07-07 20:12:58,224] [INFO] (root) - Working on study, instance UID: 1.2.826.0.1.3680043.2.1125.1.67295333199898911264201812221946213\n", + "Working on study, instance UID: 1.2.826.0.1.3680043.2.1125.1.67295333199898911264201812221946213\n", + "[2022-07-07 20:12:58,224] [INFO] (root) - Working on series, instance UID: 1.2.826.0.1.3680043.2.1125.1.68102559796966796813942775094416763\n", + "Working on series, instance UID: 1.2.826.0.1.3680043.2.1125.1.68102559796966796813942775094416763\n", + "\u001b[34mDone performing execution of operator DICOMSeriesSelectorOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSeriesToVolumeOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSeriesToVolumeOperator \u001b[33m(Process ID: 4852, Operator ID: 8a3b960f-6c2e-4ba1-83cc-155ae4ff1771)\u001b[39m\n", + "\u001b[34mDone performing execution of operator DICOMSeriesToVolumeOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator MonaiBundleInferenceOperator\u001b[39m\n", + "\u001b[32mExecuting operator MonaiBundleInferenceOperator \u001b[33m(Process ID: 4852, Operator ID: b60cb240-43d7-482e-85c1-bdce50bd87be)\u001b[39m\n", + "\u001b[34mDone performing execution of operator MonaiBundleInferenceOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSegmentationWriterOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSegmentationWriterOperator \u001b[33m(Process ID: 4852, Operator ID: 0b34cf71-0fb2-4900-8ed7-872348a0f772)\u001b[39m\n", + "[2022-07-07 20:13:26,091] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Number of DICOM instance datasets in the list: 515\n", + "[2022-07-07 20:13:26,091] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Number of slices in the numpy image: 515\n", + "[2022-07-07 20:13:26,091] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Labels of the segments: ['Spleen']\n", + "[2022-07-07 20:13:28,081] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Unique values in seg image: [0 1]\n", + "[2022-07-07 20:13:29,127] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Saving output file /home/mqin/src/monai-app-sdk/notebooks/tutorials/output/dicom_seg-DICOMSEG.dcm\n", + "[2022-07-07 20:13:29,189] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - File saved.\n", + "\u001b[34mDone performing execution of operator DICOMSegmentationWriterOperator\n", + "\u001b[39m\n", + "[2022-07-07 20:13:29,194] [INFO] (app.AISpleenSegApp) - End run\n" + ] + } + ], + "source": [ + "import os\n", + "os.environ['MKL_THREADING_LAYER'] = 'GNU'\n", + "!monai-deploy exec my_app -i dcm -o output -m model.ts" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dicom_seg-DICOMSEG.dcm\n" + ] + } + ], + "source": [ + "!ls output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Packaging app" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's package the app with [MONAI Application Packager](/developing_with_sdk/packaging_app)." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2022-07-07 20:14:05,029] [INFO] (root) - Begin compose\n", + "[2022-07-07 20:14:05,030] [INFO] (root) - End compose\n", + "Building MONAI Application Package... Done\n", + "[2022-07-07 20:14:05,630] [INFO] (app_packager) - Successfully built my_app:latest\n" + ] + } + ], + "source": [ + "!monai-deploy package -b nvcr.io/nvidia/pytorch:21.11-py3 my_app --tag my_app:latest -m model.ts" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::{note}\n", + "Building a MONAI Application Package (Docker image) can take time. Use `-l DEBUG` option if you want to see the progress.\n", + ":::\n", + "\n", + "We can see that the Docker image is created." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "my_app latest 05c843c720ad 2 hours ago 15.3GB\n" + ] + } + ], + "source": [ + "!docker image ls | grep my_app" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Executing packaged app locally\n", + "\n", + "The packaged app can be run locally through [MONAI Application Runner](/developing_with_sdk/executing_packaged_app_locally)." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checking dependencies...\n", + "--> Verifying if \"docker\" is installed...\n", + "\n", + "--> Verifying if \"my_app:latest\" is available...\n", + "\n", + "Checking for MAP \"my_app:latest\" locally\n", + "\"my_app:latest\" found.\n", + "\n", + "Reading MONAI App Package manifest...\n", + "--> Verifying if \"nvidia-docker\" is installed...\n", + "\n", + "/opt/conda/lib/python3.8/site-packages/scipy/__init__.py:138: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.0)\n", + " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion} is required for this version of \"\n", + "INFO:root:Begin compose\n", + "DEBUG:root:Bundle path, None, not valid. Will get it in the execution context.\n", + "INFO:root:End compose\n", + "INFO:__main__.AISpleenSegApp:Begin run\n", + "\u001b[34mGoing to initiate execution of operator DICOMDataLoaderOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMDataLoaderOperator \u001b[33m(Process ID: 1, Operator ID: a2b6f335-f5f5-4327-a681-439b36091ca2)\u001b[39m\n", + "\u001b[34mDone performing execution of operator DICOMDataLoaderOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSeriesSelectorOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSeriesSelectorOperator \u001b[33m(Process ID: 1, Operator ID: 4e7f84f5-ca5f-42ba-8984-86ee5c527594)\u001b[39m\n", + "[2022-07-08 03:14:33,651] [WARNING] (root) - No selection rules given; select all series.\n", + "[2022-07-08 03:14:33,651] [INFO] (root) - Working on study, instance UID: 1.2.826.0.1.3680043.2.1125.1.67295333199898911264201812221946213\n", + "Working on study, instance UID: 1.2.826.0.1.3680043.2.1125.1.67295333199898911264201812221946213\n", + "[2022-07-08 03:14:33,651] [INFO] (root) - Working on series, instance UID: 1.2.826.0.1.3680043.2.1125.1.68102559796966796813942775094416763\n", + "Working on series, instance UID: 1.2.826.0.1.3680043.2.1125.1.68102559796966796813942775094416763\n", + "\u001b[34mDone performing execution of operator DICOMSeriesSelectorOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSeriesToVolumeOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSeriesToVolumeOperator \u001b[33m(Process ID: 1, Operator ID: 413e0366-114c-4d06-9faf-3a8f4ba6863f)\u001b[39m\n", + "\u001b[34mDone performing execution of operator DICOMSeriesToVolumeOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator MonaiBundleInferenceOperator\u001b[39m\n", + "\u001b[32mExecuting operator MonaiBundleInferenceOperator \u001b[33m(Process ID: 1, Operator ID: f250496d-9b15-4ab3-83b3-6a608f07dd7c)\u001b[39m\n", + "\u001b[34mDone performing execution of operator MonaiBundleInferenceOperator\n", + "\u001b[39m\n", + "\u001b[34mGoing to initiate execution of operator DICOMSegmentationWriterOperator\u001b[39m\n", + "\u001b[32mExecuting operator DICOMSegmentationWriterOperator \u001b[33m(Process ID: 1, Operator ID: b85c589b-054c-4ecd-b8d9-cb7745bd9011)\u001b[39m\n", + "[2022-07-08 03:15:00,796] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Number of DICOM instance datasets in the list: 515\n", + "[2022-07-08 03:15:00,796] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Number of slices in the numpy image: 515\n", + "[2022-07-08 03:15:00,796] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Labels of the segments: ['Spleen']\n", + "[2022-07-08 03:15:02,511] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Unique values in seg image: [0 1]\n", + "/root/.local/lib/python3.8/site-packages/pydicom/valuerep.py:290: UserWarning: Invalid value for VR DA: '2019-09-16'.\n", + " warnings.warn(msg)\n", + "/root/.local/lib/python3.8/site-packages/pydicom/valuerep.py:290: UserWarning: The value length (94) exceeds the maximum length of 64 allowed for VR LO.\n", + " warnings.warn(msg)\n", + "[2022-07-08 03:15:03,694] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - Saving output file /var/monai/output/dicom_seg-DICOMSEG.dcm\n", + "[2022-07-08 03:15:03,797] [INFO] (monai.deploy.operators.dicom_seg_writer_operator.DICOMSegWriter) - File saved.\n", + "\u001b[34mDone performing execution of operator DICOMSegmentationWriterOperator\n", + "\u001b[39m\n", + "[2022-07-08 03:15:03,803] [INFO] (__main__.AISpleenSegApp) - End run\n" + ] + } + ], + "source": [ + "# Copy DICOM files are in 'dcm' folder\n", + "\n", + "# Launch the app\n", + "!monai-deploy run my_app:latest dcm output" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dicom_seg-DICOMSEG.dcm\n" + ] + } + ], + "source": [ + "!ls output" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 4562c36e521f8d37ccbf1e1466b77c24b55f5b5d Mon Sep 17 00:00:00 2001 From: mmelqin Date: Fri, 8 Jul 2022 12:22:10 -0700 Subject: [PATCH 12/13] Updating the notebook descriptions. Signed-off-by: mmelqin --- notebooks/tutorials/03_segmentation_app.ipynb | 2 +- notebooks/tutorials/05_full_tutorial.ipynb | 2 +- notebooks/tutorials/06_monai_bundle_app.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/notebooks/tutorials/03_segmentation_app.ipynb b/notebooks/tutorials/03_segmentation_app.ipynb index 0c5c49ac..23d9cc36 100644 --- a/notebooks/tutorials/03_segmentation_app.ipynb +++ b/notebooks/tutorials/03_segmentation_app.ipynb @@ -6,7 +6,7 @@ "source": [ "# Creating a Segmentation App with MONAI Deploy App SDK\n", "\n", - "This tutorial shows how to create an organ segmentation application for a PyTorch model that has been trained with MONAI.\n", + "This tutorial shows how to create an organ segmentation application for a PyTorch model that has been trained with MONAI. Please note that this tutorial is based on the [earlier version](https://github.com/Project-MONAI/monai-deploy-app-sdk/blob/7615d73f6ec2125ba5d2e3480f85b060e95b81e4/examples/apps/ai_spleen_seg_app/app.py) of the Spleen Segmentation Application.\n", "\n", "Deploying AI models requires the integration with clinical imaging network, even if in a for-research-use setting. This means that the AI deploy application will need to support standards-based imaging protocols, and specifically for Radiological imaging, DICOM protocol.\n", "\n", diff --git a/notebooks/tutorials/05_full_tutorial.ipynb b/notebooks/tutorials/05_full_tutorial.ipynb index c4333f5b..80b2c570 100644 --- a/notebooks/tutorials/05_full_tutorial.ipynb +++ b/notebooks/tutorials/05_full_tutorial.ipynb @@ -6,7 +6,7 @@ "source": [ "# Full Tutorial Building and Deploying Segmentation App with MONAI Inference Service (MIS)\n", "\n", - "This tutorial begins with creating an organ segmentation application using MONAI App SDK for a PyTorch model that has been trained with MONAI. Then this tutorial transitions into discussing how to deploy the segmentation application with the RESTful [MONAI Inference Service](https://github.com/Project-MONAI/monai-deploy-app-server/blob/main/components/inference-service/README.md).\n", + "This tutorial begins with creating an organ segmentation application using MONAI App SDK for a PyTorch model that has been trained with MONAI. Then this tutorial transitions into discussing how to deploy the segmentation application with the RESTful [MONAI Inference Service](https://github.com/Project-MONAI/monai-deploy-app-server/blob/main/components/inference-service/README.md). Please note that this tutorial is based on the [earlier version](https://github.com/Project-MONAI/monai-deploy-app-sdk/blob/7615d73f6ec2125ba5d2e3480f85b060e95b81e4/examples/apps/ai_spleen_seg_app/app.py) of the Spleen Segmentation Application.\n", "\n", "In the following sections, we will demonstrate how to create a MONAI Deploy application package using the MONAI Deploy App SDK and then will demonstrate how to deploy this package with the [MONAI Inference Service](https://github.com/Project-MONAI/monai-deploy-app-server/blob/main/components/inference-service/README.md). Along the way we will provide verification steps to confirm that our application produces the desired output both locally (for verification) and as a service output.\n", "\n", diff --git a/notebooks/tutorials/06_monai_bundle_app.ipynb b/notebooks/tutorials/06_monai_bundle_app.ipynb index 35bff664..2d940689 100644 --- a/notebooks/tutorials/06_monai_bundle_app.ipynb +++ b/notebooks/tutorials/06_monai_bundle_app.ipynb @@ -135,7 +135,7 @@ "output_type": "stream", "text": [ "Downloading...\n", - "From: https://drive.google.com/uc?id=1GC_N8YQk_mOWN02oOzAU_2YDmNRWk--n\n", + "From: https://drive.google.com/uc?id=1cJq0iQh_yzYIxVElSlVa141aEmHZADJh\n", "To: ~/src/monai-deploy-app-sdk/notebooks/tutorials/ai_spleen_bundle_data.zip\n", "104MB [00:10, 10.3MB/s] \n", "Archive: ai_spleen_bundle_data.zip\n", From 6fb7dd01a851690c02c7ab4812bce2578b1dc369 Mon Sep 17 00:00:00 2001 From: mmelqin Date: Mon, 11 Jul 2022 15:16:35 -0700 Subject: [PATCH 13/13] Update and make consistent the user guides Signed-off-by: mmelqin --- docs/source/getting_started/examples.md | 2 ++ .../tutorials/03_segmentation_app.md | 19 ++++++++++--------- .../source/getting_started/tutorials/index.md | 1 + notebooks/tutorials/03_segmentation_app.ipynb | 16 ++++++++-------- notebooks/tutorials/05_full_tutorial.ipynb | 16 ++++++++-------- 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/docs/source/getting_started/examples.md b/docs/source/getting_started/examples.md index 40b23192..d981f90c 100644 --- a/docs/source/getting_started/examples.md +++ b/docs/source/getting_started/examples.md @@ -5,7 +5,9 @@ has example apps that you can see. - ai_spleen_seg_app +- ai_livertumor_seg_app - ai_unetr_seg_app - dicom_series_to_image_app - mednist_classifier_monaideploy - simple_imaging_app +- deply_app_on_aarch64 diff --git a/docs/source/getting_started/tutorials/03_segmentation_app.md b/docs/source/getting_started/tutorials/03_segmentation_app.md index e277108b..4e7f5a1f 100644 --- a/docs/source/getting_started/tutorials/03_segmentation_app.md +++ b/docs/source/getting_started/tutorials/03_segmentation_app.md @@ -15,7 +15,7 @@ jupyter-lab ``` ## Executing from Jupyter Notebook - +Please note that the example code used in the Jupyter Notebook is based on an earlier version of the segmentation application, hence not the same as the latest source code on Github, e.g. not using MONAI Bundle inference operator. ```{toctree} :maxdepth: 4 @@ -43,7 +43,7 @@ jupyter-lab ``` ## Executing from Shell - +Please note that this part of the example uses the latest application source code on Github, as well as the corresponding test data. ```bash # Clone the github project (the latest version of main branch only) git clone --branch main --depth 1 https://github.com/Project-MONAI/monai-deploy-app-sdk.git @@ -53,17 +53,18 @@ cd monai-deploy-app-sdk # Install monai-deploy-app-sdk package pip install monai-deploy-app-sdk -# Download/Extract ai_spleen_seg_data zip file from https://drive.google.com/file/d/1GC_N8YQk_mOWN02oOzAU_2YDmNRWk--n/view?usp=sharing +# Download/Extract ai_spleen_bundle_data zip file from https://drive.google.com/file/d/1cJq0iQh_yzYIxVElSlVa141aEmHZADJh/view?usp=sharing -# Download ai_spleen_seg_data.zip +# Download ai_spleen_bundle_data.zip pip install gdown -gdown https://drive.google.com/uc?id=1GC_N8YQk_mOWN02oOzAU_2YDmNRWk--n +gdown https://drive.google.com/uc?id=1cJq0iQh_yzYIxVElSlVa141aEmHZADJh -# After downloading ai_spleen_seg_data.zip from the web browser or using gdown, -unzip -o ai_spleen_seg_data_updated_1203.zip +# After downloading ai_spleen_bundle_data.zip from the web browser or using gdown, +unzip -o ai_spleen_bundle_data.zip -# Install necessary packages from the app -pip install monai pydicom SimpleITK Pillow nibabel +# Install necessary packages from the app; note that numpy-stl and trimesh are only +# needed if the application uses the STL Conversion Operator +pip install monai pydicom SimpleITK Pillow nibabel scikit-image numpy-stl trimesh # Local execution of the app python examples/apps/ai_spleen_seg_app/app.py -i dcm/ -o output -m model.ts diff --git a/docs/source/getting_started/tutorials/index.md b/docs/source/getting_started/tutorials/index.md index d3a292ee..2077b140 100644 --- a/docs/source/getting_started/tutorials/index.md +++ b/docs/source/getting_started/tutorials/index.md @@ -9,4 +9,5 @@ 03_segmentation_app 04_mis_tutorial 05_full_tutorial +06_monai_bundle_app ``` diff --git a/notebooks/tutorials/03_segmentation_app.ipynb b/notebooks/tutorials/03_segmentation_app.ipynb index 23d9cc36..8f9bb6a0 100644 --- a/notebooks/tutorials/03_segmentation_app.ipynb +++ b/notebooks/tutorials/03_segmentation_app.ipynb @@ -120,7 +120,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Download/Extract ai_spleen_seg_data from Google Drive" + "### Download/Extract ai_spleen_bundle_data from Google Drive" ] }, { @@ -144,10 +144,10 @@ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (1.26.6)\n", "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (1.7.1)\n", "Downloading...\n", - "From: https://drive.google.com/uc?id=1GC_N8YQk_mOWN02oOzAU_2YDmNRWk--n\n", - "To: ~/src/monai-deploy-app-sdk/notebooks/tutorials/ai_spleen_seg_data_update_1203.zip\n", + "From: https://drive.google.com/uc?id=1cJq0iQh_yzYIxVElSlVa141aEmHZADJh\n", + "To: ~/src/monai-deploy-app-sdk/notebooks/tutorials/ai_spleen_bundle_data.zip\n", "104MB [00:10, 10.3MB/s] \n", - "Archive: ai_spleen_seg_data_updated_1203.zip\n", + "Archive: ai_spleen_bundle_data.zip\n", " creating: dcm/\n", " inflating: dcm/IMG0001.dcm \n", " inflating: dcm/IMG0002.dcm \n", @@ -669,12 +669,12 @@ } ], "source": [ - "# Download ai_spleen_seg_data test data zip file\n", + "# Download ai_spleen_bundle_data test data zip file\n", "!pip install gdown \n", - "!gdown \"https://drive.google.com/uc?id=1GC_N8YQk_mOWN02oOzAU_2YDmNRWk--n\"\n", + "!gdown \"https://drive.google.com/uc?id=1cJq0iQh_yzYIxVElSlVa141aEmHZADJh\"\n", "\n", - "# After downloading ai_spleen_seg_data zip file from the web browser or using gdown,\n", - "!unzip -o \"ai_spleen_seg_data_updated_1203.zip\"" + "# After downloading ai_spleen_bundle_data zip file from the web browser or using gdown,\n", + "!unzip -o \"ai_spleen_bundle_data.zip\"" ] }, { diff --git a/notebooks/tutorials/05_full_tutorial.ipynb b/notebooks/tutorials/05_full_tutorial.ipynb index 80b2c570..91d67c01 100644 --- a/notebooks/tutorials/05_full_tutorial.ipynb +++ b/notebooks/tutorials/05_full_tutorial.ipynb @@ -112,7 +112,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Download/Extract ai_spleen_seg_data from Google Drive" + "### Download/Extract ai_spleen_bundle_data from Google Drive" ] }, { @@ -136,10 +136,10 @@ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (1.26.6)\n", "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /home/gbae/miniconda3/envs/mednist/lib/python3.6/site-packages (from requests[socks]>=2.12.0->gdown) (1.7.1)\n", "Downloading...\n", - "From: https://drive.google.com/uc?id=1GC_N8YQk_mOWN02oOzAU_2YDmNRWk--n\n", - "To: ~/src/monai-deploy-app-sdk/notebooks/tutorials/ai_spleen_seg_data_updated_1203.zip\n", + "From: https://drive.google.com/uc?id=1cJq0iQh_yzYIxVElSlVa141aEmHZADJh\n", + "To: ~/src/monai-deploy-app-sdk/notebooks/tutorials/ai_spleen_bundle_data.zip\n", "104MB [00:10, 10.3MB/s] \n", - "Archive: ai_spleen_seg_data_update_1203.zip\n", + "Archive: ai_spleen_bundle_data.zip\n", " creating: dcm/\n", " inflating: dcm/IMG0001.dcm \n", " inflating: dcm/IMG0002.dcm \n", @@ -661,12 +661,12 @@ } ], "source": [ - "# Download ai_spleen_seg_data test data zip file\n", + "# Download ai_spleen_bundle_data test data zip file\n", "!pip install gdown \n", - "!gdown \"https://drive.google.com/uc?id=1GC_N8YQk_mOWN02oOzAU_2YDmNRWk--n\"\n", + "!gdown \"https://drive.google.com/uc?id=1cJq0iQh_yzYIxVElSlVa141aEmHZADJh\"\n", "\n", - "# After downloading ai_spleen_seg_data zip file from the web browser or using gdown,\n", - "!unzip -o \"ai_spleen_seg_data_updated_1203.zip\"" + "# After downloading ai_spleen_bundle_data zip file from the web browser or using gdown,\n", + "!unzip -o \"ai_spleen_bundle_data.zip\"" ] }, {