diff --git a/docs/source/getting_started/examples.md b/docs/source/getting_started/examples.md index e16873c3..521a4e52 100644 --- a/docs/source/getting_started/examples.md +++ b/docs/source/getting_started/examples.md @@ -13,3 +13,4 @@ - ai_unetr_seg_app - dicom_series_to_image_app - breast_density_classifer_app +- cchmc_ped_abd_ct_seg_app diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/README.md b/examples/apps/cchmc_ped_abd_ct_seg_app/README.md new file mode 100644 index 00000000..5e1a8a21 --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/README.md @@ -0,0 +1,40 @@ +# MONAI Application Package (MAP) for CCHMC Pediatric Abdominal CT Segmentation MONAI Bundle + +This MAP is based on the [CCHMC Pediatric Abdominal CT Segmentation MONAI Bundle](https://github.com/cchmc-dll/pediatric_abdominal_segmentation_bundle/tree/original). This model was developed at Cincinnati Children's Hospital Medical Center by the Department of Radiology. + +The PyTorch and TorchScript DynUNet models can be downloaded from the [MONAI Bundle Repository](https://github.com/cchmc-dll/pediatric_abdominal_segmentation_bundle/tree/original/models). + +For questions, please feel free to contact Elan Somasundaram (Elanchezhian.Somasundaram@cchmc.org) and Bryan Luna (Bryan.Luna@cchmc.org). + +## Unique Features + +Some unique features of this MAP pipeline include: +- **Custom Inference Operator:** custom `AbdomenSegOperator` enables either PyTorch or TorchScript model loading +- **DICOM Secondary Capture Output:** custom `DICOMSecondaryCaptureWriterOperator` writes a DICOM SC with organ contours +- **Output Filtering:** model produces Liver-Spleen-Pancreas segmentations, but seg visibility in the outputs (DICOM SEG, SC, SR) can be controlled in `app.py` +- **MONAI Deploy Express MongoDB Write:** custom operators (`MongoDBEntryCreatorOperator` and `MongoDBWriterOperator`) allow for writing to the MongoDB database associated with MONAI Deploy Express + +## Scripts +Several scripts have been compiled that quickly execute useful actions (such as running the app code locally with Python interpreter, MAP packaging, MAP execution, etc.). Some scripts require the input of command line arguments; review the `scripts` folder for more details. + +## Notes +The DICOM Series selection criteria has been customized based on the model's training and CCHMC use cases. By default, Axial CT series with Slice Thickness between 3.0 - 5.0 mm (inclusive) will be selected for. + +If MongoDB writing is not desired, please comment out the relevant sections in `app.py` and the `AbdomenSegOperator`. + +To execute the pipeline with MongoDB writing enabled, it is best to create a `.env` file that the `MongoDBWriterOperator` can load in. Below is an example `.env` file that follows the format outlined in this operator; note that these values are the default variable values as defined in the [.env](https://github.com/Project-MONAI/monai-deploy/blob/main/deploy/monai-deploy-express/.env) and [docker-compose.yaml](https://github.com/Project-MONAI/monai-deploy/blob/main/deploy/monai-deploy-express/docker-compose.yml) files of v0.6.0 of MONAI Deploy Express: + +```dotenv +MONGODB_USERNAME=root +MONGODB_PASSWORD=rootpassword +MONGODB_PORT=27017 +MONGODB_IP_DOCKER=172.17.0.1 # default Docker bridge network IP +``` + +Prior to packaging into a MAP, the MongoDB credentials should be hardcoded into the `MongoDBWriterOperator`. + +The MONAI Deploy Express MongoDB Docker container (`mdl-mongodb`) needs to be connected to the Docker bridge network in order for the MongoDB write to be successful. Executing the following command in a MONAI Deploy Express terminal will establish this connection: + +```bash +docker network connect bridge mdl-mongodb +``` diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/__init__.py b/examples/apps/cchmc_ped_abd_ct_seg_app/__init__.py new file mode 100644 index 00000000..06014cc7 --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# __init__.py is used to initialize a Python package +# ensures that the directory __init__.py resides in is included at the start of the sys.path +# this is useful when you want to import modules from this directory, even if it’s not the +# directory where your Python script is running. + +# give access to operating system and Python interpreter +import os +import sys + +# grab absolute path of directory containing __init__.py +_current_dir = os.path.abspath(os.path.dirname(__file__)) + +# if sys.path is not the same as the directory containing the __init__.py file +if sys.path and os.path.abspath(sys.path[0]) != _current_dir: + # insert directory containing __init__.py file at the beginning of sys.path + sys.path.insert(0, _current_dir) +# delete variable +del _current_dir diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/__main__.py b/examples/apps/cchmc_ped_abd_ct_seg_app/__main__.py new file mode 100644 index 00000000..80cca2fa --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/__main__.py @@ -0,0 +1,26 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# __main__.py is needed for MONAI Application Packager to detect the main app code (app.py) when +# app.py is executed in the application folder path +# e.g., python my_app + +import logging + +# import AIAbdomenSegApp class from app.py +from app import AIAbdomenSegApp + +# if __main__.py is being run directly +if __name__ == "__main__": + logging.info(f"Begin {__name__}") + # create and run an instance of AIAbdomenSegApp + AIAbdomenSegApp().run() + logging.info(f"End {__name__}") diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/abdomen_seg_operator.py b/examples/apps/cchmc_ped_abd_ct_seg_app/abdomen_seg_operator.py new file mode 100644 index 00000000..2f412f14 --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/abdomen_seg_operator.py @@ -0,0 +1,294 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import List + +import torch +from numpy import float32, int16 + +# import custom transforms from post_transforms.py +from post_transforms import CalculateVolumeFromMaskd, ExtractVolumeToTextd, LabelToContourd, OverlayImageLabeld + +import monai +from monai.deploy.core import AppContext, Fragment, Model, Operator, OperatorSpec +from monai.deploy.operators.monai_seg_inference_operator import InfererType, InMemImageReader, MonaiSegInferenceOperator +from monai.transforms import ( + Activationsd, + AsDiscreted, + CastToTyped, + Compose, + CropForegroundd, + EnsureChannelFirstd, + EnsureTyped, + Invertd, + LoadImaged, + Orientationd, + SaveImaged, + ScaleIntensityRanged, + Spacingd, +) + + +# this operator performs inference with the new version of the bundle +class AbdomenSegOperator(Operator): + """Performs segmentation inference with a custom model architecture.""" + + DEFAULT_OUTPUT_FOLDER = Path.cwd() / "output" + + def __init__( + self, + fragment: Fragment, + *args, + app_context: AppContext, + model_path: Path, + output_folder: Path = DEFAULT_OUTPUT_FOLDER, + output_labels: List[int], + **kwargs, + ): + + self._logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + self._input_dataset_key = "image" + self._pred_dataset_key = "pred" + + # self.model_path is compatible with TorchScript and PyTorch model workflows (pythonic and MAP) + self.model_path = self._find_model_file_path(model_path) + + self.output_folder = output_folder + self.output_folder.mkdir(parents=True, exist_ok=True) + self.output_labels = output_labels + self.app_context = app_context + self.input_name_image = "image" + self.output_name_seg = "seg_image" + self.output_name_text_dicom_sr = "result_text_dicom_sr" + self.output_name_text_mongodb = "result_text_mongodb" + self.output_name_sc_path = "dicom_sc_dir" + + # the base class has an attribute called fragment to hold the reference to the fragment object + super().__init__(fragment, *args, **kwargs) + + # find model path; supports TorchScript and PyTorch model workflows (pythonic and MAP) + def _find_model_file_path(self, model_path: Path): + # when executing pythonically, model_path is a file + # when executing as MAP, model_path is a directory (/opt/holoscan/models) + # torch.load() from PyTorch workflow needs file path; can't load model from directory + # returns first found file in directory in this case + if model_path: + if model_path.is_file(): + return model_path + elif model_path.is_dir(): + for file in model_path.rglob("*"): + if file.is_file(): + return file + + raise ValueError(f"Model file not found in the provided path: {model_path}") + + # load a PyTorch model and register it in app_context + def _load_pytorch_model(self): + + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _kernel_size: tuple = (3, 3, 3, 3, 3, 3) + _strides: tuple = (1, 2, 2, 2, 2, (2, 2, 1)) + _upsample_kernel_size: tuple = (2, 2, 2, 2, (2, 2, 1)) + + # create DynUNet model with the specified architecture parameters + move to computational device (GPU or CPU) + # parameters pulled from inference.yaml file of the MONAI bundle + model = monai.networks.nets.dynunet.DynUNet( + spatial_dims=3, + in_channels=1, + out_channels=4, + kernel_size=_kernel_size, + strides=_strides, + upsample_kernel_size=_upsample_kernel_size, + norm_name="INSTANCE", + deep_supervision=False, + res_block=True, + ).to(_device) + + # load model state dictionary (i.e. mapping param names to tensors) via torch.load + # weights_only=True to avoid arbitrary code execution during unpickling + state_dict = torch.load(self.model_path, weights_only=True) + + # assign loaded weights to model architecture via load_state_dict + model.load_state_dict(state_dict) + + # set model in evaluation (inference) mode + model.eval() + + # create a MONAI Model object to encapsulate the PyTorch model and metadata + loaded_model = Model(self.model_path, name="ped_abd_ct_seg") + + # assign loaded PyTorch model as the predictor for the Model object + loaded_model.predictor = model + + # register the loaded Model object in the application context so other operators can access it + # MonaiSegInferenceOperator uses _get_model method to load models; looks at app_context.models first + self.app_context.models = loaded_model + + def setup(self, spec: OperatorSpec): + spec.input(self.input_name_image) + + # DICOM SEG + spec.output(self.output_name_seg) + + # DICOM SR + spec.output(self.output_name_text_dicom_sr) + + # MongoDB + spec.output(self.output_name_text_mongodb) + + # DICOM SC + spec.output(self.output_name_sc_path) + + def compute(self, op_input, op_output, context): + input_image = op_input.receive(self.input_name_image) + if not input_image: + raise ValueError("Input image is not found.") + + # this operator gets an in-memory Image object, so a specialized ImageReader is needed. + _reader = InMemImageReader(input_image) + + # preprocessing and postprocessing + pre_transforms = self.pre_process(_reader) + post_transforms = self.post_process(pre_transforms) + + # if PyTorch model + if self.model_path.suffix.lower() == ".pt": + # load the PyTorch model + self._logger.info("PyTorch model detected") + self._load_pytorch_model() + # else, we have TorchScript model + else: + self._logger.info("TorchScript model detected") + + # delegates inference and saving output to the built-in operator. + infer_operator = MonaiSegInferenceOperator( + self.fragment, + roi_size=(96, 96, 96), + pre_transforms=pre_transforms, + post_transforms=post_transforms, + overlap=0.75, + app_context=self.app_context, + model_name="", + inferer=InfererType.SLIDING_WINDOW, + sw_batch_size=4, + model_path=self.model_path, + name="monai_seg_inference_op", + ) + + # setting the keys used in the dictionary-based transforms + infer_operator.input_dataset_key = self._input_dataset_key + infer_operator.pred_dataset_key = self._pred_dataset_key + + seg_image = infer_operator.compute_impl(input_image, context) + + # DICOM SEG + op_output.emit(seg_image, self.output_name_seg) + + # grab result_text_dicom_sr and result_text_mongodb from ExractVolumeToTextd transform + result_text_dicom_sr, result_text_mongodb = self.get_result_text_from_transforms(post_transforms) + if not result_text_dicom_sr or not result_text_mongodb: + raise ValueError("Result text could not be generated.") + + # only log volumes for target organs so logs reflect MAP behavior + self._logger.info(f"Calculated Organ Volumes: {result_text_dicom_sr}") + + # DICOM SR + op_output.emit(result_text_dicom_sr, self.output_name_text_dicom_sr) + + # MongoDB + op_output.emit(result_text_mongodb, self.output_name_text_mongodb) + + # DICOM SC + # temporary DICOM SC (w/o source DICOM metadata) saved in output_folder / temp directory + dicom_sc_dir = self.output_folder / "temp" + + self._logger.info(f"Temporary DICOM SC saved at: {dicom_sc_dir}") + + op_output.emit(dicom_sc_dir, self.output_name_sc_path) + + def pre_process(self, img_reader) -> Compose: + """Composes transforms for preprocessing the input image before predicting on a model.""" + + my_key = self._input_dataset_key + + return Compose( + [ + # img_reader: specialized InMemImageReader, derived from MONAI ImageReader + LoadImaged(keys=my_key, reader=img_reader), + EnsureChannelFirstd(keys=my_key), + Orientationd(keys=my_key, axcodes="RAS"), + Spacingd(keys=my_key, pixdim=[1.5, 1.5, 3.0], mode=["bilinear"]), + ScaleIntensityRanged(keys=my_key, a_min=-250, a_max=400, b_min=0.0, b_max=1.0, clip=True), + CropForegroundd(keys=my_key, source_key=my_key, mode="minimum"), + EnsureTyped(keys=my_key), + CastToTyped(keys=my_key, dtype=float32), + ] + ) + + def post_process(self, pre_transforms: Compose) -> Compose: + """Composes transforms for postprocessing the prediction results.""" + + pred_key = self._pred_dataset_key + + labels = {"background": 0, "liver": 1, "spleen": 2, "pancreas": 3} + + return Compose( + [ + Activationsd(keys=pred_key, softmax=True), + Invertd( + keys=[pred_key, self._input_dataset_key], + transform=pre_transforms, + orig_keys=[self._input_dataset_key, self._input_dataset_key], + meta_key_postfix="meta_dict", + nearest_interp=[False, False], + to_tensor=True, + ), + AsDiscreted(keys=pred_key, argmax=True), + # custom post-processing steps + CalculateVolumeFromMaskd(keys=pred_key, label_names=labels), + # optional code for saving segmentation masks as a NIfTI + # SaveImaged( + # keys=pred_key, + # output_ext=".nii.gz", + # output_dir=self.output_folder / "NIfTI", + # meta_keys="pred_meta_dict", + # separate_folder=False, + # output_dtype=int16 + # ), + # volume data stored in dictionary under pred_key + '_volumes' key + ExtractVolumeToTextd( + keys=[pred_key + "_volumes"], label_names=labels, output_labels=self.output_labels + ), + # comment out LabelToContourd for seg masks instead of contours; organ filtering will be lost + LabelToContourd(keys=pred_key, output_labels=self.output_labels), + OverlayImageLabeld(image_key=self._input_dataset_key, label_key=pred_key, overlay_key="overlay"), + SaveImaged( + keys="overlay", + output_ext=".dcm", + # save temporary DICOM SC (w/o source DICOM metadata) in output_folder / temp directory + output_dir=self.output_folder / "temp", + separate_folder=False, + output_dtype=int16, + ), + ] + ) + + # grab volume data from ExtractVolumeToTextd transform + def get_result_text_from_transforms(self, post_transforms: Compose): + """Extracts the result_text variables from post-processing transforms output.""" + + # grab the result_text variables from ExractVolumeToTextd transfor + for transform in post_transforms.transforms: + if isinstance(transform, ExtractVolumeToTextd): + return transform.result_text_dicom_sr, transform.result_text_mongodb + return None diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/app.py b/examples/apps/cchmc_ped_abd_ct_seg_app/app.py new file mode 100644 index 00000000..845954b0 --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/app.py @@ -0,0 +1,281 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +# custom inference operator +from abdomen_seg_operator import AbdomenSegOperator + +# custom DICOM Secondary Capture (SC) writer operator +from dicom_sc_writer_operator import DICOMSCWriterOperator + +# custom MongoDB operators +from mongodb_entry_creator_operator import MongoDBEntryCreatorOperator +from mongodb_writer_operator import MongoDBWriterOperator + +# required for setting SegmentDescription attributes +# direct import as this is not part of App SDK package +from pydicom.sr.codedict import codes + +from monai.deploy.conditions import CountCondition +from monai.deploy.core import Application +from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator +from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator, SegmentDescription +from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator +from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator +from monai.deploy.operators.dicom_text_sr_writer_operator import DICOMTextSRWriterOperator, EquipmentInfo, ModelInfo + + +# inherit new Application class instance, AIAbdomenSegApp, from MONAI Application base class +# base class provides support for chaining up operators and executing application +class AIAbdomenSegApp(Application): + """Demonstrates inference with customized CCHMC pediatric abdominal segmentation bundle inference operator, with + DICOM files as input/output + + This application loads a set of DICOM instances, selects the appropriate series, converts the series to + 3D volume image, performs inference with a custom inference operator, including pre-processing + and post-processing, saves a DICOM SEG (organ contours), a DICOM Secondary Capture (organ contours overlay), + and a DICOM SR (organ volumes), and writes organ volumes and relevant DICOM tags to the MONAI Deploy Express + MongoDB database (optional). + + Pertinent MONAI Bundle: + https://github.com/cchmc-dll/pediatric_abdominal_segmentation_bundle/tree/original + + Execution Time Estimate: + With a NVIDIA GeForce RTX 3090 24GB GPU, for an input DICOM Series of 204 instances, the execution time is around + 25 seconds for DICOM SEG, DICOM SC, and DICOM SR outputs, as well as the MDE MongoDB database write. + """ + + def __init__(self, *args, **kwargs): + """Creates an application instance.""" + self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) + super().__init__(*args, **kwargs) + + def run(self, *args, **kwargs): + # this method calls the base class to run; can be omitted if simply calling through + self._logger.info(f"Begin {self.run.__name__}") + super().run(*args, **kwargs) + self._logger.info(f"End {self.run.__name__}") + + # use compose method to instantiate operators and connect them to form a Directed Acyclic Graph (DAG) + def compose(self): + """Creates the app specific operators and chain them up in the processing DAG.""" + + logging.info(f"Begin {self.compose.__name__}") + + # use Commandline options over environment variables to init context + app_context = Application.init_app_context(self.argv) + app_input_path = Path(app_context.input_path) + app_output_path = Path(app_context.output_path) + model_path = Path(app_context.model_path) + + # create the custom operator(s) as well as SDK built-in operator(s) + # DICOM Data Loader op + study_loader_op = DICOMDataLoaderOperator( + self, CountCondition(self, 1), input_folder=app_input_path, name="study_loader_op" + ) + + # custom DICOM Series Selector op + # all_matched and sort_by_sop_instance_count = True; want all series that meet the selection criteria + # to be matched, and SOP sorting + series_selector_op = DICOMSeriesSelectorOperator( + self, rules=Sample_Rules_Text, all_matched=True, sort_by_sop_instance_count=True, name="series_selector_op" + ) + + # DICOM Series to Volume op + series_to_vol_op = DICOMSeriesToVolumeOperator(self, name="series_to_vol_op") + + # custom inference op + # output_labels specifies which of the organ segmentations are desired in the DICOM SEG, DICOM SC, and DICOM SR outputs + # 1 = Liver, 2 = Spleen, 3 = Pancreas; all segmentations performed, but visibility in outputs (SEG, SC, SR) controlled here + # all organ volumes will be written to MongoDB + output_labels = [1, 2, 3] + abd_seg_op = AbdomenSegOperator( + self, app_context=app_context, model_path=model_path, output_labels=output_labels, name="abd_seg_op" + ) + + # create DICOM Seg writer providing the required segment description for each segment with + # the actual algorithm and the pertinent organ/tissue; the segment_label, algorithm_name, + # and algorithm_version are of DICOM VR LO type, limited to 64 chars + # https://dicom.nema.org/medical/dicom/current/output/chtml/part05/sect_6.2.html + + # general algorithm information + _algorithm_name = "CCHMC Pediatric CT Abdominal Segmentation" + _algorithm_family = codes.DCM.ArtificialIntelligence + _algorithm_version = "0.4.3" + + segment_descriptions = [ + SegmentDescription( + segment_label="Liver", + segmented_property_category=codes.SCT.Organ, + segmented_property_type=codes.SCT.Liver, + algorithm_name=_algorithm_name, + algorithm_family=_algorithm_family, + algorithm_version=_algorithm_version, + ), + SegmentDescription( + segment_label="Spleen", + segmented_property_category=codes.SCT.Organ, + segmented_property_type=codes.SCT.Spleen, + algorithm_name=_algorithm_name, + algorithm_family=_algorithm_family, + algorithm_version=_algorithm_version, + ), + SegmentDescription( + segment_label="Pancreas", + segmented_property_category=codes.SCT.Organ, + segmented_property_type=codes.SCT.Pancreas, + algorithm_name=_algorithm_name, + algorithm_family=_algorithm_family, + algorithm_version=_algorithm_version, + ), + ] + + # custom tags - add Device UID to DICOM SEG to match SR and SC tags + custom_tags_seg = {"SeriesDescription": "AI Generated DICOM SEG; Not for Clinical Use.", "DeviceUID": "0.0.1"} + custom_tags_sr = {"SeriesDescription": "AI Generated DICOM SR; Not for Clinical Use."} + custom_tags_sc = {"SeriesDescription": "AI Generated DICOM Secondary Capture; Not for Clinical Use."} + + # DICOM SEG Writer op writes content from segment_descriptions to output DICOM images as DICOM tags + dicom_seg_writer = DICOMSegmentationWriterOperator( + self, + segment_descriptions=segment_descriptions, + custom_tags=custom_tags_seg, + # store DICOM SEG in SEG subdirectory; necessary for routing in CCHMC MDE workflow definition + output_folder=app_output_path / "SEG", + # omit_empty_frames is a default parameteter (type bool) of DICOMSegmentationWriterOperator + # dictates whether or not to omit frames that contain no segmented pixels from the output segmentation + # default value is True; changed to False to ensure input and output DICOM series #'s match + omit_empty_frames=False, + name="dicom_seg_writer", + ) + + # model and equipment info + my_model_info = ModelInfo("CCHMC CAIIR", "CCHMC Pediatric CT Abdominal Segmentation", "0.4.3", "0.0.1") + my_equipment = EquipmentInfo(manufacturer="The MONAI Consortium", manufacturer_model="MONAI Deploy App SDK") + + # DICOM SR Writer op + dicom_sr_writer = DICOMTextSRWriterOperator( + self, + # copy_tags is a default parameteter (type bool) of DICOMTextSRWriterOperator; default value is True + # dictates whether or not to copy DICOM attributes from the selected DICOM series + # changed to True to copy DICOM attributes so DICOM SR has same Study UID + copy_tags=True, + model_info=my_model_info, + equipment_info=my_equipment, + custom_tags=custom_tags_sr, + # store DICOM SR in SR subdirectory; necessary for routing in CCHMC MDE workflow definition + output_folder=app_output_path / "SR", + ) + + # custom DICOM SC Writer op + dicom_sc_writer = DICOMSCWriterOperator( + self, + model_info=my_model_info, + equipment_info=my_equipment, + custom_tags=custom_tags_sc, + # store DICOM SC in SC subdirectory; necessary for routing in CCHMC MDE workflow definition + output_folder=app_output_path / "SC", + ) + + # MongoDB database, collection, and MAP version info + database_name = "CTLiverSpleenSegPredictions" + collection_name = "OrganVolumes" + map_version = "0.0.1" + + # custom MongoDB Entry Creator op + mongodb_entry_creator = MongoDBEntryCreatorOperator(self, map_version=map_version) + + # custom MongoDB Writer op + mongodb_writer = MongoDBWriterOperator(self, database_name=database_name, collection_name=collection_name) + + # create the processing pipeline, by specifying the source and destination operators, and + # ensuring the output from the former matches the input of the latter, in both name and type + # instantiate and connect operators using self.add_flow(); specify current operator, next operator, and tuple to match I/O + self.add_flow(study_loader_op, series_selector_op, {("dicom_study_list", "dicom_study_list")}) + self.add_flow( + series_selector_op, series_to_vol_op, {("study_selected_series_list", "study_selected_series_list")} + ) + self.add_flow(series_to_vol_op, abd_seg_op, {("image", "image")}) + + # note below the dicom_seg_writer, dicom_sr_writer, dicom_sc_writer, and mongodb_entry_creator each require + # two inputs, each coming from a source operator + + # DICOM SEG + self.add_flow( + series_selector_op, dicom_seg_writer, {("study_selected_series_list", "study_selected_series_list")} + ) + self.add_flow(abd_seg_op, dicom_seg_writer, {("seg_image", "seg_image")}) + + # DICOM SR + self.add_flow( + series_selector_op, dicom_sr_writer, {("study_selected_series_list", "study_selected_series_list")} + ) + self.add_flow(abd_seg_op, dicom_sr_writer, {("result_text_dicom_sr", "text")}) + + # DICOM SC + self.add_flow( + series_selector_op, dicom_sc_writer, {("study_selected_series_list", "study_selected_series_list")} + ) + self.add_flow(abd_seg_op, dicom_sc_writer, {("dicom_sc_dir", "dicom_sc_dir")}) + + # MongoDB + self.add_flow( + series_selector_op, mongodb_entry_creator, {("study_selected_series_list", "study_selected_series_list")} + ) + self.add_flow(abd_seg_op, mongodb_entry_creator, {("result_text_mongodb", "text")}) + self.add_flow(mongodb_entry_creator, mongodb_writer, {("mongodb_database_entry", "mongodb_database_entry")}) + + logging.info(f"End {self.compose.__name__}") + + +# series selection rule in JSON, which selects for axial CT series; flexible ST choices: +# StudyDescription: matches any value +# Modality: matches "CT" value (case-insensitive); filters out non-CT modalities +# ImageType: matches value that contains "PRIMARY", "ORIGINAL", and "AXIAL"; filters out most cor and sag views +# SeriesDescription: matches any values that do not contain "cor" or "sag" (case-insensitive); filters out cor and sag views +# SliceThickness: supports list, string, and numerical matching: +# [3, 5]: matches ST values between 3 and 5 +# "^(5(\\\\.0+)?|5)$": RegEx; matches ST values of 5, 5.0, 5.00, etc. +# 5: matches ST values of 5, 5.0, 5.00, etc. +# all valid series will be selected; downstream operators only perform inference and write outputs for 1st selected series +# please see more detail in DICOMSeriesSelectorOperator + +Sample_Rules_Text = """ +{ + "selections": [ + { + "name": "Axial CT Series", + "conditions": { + "StudyDescription": "(.*?)", + "Modality": "(?i)CT", + "ImageType": ["PRIMARY", "ORIGINAL", "AXIAL"], + "SeriesDescription": "(?i)^(?!.*(cor|sag)).*$", + "SliceThickness": [3, 5] + } + } + ] +} +""" + +# if executing application code using python interpreter: +if __name__ == "__main__": + # creates the app and test it standalone; when running is this mode, please note the following: + # -m , for model file path + # -i , for input DICOM CT series folder + # -o , for the output folder, default $PWD/output + # e.g. + # monai-deploy exec app.py -i input -m model/dynunet_FT.ts + # + logging.info(f"Begin {__name__}") + AIAbdomenSegApp().run() + logging.info(f"End {__name__}") diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/app.yaml b/examples/apps/cchmc_ped_abd_ct_seg_app/app.yaml new file mode 100644 index 00000000..badfac7c --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/app.yaml @@ -0,0 +1,32 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +--- + +# app.yaml is a configuration file that specifies MAP settings +# used by MONAI App SDK to understand how to run our app in a MAP and what resources it needs + +# specifies high-level information about our app +application: + title: MONAI Deploy App Package - CCHMC Pediatric CT Abdominal Segmentation + version: 0.0.1 + inputFormats: ["file"] + outputFormats: ["file"] + +# specifies the resources our app needs to run +# per MONAI docs (https://docs.monai.io/projects/monai-deploy-app-sdk/en/latest/developing_with_sdk/executing_packaged_app_locally.html) +# MAR does not validate all of the resource requirements embedded in the MAP to ensure they are met in host system +# e.g, MAR will throw an error if gpu requirement is not met on host system; however, gpuMemory parameter doesn't appear to be validated +resources: + cpu: 1 + gpu: 1 + memory: 1Gi + # during MAP execution, for an input DICOM Series of 204 instances, GPU usage peaks at just under 8900 MiB ~= 9.3 GB ~= 8.7 Gi + gpuMemory: 9Gi diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/dicom_sc_writer_operator.py b/examples/apps/cchmc_ped_abd_ct_seg_app/dicom_sc_writer_operator.py new file mode 100644 index 00000000..9479485e --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/dicom_sc_writer_operator.py @@ -0,0 +1,253 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from pathlib import Path +from typing import Dict, Optional, Union + +import pydicom + +from monai.deploy.core import Fragment, Operator, OperatorSpec +from monai.deploy.core.domain.dicom_series import DICOMSeries +from monai.deploy.core.domain.dicom_series_selection import StudySelectedSeries +from monai.deploy.operators.dicom_utils import EquipmentInfo, ModelInfo, write_common_modules +from monai.deploy.utils.importutil import optional_import +from monai.deploy.utils.version import get_sdk_semver + +dcmread, _ = optional_import("pydicom", name="dcmread") +dcmwrite, _ = optional_import("pydicom.filewriter", name="dcmwrite") +generate_uid, _ = optional_import("pydicom.uid", name="generate_uid") +ImplicitVRLittleEndian, _ = optional_import("pydicom.uid", name="ImplicitVRLittleEndian") +Dataset, _ = optional_import("pydicom.dataset", name="Dataset") +FileDataset, _ = optional_import("pydicom.dataset", name="FileDataset") +Sequence, _ = optional_import("pydicom.sequence", name="Sequence") + + +class DICOMSCWriterOperator(Operator): + """Class to write a new DICOM Secondary Capture (DICOM SC) instance with source DICOM Series metadata included. + + Named inputs: + dicom_sc_dir: file path of temporary DICOM SC (w/o source DICOM Series metadata). + study_selected_series_list: DICOM Series for copying metadata from. + + Named output: + None. + + File output: + New, updated DICOM SC file (with source DICOM Series metadata) in the provided output folder. + """ + + # file extension for the generated DICOM Part 10 file + DCM_EXTENSION = ".dcm" + # the default output folder for saving the generated DICOM instance file + # DEFAULT_OUTPUT_FOLDER = Path(os.path.join(os.path.dirname(__file__))) / "output" + DEFAULT_OUTPUT_FOLDER = Path.cwd() / "output" + + def __init__( + self, + fragment: Fragment, + *args, + output_folder: Union[str, Path], + model_info: ModelInfo, + equipment_info: Optional[EquipmentInfo] = None, + custom_tags: Optional[Dict[str, str]] = None, + **kwargs, + ): + """Class to write a new DICOM Secondary Capture (DICOM SC) instance with source DICOM Series metadata. + + Args: + output_folder (str or Path): The folder for saving the generated DICOM SC instance file. + model_info (ModelInfo): Object encapsulating model creator, name, version and UID. + equipment_info (EquipmentInfo, optional): Object encapsulating info for DICOM Equipment Module. + Defaults to None. + custom_tags (Dict[str, str], optional): Dictionary for setting custom DICOM tags using Keywords and str values only. + Defaults to None. + + Raises: + ValueError: If result cannot be found either in memory or from file. + """ + + self._logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + + # need to init the output folder until the execution context supports dynamic FS path + # not trying to create the folder to avoid exception on init + self.output_folder = Path(output_folder) if output_folder else DICOMSCWriterOperator.DEFAULT_OUTPUT_FOLDER + self.input_name_sc_dir = "dicom_sc_dir" + self.input_name_study_series = "study_selected_series_list" + + # for copying DICOM attributes from a provided DICOMSeries + # required input for write_common_modules; will always be True for this implementation + self.copy_tags = True + + self.model_info = model_info if model_info else ModelInfo() + self.equipment_info = equipment_info if equipment_info else EquipmentInfo() + self.custom_tags = custom_tags + + # set own Modality and SOP Class UID + # Standard SOP Classes: https://dicom.nema.org/dicom/2013/output/chtml/part04/sect_B.5.html + # Modality, e.g., + # "OT" for PDF + # "SR" for Structured Report. + # Media Storage SOP Class UID, e.g., + # "1.2.840.10008.5.1.4.1.1.88.11" for Basic Text SR Storage + # "1.2.840.10008.5.1.4.1.1.104.1" for Encapsulated PDF Storage, + # "1.2.840.10008.5.1.4.1.1.88.34" for Comprehensive 3D SR IOD + # "1.2.840.10008.5.1.4.1.1.66.4" for Segmentation Storage + self.modality_type = "OT" # OT Modality for Secondary Capture + self.sop_class_uid = ( + "1.2.840.10008.5.1.4.1.1.7.4" # SOP Class UID for Multi-frame True Color Secondary Capture Image Storage + ) + # custom OverlayImageLabeld post-processing transform creates an RBG overlay + + # equipment version may be different from contributing equipment version + try: + self.software_version_number = get_sdk_semver() # SDK Version + except Exception: + self.software_version_number = "" + self.operators_name = f"AI Algorithm {self.model_info.name}" + + super().__init__(fragment, *args, **kwargs) + + def setup(self, spec: OperatorSpec): + """Set up the named input(s), and output(s) if applicable. + + This operator does not have an output for the next operator, rather file output only. + + Args: + spec (OperatorSpec): The Operator specification for inputs and outputs etc. + """ + + spec.input(self.input_name_sc_dir) + spec.input(self.input_name_study_series) + + def compute(self, op_input, op_output, context): + """Performs computation for this operator and handles I/O. + + For now, only a single result content is supported, which could be in memory or an accessible file. + The DICOM Series used during inference is required (and copy_tags is hardcoded to True). + + When there are multiple selected series in the input, the first series' containing study will + be used for retrieving DICOM Study module attributes, e.g. StudyInstanceUID. + + Raises: + NotADirectoryError: When temporary DICOM SC path is not a directory. + FileNotFoundError: When result object not in the input, and result file not found either. + ValueError: Content object and file path not in the inputs, or no DICOM series provided. + IOError: If the input content is blank. + """ + + # receive the temporary DICOM SC file path and study selected series list + dicom_sc_dir = Path(op_input.receive(self.input_name_sc_dir)) + if not dicom_sc_dir: + raise IOError("Temporary DICOM SC path is read but blank.") + if not dicom_sc_dir.is_dir(): + raise NotADirectoryError(f"Provided temporary DICOM SC path is not a directory: {dicom_sc_dir}") + self._logger.info(f"Received temporary DICOM SC path: {dicom_sc_dir}") + + study_selected_series_list = op_input.receive(self.input_name_study_series) + if not study_selected_series_list or len(study_selected_series_list) < 1: + raise ValueError("Missing input, list of 'StudySelectedSeries'.") + + # retrieve the DICOM Series used during inference in order to grab appropriate study/series level tags + # this will be the 1st Series in study_selected_series_list + dicom_series = None + for study_selected_series in study_selected_series_list: + if not isinstance(study_selected_series, StudySelectedSeries): + raise ValueError(f"Element in input is not expected type, {StudySelectedSeries}.") + selected_series = study_selected_series.selected_series[0] + dicom_series = selected_series.series + break + + # log basic DICOM metadata for the retrieved DICOM Series + self._logger.debug(f"Dicom Series: {dicom_series}") + + # the output folder should come from the execution context when it is supported + self.output_folder.mkdir(parents=True, exist_ok=True) + + # write the new DICOM SC instance + self.write(dicom_sc_dir, dicom_series, self.output_folder) + + def write(self, dicom_sc_dir, dicom_series: DICOMSeries, output_dir: Path): + """Writes a new, updated DICOM SC instance and deletes the temporary DICOM SC instance. + The new, updated DICOM SC instance is the temporary DICOM SC instance with source + DICOM Series metadata copied. + + Args: + dicom_sc_dir: temporary DICOM SC file path. + dicom_series (DICOMSeries): DICOMSeries object encapsulating the original series. + + Returns: + None + + File output: + New, updated DICOM SC file (with source DICOM Series metadata) in the provided output folder. + """ + + if not isinstance(output_dir, Path): + raise ValueError("output_dir is not a valid Path.") + + output_dir.mkdir(parents=True, exist_ok=True) # just in case + + # find the temporary DICOM SC file in the directory; there should only be one .dcm file present + dicom_files = list(dicom_sc_dir.glob("*.dcm")) + dicom_sc_file = dicom_files[0] + + # load the temporary DICOM SC file using pydicom + dicom_sc_dataset = pydicom.dcmread(dicom_sc_file) + self._logger.info(f"Loaded temporary DICOM SC file: {dicom_sc_file}") + + # use write_common_modules to copy metadata from dicom_series + # this will copy metadata and return an updated Dataset + ds = write_common_modules( + dicom_series, + self.copy_tags, # always True for this implementation + self.modality_type, + self.sop_class_uid, + self.model_info, + self.equipment_info, + ) + + # Secondary Capture specific tags + ds.ImageType = ["DERIVED", "SECONDARY"] + + # for now, only allow str Keywords and str value + if self.custom_tags: + for k, v in self.custom_tags.items(): + if isinstance(k, str) and isinstance(v, str): + try: + ds.update({k: v}) + except Exception as ex: + # best effort for now + logging.warning(f"Tag {k} was not written, due to {ex}") + + # merge the copied metadata into the loaded temporary DICOM SC file (dicom_sc_dataset) + for tag, value in ds.items(): + dicom_sc_dataset[tag] = value + + # save the updated DICOM SC file to the output folder + # instance file name is the same as the new SOP instance UID + output_file_path = self.output_folder.joinpath( + f"{dicom_sc_dataset.SOPInstanceUID}{DICOMSCWriterOperator.DCM_EXTENSION}" + ) + dicom_sc_dataset.save_as(output_file_path) + self._logger.info(f"Saved updated DICOM SC file at: {output_file_path}") + + # remove the temporary DICOM SC file + os.remove(dicom_sc_file) + self._logger.info(f"Removed temporary DICOM SC file: {dicom_sc_file}") + + # check if the temp directory is empty, then delete it + if not any(dicom_sc_dir.iterdir()): + os.rmdir(dicom_sc_dir) + self._logger.info(f"Removed temporary directory: {dicom_sc_dir}") + else: + self._logger.warning(f"Temporary directory {dicom_sc_dir} is not empty, skipping removal.") diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/mongodb_entry_creator_operator.py b/examples/apps/cchmc_ped_abd_ct_seg_app/mongodb_entry_creator_operator.py new file mode 100644 index 00000000..4f2f275c --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/mongodb_entry_creator_operator.py @@ -0,0 +1,349 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from datetime import datetime +from typing import Any, Dict, Union + +import pydicom +import pytz + +from monai.deploy.core import Fragment, Operator, OperatorSpec +from monai.deploy.core.domain.dicom_series import DICOMSeries +from monai.deploy.core.domain.dicom_series_selection import StudySelectedSeries + + +class MongoDBEntryCreatorOperator(Operator): + """Class to create a database entry for downstream MONAI Deploy Express MongoDB database writing. + Provided text input and source DICOM Series DICOM tags are used to create the entry. + + Named inputs: + text: text content to be included in the database entry. + study_selected_series_list: DICOM series for copying metadata from. + + Named output: + mongodb_database_entry: formatted MongoDB database entry. Downstream receiver MongoDBWriterOperator will write + the entry to the MONAI Deploy Express MongoDB database. + """ + + def __init__(self, fragment: Fragment, *args, map_version: str, **kwargs): + """Class to create a MONAI Deploy Express MongoDB database entry. Provided text input and + source DICOM Series DICOM tags are used to create the entry. + + Args: + map_version (str): version of the MAP. + + Raises: + ValueError: If result cannot be found either in memory or from file. + """ + + self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) + + self.map_version = map_version + + self.input_name_text = "text" + self.input_name_dcm_series = "study_selected_series_list" + + self.output_name_db_entry = "mongodb_database_entry" + + super().__init__(fragment, *args, **kwargs) + + def setup(self, spec: OperatorSpec): + """Set up the named input(s), and output(s). + + Args: + spec (OperatorSpec): The Operator specification for inputs and outputs etc. + """ + + spec.input(self.input_name_text) + spec.input(self.input_name_dcm_series) + + spec.output(self.output_name_db_entry) + + def compute(self, op_input, op_output, context): + """Performs computation for this operator and handles I/O. + + For now, only a single result content is supported, which could be in memory or an accessible file. + The DICOM Series used during inference is required. + + When there are multiple selected series in the input, the first series' containing study will + be used for retrieving DICOM Study module attributes, e.g. StudyInstanceUID. + + Raises: + FileNotFoundError: When result object not in the input, and result file not found either. + ValueError: Content object and file path not in the inputs, or no DICOM series provided. + IOError: If the input content is blank. + """ + + # receive the result text and study selected series list + result_text = str(op_input.receive(self.input_name_text)).strip() + if not result_text: + raise IOError("Input is read but blank.") + + study_selected_series_list = None + try: + study_selected_series_list = op_input.receive(self.input_name_dcm_series) + except Exception: + pass + if not study_selected_series_list or len(study_selected_series_list) < 1: + raise ValueError("Missing input, list of 'StudySelectedSeries'.") + + # retrieve the DICOM Series used during inference in order to grab appropriate Study/Series level tags + # this will be the 1st Series in study_selected_series_list + dicom_series = None + for study_selected_series in study_selected_series_list: + if not isinstance(study_selected_series, StudySelectedSeries): + raise ValueError(f"Element in input is not expected type, {StudySelectedSeries}.") + selected_series = study_selected_series.selected_series[0] + dicom_series = selected_series.series + break + + # create MongoDB entry + mongodb_database_entry = self.create_entry(result_text, dicom_series, self.map_version) + + # emit MongoDB entry + op_output.emit(mongodb_database_entry, self.output_name_db_entry) + + def create_entry(self, result_text: str, dicom_series: DICOMSeries, map_version: str): + """Creates the MONAI Deploy Express MongoDB database entry. + + Args: + result_text (str): text content to be included in the database entry. + dicom_series (DICOMSeries): DICOMSeries object encapsulating the original series. + map_version (str): version of the MAP. + + Returns: + mongodb_database_entry: formatted MongoDB database entry. + """ + + if not result_text or not len(result_text.strip()): + raise ValueError("Content is empty.") + + # get one of the SOP instance's native sop instance dataset + # we will pull Study level (and some Series level) DICOM tags from this SOP instance + # this same strategy is employed by write_common_modules + orig_ds = dicom_series.get_sop_instances()[0].get_native_sop_instance() + + # # loop through dicom series tags; look for discrepancies from SOP instances + # for sop_instance in dicom_series.get_sop_instances(): + # # get the native SOP instance dataset + # dicom_image = sop_instance.get_native_sop_instance() + + # # check if the tag is present in the dataset + # if hasattr(dicom_image, 'Exposure'): + # tag = dicom_image.Exposure + # print(f"Exposure: {tag}") + # else: + # print("Exposure tag not found in this SOP instance.") + + # DICOM TAG WRITING TO MONGODB + # edge cases addressed by looking at DICOM tag Type, Value Representation (VR), + # and Value Multiplicity (VM) specifically for the CT Image CIOD + # https://dicom.innolitics.com/ciods/ct-image + + # define Tag Absent variable + tag_absent = "Tag Absent" + + # STUDY AND SERIES LEVEL DICOM TAGS + + # AccessionNumber - Type: Required (2), VR: SH, VM: 1 + accession_number = orig_ds.AccessionNumber + + # StudyInstanceUID - Type: Required (1), VR: UI, VM: 1 + study_instance_uid = orig_ds.StudyInstanceUID + + # StudyDescription: Type: Optional (3), VR: LO, VM: 1 + # while Optional, only studies with this tag will be routed from Compass and MAP launched per workflow def + study_description = orig_ds.get("StudyDescription", tag_absent) + + # SeriesInstanceUID: Type: Required (1), VR: UI, VM: 1 + series_instance_uid = dicom_series._series_instance_uid + + # SeriesDescription: Type: Optional (3), VR: LO, VM: 1 + series_description = orig_ds.get("SeriesDescription", tag_absent) + + # sop instances should always be available on the MONAI DICOM Series object + series_sop_instances = len(dicom_series._sop_instances) + + # PATIENT DETAIL DICOM TAGS + + # PatientID - Type: Required (2), VR: LO, VM: 1 + patient_id = orig_ds.PatientID + + # PatientName - Type: Required (2), VR: PN, VM: 1 + # need to convert to str; pydicom can't encode PersonName object + patient_name = str(orig_ds.PatientName) + + # PatientSex - Type: Required (2), VR: CS, VM: 1 + patient_sex = orig_ds.PatientSex + + # PatientBirthDate - Type: Required (2), VR: DA, VM: 1 + patient_birth_date = orig_ds.PatientBirthDate + + # PatientAge - Type: Optional (3), VR: AS, VM: 1 + patient_age = orig_ds.get("PatientAge", tag_absent) + + # EthnicGroup - Type: Optional (3), VR: SH, VM: 1 + ethnic_group = orig_ds.get("EthnicGroup", tag_absent) + + # SCAN ACQUISITION PARAMETER DICOM TAGS + + # on CCHMC test cases, the following tags had consistent values for all SOP instances + + # Manufacturer - Type: Required (2), VR: LO, VM: 1 + manufacturer = orig_ds.Manufacturer + + # ManufacturerModelName - Type: Optional (3), VR: LO, VM: 1 + manufacturer_model_name = orig_ds.get("ManufacturerModelName", tag_absent) + + # BodyPartExamined - Type: Optional (3), VR: CS, VM: 1 + body_part_examined = orig_ds.get("BodyPartExamined", tag_absent) + + # row and column pixel spacing are derived from PixelSpacing + # PixelSpacing - Type: Required (1), VR: DS, VM: 2 (handled by MONAI) + row_pixel_spacing = dicom_series._row_pixel_spacing + column_pixel_spacing = dicom_series._col_pixel_spacing + + # per DICOMSeriesToVolumeOperator, depth pixel spacing will always be defined + depth_pixel_spacing = dicom_series._depth_pixel_spacing + + # SliceThickness - Type: Required (2), VR: DS, VM: 1 + slice_thickness = orig_ds.SliceThickness + + # PixelRepresentation - Type: Required (1), VR: US, VM: 1 + pixel_representation = orig_ds.PixelRepresentation + + # BitsStored - Type: Required (1), VR: US, VM: 1 + bits_stored = orig_ds.BitsStored + + # WindowWidth - Type: Conditionally Required (1C), VR: DS, VM: 1-n + window_width = orig_ds.get("WindowWidth", tag_absent) + # for MultiValue case: + if isinstance(window_width, pydicom.multival.MultiValue): + # join multiple values into a single string separated by a | + # convert DSfloat objects to strs to allow joining + window_width = " | ".join([str(window) for window in window_width]) + + # RevolutionTime - Type: Optional (3), VR: FD, VM: 1 + revolution_time = orig_ds.get("RevolutionTime", tag_absent) + + # FocalSpots - Type: Optional (3), VR: DS, VM: 1-n + focal_spots = orig_ds.get("FocalSpots", tag_absent) + # for MultiValue case: + if isinstance(focal_spots, pydicom.multival.MultiValue): + # join multiple values into a single string separated by a | + # convert DSfloat objects to strs to allow joining + focal_spots = " | ".join([str(spot) for spot in focal_spots]) + + # SpiralPitchFactor - Type: Optional (3), VR: FD, VM: 1 + spiral_pitch_factor = orig_ds.get("SpiralPitchFactor", tag_absent) + + # ConvolutionKernel - Type: Optional (3), VR: SH, VM: 1-n + convolution_kernel = orig_ds.get("ConvolutionKernel", tag_absent) + # for MultiValue case: + if isinstance(convolution_kernel, pydicom.multival.MultiValue): + # join multiple values into a single string separated by a | + convolution_kernel = " | ".join(convolution_kernel) + + # ReconstructionDiameter - Type: Optional (3), VR: DS, VM: 1 + reconstruction_diameter = orig_ds.get("ReconstructionDiameter", tag_absent) + + # KVP - Type: Required (2), VR: DS, VM: 1 + kvp = orig_ds.KVP + + # on CCHMC test cases, the following tags did NOT have consistent values for all SOP instances + # as such, if the tag value exists, it will be averaged over all SOP instances + + # initialize an averaged values dictionary + averaged_values: Dict[str, Union[float, str]] = {} + + # tags to check and average + tags_to_average = { + "XRayTubeCurrent": tag_absent, # Type: Optional (3), VR: IS, VM: 1 + "Exposure": tag_absent, # Type: Optional (3), VR: IS, VM: 1 + "CTDIvol": tag_absent, # Type: Optional (3), VR: FD, VM: 1 + } + + # check which tags are present on the 1st SOP instance + for tag, default_value in tags_to_average.items(): + # if the tag exists + if tag in orig_ds: + # loop through SOP instances, grab tag values + values = [] + for sop_instance in dicom_series.get_sop_instances(): + ds = sop_instance.get_native_sop_instance() + value = ds.get(tag, default_value) + # if tag is present on current SOP instance + if value != default_value: + # add tag value to values; convert to float for averaging + values.append(float(value)) + # compute the average if values were collected + if values: + averaged_values[tag] = round(sum(values) / len(values), 3) + else: + averaged_values[tag] = default_value + else: + # if the tag is absent in the first SOP instance, keep the default value + averaged_values[tag] = default_value + + # parse result_text (i.e. predicted organ volumes) and format + map_results = {} + for line in result_text.split("\n"): + if ":" in line: + key, value = line.split(":") + key = key.replace(" ", "") + map_results[key] = value.strip() + + # create the MongoDB database entry + mongodb_database_entry: Dict[str, Any] = { + "Timestamp": datetime.now(pytz.UTC), # timestamp in UTC + "MAPVersion": map_version, + "DICOMSeriesDetails": { + "AccessionNumber": accession_number, + "StudyInstanceUID": study_instance_uid, + "StudyDescription": study_description, + "SeriesInstanceUID": series_instance_uid, + "SeriesDescription": series_description, + "SeriesFileCount": series_sop_instances, + }, + "PatientDetails": { + "PatientID": patient_id, + "PatientName": patient_name, + "PatientSex": patient_sex, + "PatientBirthDate": patient_birth_date, + "PatientAge": patient_age, + "EthnicGroup": ethnic_group, + }, + "ScanAcquisitionDetails": { + "Manufacturer": manufacturer, + "ManufacturerModelName": manufacturer_model_name, + "BodyPartExamined": body_part_examined, + "RowPixelSpacing": row_pixel_spacing, + "ColumnPixelSpacing": column_pixel_spacing, + "DepthPixelSpacing": depth_pixel_spacing, + "SliceThickness": slice_thickness, + "PixelRepresentation": pixel_representation, + "BitsStored": bits_stored, + "WindowWidth": window_width, + "RevolutionTime": revolution_time, + "FocalSpots": focal_spots, + "SpiralPitchFactor": spiral_pitch_factor, + "ConvolutionKernel": convolution_kernel, + "ReconstructionDiameter": reconstruction_diameter, + "KVP": kvp, + }, + "MAPResults": map_results, + } + + # integrate averaged tags into MongoDB entry: + mongodb_database_entry["ScanAcquisitionDetails"].update(averaged_values) + + return mongodb_database_entry diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/mongodb_writer_operator.py b/examples/apps/cchmc_ped_abd_ct_seg_app/mongodb_writer_operator.py new file mode 100644 index 00000000..6d18e395 --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/mongodb_writer_operator.py @@ -0,0 +1,235 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from dotenv import load_dotenv + +load_dotenv() + +from pymongo import MongoClient, errors + +from monai.deploy.core import Fragment, Operator, OperatorSpec + + +class MongoDBWriterOperator(Operator): + """Class to write the MONAI Deploy Express MongoDB database with provided database entry. + + Named inputs: + mongodb_database_entry: formatted MongoDB database entry. + + Named output: + None + + Result: + MONAI Deploy Express MongoDB database write of the database entry. + """ + + def __init__(self, fragment: Fragment, *args, database_name: str, collection_name: str, **kwargs): + """Class to write the MONAI Deploy Express MongoDB database with provided database entry. + + Args: + database_name (str): name of the MongoDB database that will be written. + collection_name (str): name of the MongoDB collection that will be written. + + Raises: + Relevant MongoDB errors if database writing is unsuccessful. + """ + + self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) + + self.database_name = database_name + self.collection_name = collection_name + + self.input_name_db_entry = "mongodb_database_entry" + + # MongoDB credentials + self.mongodb_username = os.environ.get("MONGODB_USERNAME") + self.mongodb_password = os.environ.get("MONGODB_PASSWORD") + self.mongodb_port = os.environ.get("MONGODB_PORT") + self.docker_mongodb_ip = os.environ.get("MONGODB_IP_DOCKER") + + # determine the MongoDB IP address based on execution environment + self.mongo_ip = self._get_mongo_ip() + self._logger.info(f"Using MongoDB IP: {self.mongo_ip}") + + # connect to the MongoDB database + self.client = None + + try: + self.client = MongoClient( + f"mongodb://{self.mongodb_username}:{self.mongodb_password}@{self.mongo_ip}:{self.mongodb_port}/?authSource=admin", + serverSelectionTimeoutMS=10000, # 10s timeout for testing connection; 20s by default + ) + if self.client is None: + raise RuntimeError("MongoClient was not created successfully") + ping_response = self.client.admin.command("ping") + self._logger.info( + f"Successfully connected to MongoDB at: {self.client.address}. Ping response: {ping_response}" + ) + self.db = self.client[self.database_name] + self.collection = self.db[self.collection_name] + except errors.ServerSelectionTimeoutError as e: + self._logger.error("Failed to connect to MongoDB: Server selection timeout.") + self._logger.debug(f"Detailed error: {e}") + raise + except errors.ConnectionFailure as e: + self._logger.error("Failed to connect to MongoDB: Connection failure.") + self._logger.debug(f"Detailed error: {e}") + raise + except errors.OperationFailure as e: + self._logger.error("Failed to authenticate with MongoDB.") + self._logger.debug(f"Detailed error: {e}") + raise + except Exception as e: + self._logger.error("Unexpected error occurred while connecting to MongoDB.") + self._logger.debug(f"Detailed error: {e}") + raise + super().__init__(fragment, *args, **kwargs) + + def setup(self, spec: OperatorSpec): + """Set up the named input(s), and output(s) if applicable. + + This operator does not have an output for the next operator - MongoDB write only. + + Args: + spec (OperatorSpec): The Operator specification for inputs and outputs etc. + """ + + spec.input(self.input_name_db_entry) + + def compute(self, op_input, op_output, context): + """Performs computation for this operator""" + + mongodb_database_entry = op_input.receive(self.input_name_db_entry) + + # write to MongoDB + self.write(mongodb_database_entry) + + def write(self, mongodb_database_entry): + """Writes the database entry to the MONAI Deploy Express MongoDB database. + + Args: + mongodb_database_entry: formatted MongoDB database entry. + + Returns: + None + """ + + # MongoDB writing + try: + insert_result = self.collection.insert_one(mongodb_database_entry) + if insert_result.acknowledged: + self._logger.info(f"Document inserted with ID: {insert_result.inserted_id}") + else: + self._logger.error("Failed to write document to MongoDB.") + except errors.PyMongoError as e: + self._logger.error("Failed to insert document into MongoDB.") + self._logger.debug(f"Detailed error: {e}") + raise + + def _get_mongo_ip(self): + """Determine the MongoDB IP based on the execution environment. + + If the pipeline is being run pythonically, use localhost. + + If MAP is being run via MAR or MONAI Deploy Express, use Docker bridge network IP. + """ + + # if running in a Docker container (/.dockerenv file present) + if os.path.exists("/.dockerenv"): + self._logger.info("Detected Docker environment") + return self.docker_mongodb_ip + + # if not executing as Docker container, we are executing pythonically + self._logger.info("Detected local environment (pythonic execution)") + return "localhost" + + +# Module function (helper function) +def test(): + """Test writing to and deleting from the MDE MongoDB instance locally""" + + # MongoDB credentials + mongodb_username = os.environ.get("MONGODB_USERNAME") + mongodb_password = os.environ.get("MONGODB_PASSWORD") + mongodb_port = os.environ.get("MONGODB_PORT") + + # sample information + database_name = "CTLiverSpleenSegPredictions" + collection_name = "OrganVolumes" + test_entry = {"test_key": "test_value"} + + # connect to MongoDB instance (localhost as we are testing locally) + try: + client = MongoClient( + f"mongodb://{mongodb_username}:{mongodb_password}@localhost:{mongodb_port}/?authSource=admin", + serverSelectionTimeoutMS=10000, # 10s timeout for testing connection; 20s by default + ) + if client is None: + raise RuntimeError("MongoClient was not created successfully") + ping_response = client.admin.command("ping") + print(f"Successfully connected to MongoDB at: {client.address}. Ping response: {ping_response}") + db = client[database_name] + collection = db[collection_name] + except errors.ServerSelectionTimeoutError as e: + print("Failed to connect to MongoDB: Server selection timeout.") + print(f"Detailed error: {e}") + raise + except errors.ConnectionFailure as e: + print("Failed to connect to MongoDB: Connection failure.") + print(f"Detailed error: {e}") + raise + except errors.OperationFailure as e: + print("Failed to authenticate with MongoDB.") + print(f"Detailed error: {e}") + raise + except Exception as e: + print("Unexpected error occurred while connecting to MongoDB.") + print(f"Detailed error: {e}") + raise + + # insert document + try: + insert_result = collection.insert_one(test_entry) + if insert_result.acknowledged: + print(f"Document inserted with ID: {insert_result.inserted_id}") + else: + print("Failed to write document to MongoDB.") + except errors.PyMongoError as e: + print("Failed to insert document into MongoDB.") + print(f"Detailed error: {e}") + raise + + # verify the inserted document + try: + inserted_doc = collection.find_one({"_id": insert_result.inserted_id}) + if inserted_doc: + print(f"Inserted document: {inserted_doc}") + else: + print("Document not found in the collection after insertion.") + except errors.PyMongoError as e: + print("Failed to retrieve the inserted document from MongoDB.") + print(f"Detailed error: {e}") + return + + # # delete a database + # try: + # client.drop_database(database_name) + # print(f"Test database '{database_name}' deleted successfully.") + # except errors.PyMongoError as e: + # print("Failed to delete the test database.") + # print(f"Detailed error: {e}") + + +if __name__ == "__main__": + test() diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/post_transforms.py b/examples/apps/cchmc_ped_abd_ct_seg_app/post_transforms.py new file mode 100644 index 00000000..607bcd47 --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/post_transforms.py @@ -0,0 +1,387 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import List + +import matplotlib.cm as cm +import numpy as np + +from monai.config import KeysCollection +from monai.data import MetaTensor +from monai.transforms import LabelToContour, MapTransform + + +# Calculate segmentation volumes in ml +class CalculateVolumeFromMaskd(MapTransform): + """ + Dictionary-based transform to calculate the volume of predicted organ masks. + + Args: + keys (list): The keys corresponding to the predicted organ masks in the dictionary. + label_names (list): The list of organ names corresponding to the masks. + """ + + def __init__(self, keys, label_names): + self._logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + super().__init__(keys) + self.label_names = label_names + + def __call__(self, data): + # Initialize a dictionary to store the volumes of each organ + pred_volumes = {} + + for key in self.keys: + for label_name in self.label_names.keys(): + # self._logger.info('Key: ', key, ' organ_name: ', label_name) + if label_name != "background": + # Get the predicted mask from the dictionary + pred_mask = data[key] + # Calculate the voxel size in cubic millimeters (voxel size should be in the metadata) + # Assuming the metadata contains 'spatial_shape' with voxel dimensions in mm + if hasattr(pred_mask, "affine"): + voxel_size = np.abs(np.linalg.det(pred_mask.affine[:3, :3])) + else: + raise ValueError("Affine transformation matrix with voxel spacing information is required.") + + # Calculate the volume in cubic millimeters + label_volume_mm3 = np.sum(pred_mask == self.label_names[label_name]) * voxel_size + + # Convert to milliliters (1 ml = 1000 mm^3) + label_volume_ml = label_volume_mm3 / 1000.0 + + # Store the result in the pred_volumes dictionary + # convert to int - radiologists prefer whole number with no decimals + pred_volumes[label_name] = int(round(label_volume_ml, 0)) + + # Add the calculated volumes to the data dictionary + key_name = key + "_volumes" + + data[key_name] = pred_volumes + # self._logger.info('pred_volumes: ', pred_volumes) + return data + + +class LabelToContourd(MapTransform): + def __init__(self, keys: KeysCollection, output_labels: list, allow_missing_keys: bool = False): + + self._logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + super().__init__(keys, allow_missing_keys) + + self.output_labels = output_labels + + def __call__(self, data): + d = dict(data) + for key in self.keys: + label_image = d[key] + assert isinstance(label_image, MetaTensor), "Input image must be a MetaTensor." + + # Initialize the contour image with the same shape as the label image + contour_image = np.zeros_like(label_image.cpu().numpy()) + + if label_image.ndim == 4: # Check if the label image is 4D with a channel dimension + # Process each 2D slice independently along the last axis (z-axis) + for i in range(label_image.shape[-1]): + slice_image = label_image[:, :, :, i].cpu().numpy() + + # Extract unique labels excluding background (assumed to be 0) + unique_labels = np.unique(slice_image) + unique_labels = unique_labels[unique_labels != 0] + + slice_contour = np.zeros_like(slice_image) + + # Generate contours for each label in the slice + for label in unique_labels: + # skip contour generation for labels that are not in output_labels + if label not in self.output_labels: + continue + + # Create a binary mask for the current label + binary_mask = np.zeros_like(slice_image) + binary_mask[slice_image == label] = 1.0 + + # Apply LabelToContour to the 2D slice (replace this with actual contour logic) + thick_edges = LabelToContour()(binary_mask) + + # Assign the label value to the contour image at the edge positions + slice_contour[thick_edges > 0] = label + + # Stack the processed slice back into the 4D contour image + contour_image[:, :, :, i] = slice_contour + else: + # If the label image is not 4D, process it directly + slice_image = label_image.cpu().numpy() + unique_labels = np.unique(slice_image) + unique_labels = unique_labels[unique_labels != 0] + + for label in unique_labels: + binary_mask = np.zeros_like(slice_image) + binary_mask[slice_image == label] = 1.0 + + thick_edges = LabelToContour()(binary_mask) + contour_image[thick_edges > 0] = label + + # Convert the contour image back to a MetaTensor with the original metadata + contour_image_meta = MetaTensor(contour_image, meta=label_image.meta) # , affine=label_image.affine) + + # Store the contour MetaTensor in the output dictionary + d[key] = contour_image_meta + + return d + + +class OverlayImageLabeld(MapTransform): + def __init__( + self, + image_key: KeysCollection, + label_key: str, + overlay_key: str = "overlay", + alpha: float = 0.7, + allow_missing_keys: bool = False, + ): + + self._logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + super().__init__(image_key, allow_missing_keys) + + self.image_key = image_key + self.label_key = label_key + self.overlay_key = overlay_key + self.alpha = alpha + self.jet_colormap = cm.get_cmap("jet", 256) # Get the Jet colormap with 256 discrete colors + + def apply_jet_colormap(self, label_volume): + """ + Apply the Jet colormap to a 3D label volume using matplotlib's colormap. + """ + assert label_volume.ndim == 3, "Label volume should have 3 dimensions (H, W, D) after removing channel." + + label_volume_normalized = (label_volume / label_volume.max()) * 255.0 + label_volume_uint8 = label_volume_normalized.astype(np.uint8) + + # Apply the colormap to each label + label_rgb = self.jet_colormap(label_volume_uint8)[:, :, :, :3] # Only take the RGB channels + + label_rgb = (label_rgb * 255).astype(np.uint8) + # Rearrange axes to get (3, H, W, D) + label_rgb = np.transpose(label_rgb, (3, 0, 1, 2)) + + assert label_rgb.shape == ( + 3, + *label_volume.shape, + ), f"Label RGB shape should be (3,H, W, D) but got {label_rgb.shape}" + + return label_rgb + + def convert_to_rgb(self, image_volume): + """ + Convert a single-channel grayscale 3D image to an RGB 3D image. + """ + assert image_volume.ndim == 3, "Image volume should have 3 dimensions (H, W, D) after removing channel." + + image_volume_normalized = (image_volume - image_volume.min()) / (image_volume.max() - image_volume.min()) + image_rgb = np.stack([image_volume_normalized] * 3, axis=0) + image_rgb = (image_rgb * 255).astype(np.uint8) + + assert image_rgb.shape == ( + 3, + *image_volume.shape, + ), f"Image RGB shape should be (3,H, W, D) but got {image_rgb.shape}" + + return image_rgb + + def _create_overlay(self, image_volume, label_volume): + # Convert the image volume and label volume to RGB + image_rgb = self.convert_to_rgb(image_volume) + label_rgb = self.apply_jet_colormap(label_volume) + + # Create an alpha-blended overlay + overlay = image_rgb.copy() + mask = label_volume > 0 + + # Apply the overlay where the mask is present + for i in range(3): # For each color channel + overlay[i, mask] = (self.alpha * label_rgb[i, mask] + (1 - self.alpha) * overlay[i, mask]).astype(np.uint8) + + assert ( + overlay.shape == image_rgb.shape + ), f"Overlay shape should match image RGB shape: {overlay.shape} vs {image_rgb.shape}" + + return overlay + + def __call__(self, data): + d = dict(data) + + # Get the image and label tensors + image = d[self.image_key] # Expecting shape (1, H, W, D) + label = d[self.label_key] # Expecting shape (1, H, W, D) + + # uncomment when running pipeline with mask (non-contour) outputs, i.e. LabelToContourd transform absent + # if image.device.type == "cuda": + # image = image.cpu() + # d[self.image_key] = image + # if label.device.type == "cuda": + # label = label.cpu() + # d[self.label_key] = label + # # ----------------------- + + # Ensure that the input has the correct dimensions + assert image.shape[0] == 1 and label.shape[0] == 1, "Image and label must have a channel dimension of 1." + assert image.shape == label.shape, f"Image and label must have the same shape: {image.shape} vs {label.shape}" + + # Remove the channel dimension for processing + image_volume = image[0] # Shape: (H, W, D) + label_volume = label[0] # Shape: (H, W, D) + + # Convert to 3D overlay + overlay = self._create_overlay(image_volume, label_volume) + + # Add the channel dimension back + # d[self.overlay_key] = np.expand_dims(overlay, axis=0) # Shape: (1, H, W, D, 3) + d[self.overlay_key] = MetaTensor(overlay, meta=label.meta, affine=label.affine) # Shape: (3, H, W, D) + + # Assert the final output shape + # assert d[self.overlay_key].shape == (1, *image_volume.shape, 3), \ + # f"Final overlay shape should be (1, H, W, D, 3) but got {d[self.overlay_key].shape}" + + assert d[self.overlay_key].shape == ( + 3, + *image_volume.shape, + ), f"Final overlay shape should be (3, H, W, D) but got {d[self.overlay_key].shape}" + + # Log the overlay creation (debugging) + self._logger.info(f"Overlay created with shape: {overlay.shape}") + # self._logger.info(f"Dictionary keys: {d.keys()}") + + # self._logger.info('overlay_image shape: ', d[self.overlay_key].shape) + return d + + +class SaveData(MapTransform): + """ + Save the output dictionary into JSON files. + + The name of the saved file will be `{key}_{output_postfix}.json`. + + Args: + keys: keys of the corresponding items to be saved in the dictionary. + output_dir: directory to save the output files. + output_postfix: a string appended to all output file names, default is `data`. + separate_folder: whether to save each file in a separate folder. Default is `True`. + print_log: whether to print logs when saving. Default is `True`. + """ + + def __init__( + self, + keys: KeysCollection, + namekey: str = "image", + output_dir: str = "./", + output_postfix: str = "data", + separate_folder: bool = False, + print_log: bool = True, + allow_missing_keys: bool = False, + ): + self._logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + super().__init__(keys, allow_missing_keys) + self.output_dir = output_dir + self.output_postfix = output_postfix + self.separate_folder = separate_folder + self.print_log = print_log + self.namekey = namekey + + def __call__(self, data): + d = dict(data) + image_name = os.path.basename(d[self.namekey].meta["filename_or_obj"]).split(".")[0] + for key in self.keys: + # Get the data + output_data = d[key] + + # Determine the file name + file_name = f"{image_name}_{self.output_postfix}.json" + if self.separate_folder: + file_path = os.path.join(self.output_dir, image_name, file_name) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + else: + file_path = os.path.join(self.output_dir, file_name) + + # Save the dictionary as a JSON file + with open(file_path, "w") as f: + json.dump(output_data, f) + + if self.print_log: + self._logger.info(f"Saved data to {file_path}") + + return d + + +# custom transform (not in original post_transforms.py in bundle): +class ExtractVolumeToTextd(MapTransform): + """ + Custom transform to extract volume information from the segmentation results and format it as a textual summary. + Filters organ volumes based on output_labels for DICOM SR write, while including all organs for MongoDB write. + The upstream CalculateVolumeFromMaskd transform calculates organ volumes and stores them in the dictionary + under the pred_key + '_volumes' key. The input dictionary is outputted unchanged as to not affect downstream operators. + + Args: + keys: keys of the corresponding items to be saved in the dictionary. + label_names: dictionary mapping organ names to their corresponding label indices. + output_labels: list of target label indices for organs to include in the DICOM SR output. + """ + + def __init__( + self, + keys: KeysCollection, + label_names: dict, + output_labels: List[int], + allow_missing_keys: bool = False, + ): + self._logger = logging.getLogger(f"{__name__}.{type(self).__name__}") + super().__init__(keys, allow_missing_keys) + + self.label_names = label_names + self.output_labels = output_labels + + # create separate result_texts for DICOM SR write (target organs) and MongoDB write (all organs) + self.result_text_dicom_sr: str = "" + self.result_text_mongodb: str = "" + + def __call__(self, data): + d = dict(data) + # use the first key in `keys` to access the volume data (e.g., pred_key + '_volumes') + volumes_key = self.keys[0] + organ_volumes = d.get(volumes_key, None) + + if organ_volumes is None: + raise ValueError(f"Volume data not found for key {volumes_key}.") + + # create the volume text outputs + volume_text_dicom_sr = [] + volume_text_mongodb = [] + + # loop through calculated organ volumes + for organ, volume in organ_volumes.items(): + + # append all organ volumes for MongoDB entry + volume_entry = f"{organ.capitalize()} Volume: {volume} mL" + volume_text_mongodb.append(volume_entry) + + # if the organ's label index is in output_labels + label_index = self.label_names.get(organ, None) + if label_index in self.output_labels: + # append organ volume for DICOM SR entry + volume_text_dicom_sr.append(volume_entry) + + self.result_text_dicom_sr = "\n".join(volume_text_dicom_sr) + self.result_text_mongodb = "\n".join(volume_text_mongodb) + + # not adding result_text to dictionary; return dictionary unchanged as to not affect downstream operators + return d diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/requirements.txt b/examples/apps/cchmc_ped_abd_ct_seg_app/requirements.txt new file mode 100644 index 00000000..309428d7 --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/requirements.txt @@ -0,0 +1,27 @@ +monai>=1.3.0 +torch>=1.12.0 +pytorch-ignite>=0.4.9 +fire>=0.4.0 +numpy>=1.22.2 +nibabel>=4.0.1 +# pydicom v3.0.0 removed pydicom._storage_sopclass_uids; don't meet or exceed this version +pydicom>=2.3.0,<3.0.0 +highdicom>=0.18.2 +itk>=5.3.0 +SimpleITK>=2.0.0 +scikit-image>=0.17.2 +Pillow>=8.0.0 +numpy-stl>=2.12.0 +trimesh>=3.8.11 +matplotlib>=3.7.2 +setuptools>=59.5.0 # for pkg_resources +python-dotenv>=1.0.1 + +# pymongo for MongoDB writing +pymongo>=4.10.1 + +# pytz for MongoDB Timestamp +pytz>=2024.1 + +# MONAI Deploy App SDK package installation +monai-deploy-app-sdk diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_build.sh b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_build.sh new file mode 100755 index 00000000..d4302ad6 --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_build.sh @@ -0,0 +1,29 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# build a MAP + +# check if the correct number of arguments are provided +if [ "$#" -ne 3 ]; then + echo "Please provide all arguments. Usage: $0 " + exit 1 +fi + +# assign command-line arguments to variables +tag_prefix=$1 +image_version=$2 +sdk_version=$3 + +# load in environment variables +source .env + +# build MAP +monai-deploy package cchmc_ped_abd_ct_seg_app -m $HOLOSCAN_MODEL_PATH -c cchmc_ped_abd_ct_seg_app/app.yaml -t ${tag_prefix}:${image_version} --platform x64-workstation --sdk-version ${sdk_version} -l DEBUG diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_extract.sh b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_extract.sh new file mode 100755 index 00000000..a87287cb --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_extract.sh @@ -0,0 +1,31 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# display and extract MAP contents + +# check if the correct number of arguments are provided +if [ "$#" -ne 2 ]; then + echo "Please provide all arguments. Usage: $0 " + exit 1 +fi + +# assign command-line arguments to variables +tag_prefix=$1 +image_version=$2 + +# display basic MAP manifests +docker run --rm ${tag_prefix}-x64-workstation-dgpu-linux-amd64:${image_version} show + +# remove and subsequently create export folder +rm -rf `pwd`/export && mkdir -p `pwd`/export + +# extract MAP contents +docker run --rm -v `pwd`/export/:/var/run/holoscan/export/ ${tag_prefix}-x64-workstation-dgpu-linux-amd64:${image_version} extract diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_run.sh b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_run.sh new file mode 100755 index 00000000..f4d5251a --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_run.sh @@ -0,0 +1,31 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# execute MAP locally with MAR + +# check if the correct number of arguments are provided +if [ "$#" -ne 2 ]; then + echo "Please provide all arguments. Usage: $0 " + exit 1 +fi + +# assign command-line arguments to variables +tag_prefix=$1 +image_version=$2 + +# load in environment variables +source .env + +# remove the output directory +rm -rf "$HOLOSCAN_OUTPUT_PATH" + +# execute MAP locally via MAR +monai-deploy run -i $HOLOSCAN_INPUT_PATH -o $HOLOSCAN_OUTPUT_PATH ${tag_prefix}-x64-workstation-dgpu-linux-amd64:${image_version} diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_run_interactive.sh b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_run_interactive.sh new file mode 100755 index 00000000..422ae16e --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/map_run_interactive.sh @@ -0,0 +1,37 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# run an interactive MAP container + +# check if the correct number of arguments are provided +if [ "$#" -ne 2 ]; then + echo "Please provide all arguments. Usage: $0 " + exit 1 +fi + +# assign command-line arguments to variables +tag_prefix=$1 +image_version=$2 + +# load in environment variables +source .env + +# remove the output directory +rm -rf "$HOLOSCAN_OUTPUT_PATH" + +# execute MAP locally via MAR and start interactive container +monai-deploy run -i $HOLOSCAN_INPUT_PATH -o $HOLOSCAN_OUTPUT_PATH ${tag_prefix}-x64-workstation-dgpu-linux-amd64:${image_version} --terminal + +# # start interactive MAP container without MAR +# docker run -it --entrypoint /bin/bash ${tag_prefix}-x64-workstation-dgpu-linux-amd64:${image_version} + +# # see dependencies installed in MAP +# pip freeze diff --git a/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/model_run.sh b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/model_run.sh new file mode 100755 index 00000000..6decca04 --- /dev/null +++ b/examples/apps/cchmc_ped_abd_ct_seg_app/scripts/model_run.sh @@ -0,0 +1,21 @@ +# Copyright 2021-2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# execute model bundle locally (pythonically) + +# load in environment variables +source .env + +# remove the output directory +rm -rf "$HOLOSCAN_OUTPUT_PATH" + +# execute model bundle locally (pythonically) +python cchmc_ped_abd_ct_seg_app -i "$HOLOSCAN_INPUT_PATH" -o "$HOLOSCAN_OUTPUT_PATH" -m "$HOLOSCAN_MODEL_PATH" diff --git a/requirements-examples.txt b/requirements-examples.txt index 14756af7..56301f72 100644 --- a/requirements-examples.txt +++ b/requirements-examples.txt @@ -1,6 +1,7 @@ scikit-image>=0.17.2 pydicom>=2.3.0 PyPDF2>=2.11.1 +types-pytz>=2024.1.0.20240203 highdicom>=0.18.2 SimpleITK>=2.0.0 Pillow>=8.4.0