Skip to content

Commit 2b1b6fd

Browse files
committed
Correct all MyPy type checking errors and improve inline doc.
The base inference operator needs to adjust the function signatures due to the expanded use of types resulting from the intro if bundle inference operator. The new version of MyPy also seems to be stricter on types. Signed-off-by: mmelqin <mingmelvinq@nvidia.com>
1 parent 152424d commit 2b1b6fd

File tree

3 files changed

+59
-30
lines changed

3 files changed

+59
-30
lines changed

monai/deploy/operators/inference_operator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
from abc import abstractmethod
13-
from typing import Any, Union
13+
from typing import Any, Dict, Tuple, Union
1414

1515
from monai.deploy.core import ExecutionContext, Image, InputContext, Operator, OutputContext
1616

@@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs):
2727
super().__init__()
2828

2929
@abstractmethod
30-
def pre_process(self, data: Any) -> Union[Image, Any]:
30+
def pre_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
3131
"""Transforms input before being used for predicting on a model.
3232
3333
This method must be overridden by a derived class.
@@ -50,7 +50,7 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
5050
pass
5151

5252
@abstractmethod
53-
def predict(self, data: Any) -> Union[Image, Any]:
53+
def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
5454
"""Predicts results using the models(s) with input tensors.
5555
5656
This method must be overridden by a derived class.
@@ -61,7 +61,7 @@ def predict(self, data: Any) -> Union[Image, Any]:
6161
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
6262

6363
@abstractmethod
64-
def post_process(self, data: Any) -> Union[Image, Any]:
64+
def post_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
6565
"""Transform the prediction results from the model(s).
6666
6767
This method must be overridden by a derived class.

monai/deploy/operators/monai_bundle_inference_operator.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
PostFix, _ = optional_import("monai.utils.enums", name="PostFix") # For the default meta_key_postfix
3636
first, _ = optional_import("monai.utils.misc", name="first")
37+
ensure_tuple, _ = optional_import("monai.utils", name="ensure_tuple")
3738
Compose_, _ = optional_import("monai.transforms", name="Compose")
3839
ConfigParser_, _ = optional_import("monai.bundle", name="ConfigParser")
3940
MapTransform_, _ = optional_import("monai.transforms", name="MapTransform")
@@ -236,7 +237,7 @@ def __init__(
236237
output_mapping: List[IOMapping],
237238
model_name: Optional[str] = "",
238239
bundle_path: Optional[str] = None,
239-
bundle_config_names: BundleConfigNames = None,
240+
bundle_config_names: Optional[BundleConfigNames] = None,
240241
*args,
241242
**kwargs,
242243
):
@@ -261,9 +262,9 @@ def __init__(
261262
self._input_mapping = input_mapping
262263
self._output_mapping = output_mapping
263264

264-
self._parser = None # Needs known bundle path, either on init or when compute function is called.
265-
self._inferer = None # Will be set during bundle parsing.
266-
self._init_completed = False
265+
self._parser: ConfigParser = None # Needs known bundle path, either on init or when compute function is called.
266+
self._inferer: Any = None # Will be set during bundle parsing.
267+
self._init_completed: bool = False
267268

268269
# Need to set the operator's input(s) and output(s). Even when the bundle parsing is done in init,
269270
# there is still a need to define what op inputs/outputs map to what keys in the bundle config,
@@ -289,6 +290,9 @@ def __init__(
289290
logging.warn("Bundle parsing is not completed on init, delayed till this operator is called to execute.")
290291
self._bundle_path = None
291292

293+
# Lazy init of model network till execution time when the context is fully set.
294+
self._model_network: Any = None
295+
292296
@property
293297
def model_name(self) -> str:
294298
return self._model_name
@@ -390,7 +394,7 @@ def _get_meta_key_postfix(self, compose: Compose, key_name: str = "meta_key_post
390394
post_fix = post_fix[0]
391395
break
392396

393-
return post_fix
397+
return str(post_fix)
394398

395399
def _get_io_data_type(self, conf):
396400
"""
@@ -441,28 +445,32 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
441445

