Skip to content

Commit 4d15f08

Browse files
authored
Add the liver tumor sample app with publishing to Render Server (#183)
* Add the liver tumor sample app with publishing to Render Server Signed-off-by: mmelqin <mingmelvinq@nvidia.com> * Update to use Path.mkdir Signed-off-by: mmelqin <mingmelvinq@nvidia.com>
1 parent 4f70ba4 commit 4d15f08

File tree

4 files changed

+235
-0
lines changed

4 files changed

+235
-0
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import os
2+
import sys
3+
4+
_current_dir = os.path.abspath(os.path.dirname(__file__))
5+
if sys.path and os.path.abspath(sys.path[0]) != _current_dir:
6+
sys.path.insert(0, _current_dir)
7+
del _current_dir
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from app import AIUnetrSegApp
2+
3+
if __name__ == "__main__":
4+
AIUnetrSegApp(do_run=True)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
14+
from livertumor_seg_operator import LiverTumorSegOperator
15+
16+
from monai.deploy.core import Application, resource
17+
from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator
18+
from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator
19+
from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator
20+
from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator
21+
from monai.deploy.operators.publisher_operator import PublisherOperator
22+
23+
24+
@resource(cpu=1, gpu=1, memory="7Gi")
25+
# pip_packages can be a string that is a path(str) to requirements.txt file or a list of packages.
26+
# The MONAI pkg is not required by this class, instead by the included operators.
27+
class AIUnetrSegApp(Application):
28+
def __init__(self, *args, **kwargs):
29+
"""Creates an application instance."""
30+
31+
self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
32+
super().__init__(*args, **kwargs)
33+
34+
def run(self, *args, **kwargs):
35+
# This method calls the base class to run. Can be omitted if simply calling through.
36+
self._logger.debug(f"Begin {self.run.__name__}")
37+
super().run(*args, **kwargs)
38+
self._logger.debug(f"End {self.run.__name__}")
39+
40+
def compose(self):
41+
"""Creates the app specific operators and chain them up in the processing DAG."""
42+
43+
self._logger.debug(f"Begin {self.compose.__name__}")
44+
# Creates the custom operator(s) as well as SDK built-in operator(s).
45+
study_loader_op = DICOMDataLoaderOperator()
46+
series_selector_op = DICOMSeriesSelectorOperator()
47+
series_to_vol_op = DICOMSeriesToVolumeOperator()
48+
# Model specific inference operator, supporting MONAI transforms.
49+
unetr_seg_op = LiverTumorSegOperator()
50+
51+
# Create the publisher operator
52+
publisher_op = PublisherOperator()
53+
54+
# Creates DICOM Seg writer with segment label name in a string list
55+
dicom_seg_writer = DICOMSegmentationWriterOperator(
56+
seg_labels=[
57+
"Liver",
58+
"Tumor",
59+
]
60+
)
61+
# Create the processing pipeline, by specifying the upstream and downstream operators, and
62+
# ensuring the output from the former matches the input of the latter, in both name and type.
63+
self.add_flow(study_loader_op, series_selector_op, {"dicom_study_list": "dicom_study_list"})
64+
self.add_flow(series_selector_op, series_to_vol_op, {"dicom_series": "dicom_series"})
65+
self.add_flow(series_to_vol_op, unetr_seg_op, {"image": "image"})
66+
# Note below the dicom_seg_writer requires two inputs, each coming from a upstream operator.
67+
# Also note that the DICOMSegmentationWriterOperator may throw exception with some inputs.
68+
# Bug has been created to track the issue.
69+
self.add_flow(series_selector_op, dicom_seg_writer, {"dicom_series": "dicom_series"})
70+
self.add_flow(unetr_seg_op, dicom_seg_writer, {"seg_image": "seg_image"})
71+
# Add the publishing operator to save the input and seg images for Render Server.
72+
# Note the PublisherOperator has temp impl till a proper rendering module is created.
73+
self.add_flow(unetr_seg_op, publisher_op, {"saved_images_folder": "saved_images_folder"})
74+
75+
self._logger.debug(f"End {self.compose.__name__}")
76+
77+
78+
if __name__ == "__main__":
79+
# Creates the app and test it standalone. When running is this mode, please note the following:
80+
# -m <model file>, for model file path
81+
# -i <DICOM folder>, for input DICOM CT series folder
82+
# -o <output folder>, for the output folder, default $PWD/output
83+
# e.g.
84+
# python3 app.py -i input -m model/model.ts
85+
#
86+
logging.basicConfig(level=logging.DEBUG)
87+
app_instance = AIUnetrSegApp() # Optional params' defaults are fine.
88+
app_instance.run()
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
14+
from numpy import uint8
15+
16+
import monai.deploy.core as md
17+
from monai.deploy.core import DataPath, ExecutionContext, Image, InputContext, IOType, Operator, OutputContext
18+
from monai.deploy.operators.monai_seg_inference_operator import InMemImageReader, MonaiSegInferenceOperator
19+
from monai.transforms import (
20+
Activationsd,
21+
AsDiscreted,
22+
Compose,
23+
CropForegroundd,
24+
EnsureChannelFirstd,
25+
Invertd,
26+
LoadImaged,
27+
SaveImaged,
28+
ScaleIntensityRanged,
29+
Spacingd,
30+
ToTensord,
31+
)
32+
33+
34+
@md.input("image", Image, IOType.IN_MEMORY)
35+
@md.output("seg_image", Image, IOType.IN_MEMORY)
36+
@md.output("saved_images_folder", DataPath, IOType.DISK)
37+
@md.env(pip_packages=["monai==0.6.0", "torch>=1.5", "numpy>=1.17", "nibabel"])
38+
class LiverTumorSegOperator(Operator):
39+
"""Performs liver and tumor segmentation using a DL model with an image converted from a DICOM CT series.
40+
41+
The model used in this application is from NVIDIA, publicly available at
42+
https://ngc.nvidia.com/catalog/models/nvidia:med:clara_pt_liver_and_tumor_ct_segmentation
43+
44+
Described in the downloaded model package, also called Medical Model Archive (MMAR), are the pre and post
45+
transforms before and after inference, and are using MONAI SDK transforms. As such, these transforms are
46+
simply ported to this operator, with changing SegmentationSaver handler to SaveImageD post transform.
47+
48+
This operator makes use of the App SDK MonaiSegInferenceOperator in a compsition approach.
49+
It creates the pre-transforms as well as post-transforms with MONAI dictionary based transforms.
50+
Note that the App SDK InMemImageReader, derived from MONAI ImageReader, is passed to LoadImaged.
51+
This derived reader is needed to parse the in memory image object, and return the expected data structure.
52+
Loading of the model, and predicting using in-proc PyTorch inference is done by MonaiSegInferenceOperator.
53+
"""
54+
55+
def __init__(self):
56+
57+
self.logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
58+
super().__init__()
59+
self._input_dataset_key = "image"
60+
self._pred_dataset_key = "pred"
61+
62+
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
63+
64+
input_image = op_input.get("image")
65+
if not input_image:
66+
raise ValueError("Input image is not found.")
67+
68+
# Get the output path from the execution context for saving file(s) to app output.
69+
# Without using this path, operator would be saving files to its designated path, e.g.
70+
# $PWD/.monai_workdir/operators/6048d75a-5de1-45b9-8bd1-2252f88827f2/0/output
71+
op_output_folder_name = DataPath("saved_images_folder")
72+
op_output.set(op_output_folder_name, "saved_images_folder")
73+
op_output_folder_path = op_output.get("saved_images_folder").path
74+
op_output_folder_path.mkdir(parents=True, exist_ok=True)
75+
print(f"Operator output folder path: {op_output_folder_path}")
76+
77+
# This operator gets an in-memory Image object, so a specialized ImageReader is needed.
78+
_reader = InMemImageReader(input_image)
79+
pre_transforms = self.pre_process(_reader)
80+
post_transforms = self.post_process(pre_transforms, op_output_folder_path)
81+
82+
# Delegates inference and saving output to the built-in operator.
83+
infer_operator = MonaiSegInferenceOperator(
84+
(
85+
160,
86+
160,
87+
160,
88+
),
89+
pre_transforms,
90+
post_transforms,
91+
overlap=0.6,
92+
)
93+
94+
# Setting the keys used in the dictironary based transforms may change.
95+
infer_operator.input_dataset_key = self._input_dataset_key
96+
infer_operator.pred_dataset_key = self._pred_dataset_key
97+
98+
# Now let the built-in operator handles the work with the I/O spec and execution context.
99+
infer_operator.compute(op_input, op_output, context)
100+
101+
def pre_process(self, img_reader) -> Compose:
102+
"""Composes transforms for preprocessing input before predicting on a model."""
103+
104+
my_key = self._input_dataset_key
105+
return Compose(
106+
[
107+
LoadImaged(keys=my_key, reader=img_reader),
108+
EnsureChannelFirstd(keys=my_key),
109+
Spacingd(keys=my_key, pixdim=(1.0, 1.0, 1.0), mode=("bilinear"), align_corners=True),
110+
ScaleIntensityRanged(my_key, a_min=-21, a_max=189, b_min=0.0, b_max=1.0, clip=True),
111+
CropForegroundd(my_key, source_key=my_key),
112+
ToTensord(my_key),
113+
]
114+
)
115+
116+
def post_process(self, pre_transforms: Compose, out_dir: str = "./prediction_output") -> Compose:
117+
"""Composes transforms for postprocessing the prediction results."""
118+
119+
pred_key = self._pred_dataset_key
120+
return Compose(
121+
[
122+
Activationsd(keys=pred_key, softmax=True),
123+
AsDiscreted(keys=pred_key, argmax=True),
124+
Invertd(
125+
keys=pred_key, transform=pre_transforms, orig_keys=self._input_dataset_key, nearest_interp=True
126+
),
127+
SaveImaged(keys=pred_key, output_dir=out_dir, output_postfix="seg", output_dtype=uint8, resample=False),
128+
SaveImaged(
129+
keys=self._input_dataset_key,
130+
output_dir=out_dir,
131+
output_postfix="",
132+
output_dtype=uint8,
133+
resample=False,
134+
),
135+
]
136+
)

0 commit comments

Comments
 (0)