442446
# Try to get the Model object and its path from the context.
443447
# If operator is not fully initialized, use model path as bundle path to finish it.
444-
# If Model not loaded, but bundle path exists, load model, just in case.
448+
# If Model not loaded, but bundle path exists, load model; edge case for local dev.
445449
#
446450
# `context.models.get(model_name)` returns a model instance if exists.
447451
# If model_name is not specified and only one model exists, it returns that model.
448-
model = context.models.get(self._model_name) if context.models else None
449-
if model:
452+
453+
self._model_network = context.models.get(self._model_name) if context.models else None
454+
if self._model_network:
450455
if not self._init_completed:
451456
with self._lock:
452457
if not self._init_completed:
453-
self._bundle_path = model.path
458+
self._bundle_path = self._model_network.path
454459
self._init_config(self._bundle_config_names.config_names)
455460
self._init_completed
456461
elif self._bundle_path:
462+
# For the case of local dev/testing when the bundle path is not passed in as an exec cmd arg.
463+
# When run as a MAP docker, the bundle file is expected to be in the context, even if the model
464+
# network is loaded on a remote inference server (when the feature is introduced).
457465
logging.debug(f"Model network not loaded. Trying to load from model path: {self._bundle_path}")
458-
model = torch.jit.load(self.bundle_path, map_location=self._device).eval()
466+
self._model_network = torch.jit.load(self.bundle_path, map_location=self._device).eval()
459467
else:
460468
raise IOError("Model network is not load and model file not found.")
461469

462470
first_input_name, *other_names = list(self._inputs.keys())
463471

464472
with torch.no_grad():
465-
inputs = {}
473+
inputs: Any = {} # Use type Any to quiet MyPy type checking complaints.
466474

467475
start = time.time()
468476
for name in self._inputs.keys():
@@ -482,13 +490,13 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
482490
logging.debug(f"Ingest and Pre-processing elapsed time (seconds): {time.time() - start}")
483491

484492
start = time.time()
485-
outputs = self.predict(data=first_input, network=model, **other_inputs)
493+
outputs: Any = self.predict(data=first_input, **other_inputs) # Use type Any to quiet MyPy complaints.
486494
logging.debug(f"Inference elapsed time (seconds): {time.time() - start}")
487495

488496
# TODO: Does this work for models where multiple outputs are returned?
489497
# Note that the inputs are needed because the invert transform requires it.
490498
start = time.time()
491-
outputs = self.post_process(outputs[0], inputs)
499+
outputs = self.post_process(ensure_tuple(outputs)[0], preprocessed_inputs=inputs)
492500
logging.debug(f"Post-processing elapsed time (seconds): {time.time() - start}")
493501
if isinstance(outputs, (tuple, list)):
494502
output_dict = dict(zip(self._outputs.keys(), outputs))
@@ -502,19 +510,27 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
502510
# Please see the comments in the called function for the reasons.
503511
self._send_output(output_dict[name], name, input_metadata, op_output, context)
504512

505-
def predict(self, data: Any, network: Any, *args, **kwargs) -> Union[Image, Any]:
513+
def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
506514
"""Predicts output using the inferer."""
507-
return self._inferer(inputs=data, network=network, *args, **kwargs)
508515

509-
def pre_process(self, data: Any) -> Union[Image, Any]:
516+
return self._inferer(inputs=data, network=self._model_network, *args, **kwargs)
517+
518+
def pre_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
510519
"""Processes the input dictionary with the stored transform sequence `self._preproc`."""
511520

512521
if is_map_compose(self._preproc):
513522
return self._preproc(data)
514523
return {k: self._preproc(v) for k, v in data.items()}
515524

516-
def post_process(self, data: Any, inputs: Dict) -> Union[Image, Any]:
517-
"""Processes the output list/dictionary with the stored transform sequence `self._postproc`."""
525+
def post_process(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
526+
"""Processes the output list/dictionary with the stored transform sequence `self._postproc`.
527+
528+
The "processed_inputs", in fact the metadata in it, need to be passed in so that the
529+
invertible transforms in the post processing can work properly.
530+
"""
531+
532+
# Expect the inputs be passed in so that the inversion can work.
533+
inputs = kwargs.get("preprocessed_inputs", {})
518534

519535
if is_map_compose(self._postproc):
520536
if isinstance(data, (list, tuple)):
@@ -585,7 +601,7 @@ def _receive_input(self, name: str, op_input: InputContext, context: ExecutionCo
585601

586602
return value, metadata
587603

588-
def _send_output(self, value, name: str, metadata: Dict, op_output: OutputContext, context: ExecutionContext):
604+
def _send_output(self, value: Any, name: str, metadata: Dict, op_output: OutputContext, context: ExecutionContext):
589605
"""Send the given output value to the output context."""
590606

591607
logging.debug(f"Setting output {name}")
@@ -610,7 +626,7 @@ def _send_output(self, value, name: str, metadata: Dict, op_output: OutputContex
610626
raise TypeError("arg 1 must be of type torch.Tensor or ndarray.")
611627

612628
logging.debug(f"Output {name} numpy image shape: {value.shape}")
613-
result = Image(np.swapaxes(np.squeeze(value, 0), 0, 2).astype(np.uint8), metadata=metadata)
629+
result: Any = Image(np.swapaxes(np.squeeze(value, 0), 0, 2).astype(np.uint8), metadata=metadata)
614630
logging.debug(f"Converted Image shape: {result.asnumpy().shape}")
615631
elif otype == np.ndarray:
616632
result = np.asarray(value)

monai/deploy/operators/monai_seg_inference_operator.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# limitations under the License.
1111

1212
from threading import Lock
13-
from typing import Any, Dict, List, Optional, Sequence, Union
13+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
1414

1515
import numpy as np
1616

@@ -31,7 +31,6 @@
3131
Compose_, _ = optional_import("monai.transforms", name="Compose")
3232
# Dynamic class is not handled so make it Any for now: https://github.com/python/mypy/issues/2477
3333
Compose: Any = Compose_
34-
sliding_window_inference, _ = optional_import("monai.inferers", name="sliding_window_inference")
3534

3635
import monai.deploy.core as md
3736
from monai.deploy.core import ExecutionContext, Image, InputContext, IOType, OutputContext
@@ -246,30 +245,44 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
246245
with self._lock:
247246
self._executing = False
248247

249-
def pre_process(self, img_reader) -> Union[Any, Image, Compose]:
248+
def pre_process(self, data: Any, *args, **kwargs) -> Union[Any, Image, Tuple[Any, ...], Dict[Any, Any]]:
250249
"""Transforms input before being used for predicting on a model.
251250
252251
This method must be overridden by a derived class.
252+
Expected return is monai.transforms.Compose.
253+
254+
Args:
255+
data(monai.data.ImageReader): Reader used in LoadImage to load `monai.deploy.core.Image` as the input.
256+
257+
Returns:
258+
monai.transforms.Compose encapsulating pre transforms
253259
254260
Raises:
255261
NotImplementedError: When the subclass does not override this method.
256262
"""
257263
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
258264

259-
def post_process(self, pre_transforms: Compose, out_dir: str = "./infer_out") -> Union[Any, Image, Compose]:
265+
def post_process(self, data: Any, *args, **kwargs) -> Union[Any, Image, Tuple[Any, ...], Dict[Any, Any]]:
260266
"""Transforms the prediction results from the model(s).
261267
262268
This method must be overridden by a derived class.
269+
Expected return is monai.transforms.Compose.
270+
271+
Args:
272+
data(monai.transforms.Compose): The pre-processing transforms in a Compose object.
273+
274+
Returns:
275+
monai.transforms.Compose encapsulating post-processing transforms.
263276
264277
Raises:
265278
NotImplementedError: When the subclass does not override this method.
266279
"""
267280
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
268281

269-
def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any]:
282+
def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
270283
"""Predicts results using the models(s) with input tensors.
271284
272-
This method must be overridden by a derived class.
285+
This method is currently not used in this class, instead monai.inferers.sliding_window_inference is used.
273286
274287
Raises:
275288
NotImplementedError: When the subclass does not override this method.

0 commit comments

Comments
 (0)