34
34
35
35
PostFix , _ = optional_import ("monai.utils.enums" , name = "PostFix" ) # For the default meta_key_postfix
36
36
first , _ = optional_import ("monai.utils.misc" , name = "first" )
37
+ ensure_tuple , _ = optional_import ("monai.utils" , name = "ensure_tuple" )
37
38
Compose_ , _ = optional_import ("monai.transforms" , name = "Compose" )
38
39
ConfigParser_ , _ = optional_import ("monai.bundle" , name = "ConfigParser" )
39
40
MapTransform_ , _ = optional_import ("monai.transforms" , name = "MapTransform" )
@@ -236,7 +237,7 @@ def __init__(
236
237
output_mapping : List [IOMapping ],
237
238
model_name : Optional [str ] = "" ,
238
239
bundle_path : Optional [str ] = None ,
239
- bundle_config_names : BundleConfigNames = None ,
240
+ bundle_config_names : Optional [ BundleConfigNames ] = None ,
240
241
* args ,
241
242
** kwargs ,
242
243
):
@@ -261,9 +262,9 @@ def __init__(
261
262
self ._input_mapping = input_mapping
262
263
self ._output_mapping = output_mapping
263
264
264
- self ._parser = None # Needs known bundle path, either on init or when compute function is called.
265
- self ._inferer = None # Will be set during bundle parsing.
266
- self ._init_completed = False
265
+ self ._parser : ConfigParser = None # Needs known bundle path, either on init or when compute function is called.
266
+ self ._inferer : Any = None # Will be set during bundle parsing.
267
+ self ._init_completed : bool = False
267
268
268
269
# Need to set the operator's input(s) and output(s). Even when the bundle parsing is done in init,
269
270
# there is still a need to define what op inputs/outputs map to what keys in the bundle config,
@@ -289,6 +290,9 @@ def __init__(
289
290
logging .warn ("Bundle parsing is not completed on init, delayed till this operator is called to execute." )
290
291
self ._bundle_path = None
291
292
293
+ # Lazy init of model network till execution time when the context is fully set.
294
+ self ._model_network : Any = None
295
+
292
296
@property
293
297
def model_name (self ) -> str :
294
298
return self ._model_name
@@ -390,7 +394,7 @@ def _get_meta_key_postfix(self, compose: Compose, key_name: str = "meta_key_post
390
394
post_fix = post_fix [0 ]
391
395
break
392
396
393
- return post_fix
397
+ return str ( post_fix )
394
398
395
399
def _get_io_data_type (self , conf ):
396
400
"""
@@ -441,28 +445,32 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
441
445
442
446
# Try to get the Model object and its path from the context.
443
447
# If operator is not fully initialized, use model path as bundle path to finish it.
444
- # If Model not loaded, but bundle path exists, load model, just in case.
448
+ # If Model not loaded, but bundle path exists, load model; edge case for local dev .
445
449
#
446
450
# `context.models.get(model_name)` returns a model instance if exists.
447
451
# If model_name is not specified and only one model exists, it returns that model.
448
- model = context .models .get (self ._model_name ) if context .models else None
449
- if model :
452
+
453
+ self ._model_network = context .models .get (self ._model_name ) if context .models else None
454
+ if self ._model_network :
450
455
if not self ._init_completed :
451
456
with self ._lock :
452
457
if not self ._init_completed :
453
- self ._bundle_path = model .path
458
+ self ._bundle_path = self . _model_network .path
454
459
self ._init_config (self ._bundle_config_names .config_names )
455
460
self ._init_completed
456
461
elif self ._bundle_path :
462
+ # For the case of local dev/testing when the bundle path is not passed in as an exec cmd arg.
463
+ # When run as a MAP docker, the bundle file is expected to be in the context, even if the model
464
+ # network is loaded on a remote inference server (when the feature is introduced).
457
465
logging .debug (f"Model network not loaded. Trying to load from model path: { self ._bundle_path } " )
458
- model = torch .jit .load (self .bundle_path , map_location = self ._device ).eval ()
466
+ self . _model_network = torch .jit .load (self .bundle_path , map_location = self ._device ).eval ()
459
467
else :
460
468
raise IOError ("Model network is not load and model file not found." )
461
469
462
470
first_input_name , * other_names = list (self ._inputs .keys ())
463
471
464
472
with torch .no_grad ():
465
- inputs = {}
473
+ inputs : Any = {} # Use type Any to quiet MyPy type checking complaints.
466
474
467
475
start = time .time ()
468
476
for name in self ._inputs .keys ():
@@ -482,13 +490,13 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
482
490
logging .debug (f"Ingest and Pre-processing elapsed time (seconds): { time .time () - start } " )
483
491
484
492
start = time .time ()
485
- outputs = self .predict (data = first_input , network = model , ** other_inputs )
493
+ outputs : Any = self .predict (data = first_input , ** other_inputs ) # Use type Any to quiet MyPy complaints.
486
494
logging .debug (f"Inference elapsed time (seconds): { time .time () - start } " )
487
495
488
496
# TODO: Does this work for models where multiple outputs are returned?
489
497
# Note that the inputs are needed because the invert transform requires it.
490
498
start = time .time ()
491
- outputs = self .post_process (outputs [0 ], inputs )
499
+ outputs = self .post_process (ensure_tuple ( outputs ) [0 ], preprocessed_inputs = inputs )
492
500
logging .debug (f"Post-processing elapsed time (seconds): { time .time () - start } " )
493
501
if isinstance (outputs , (tuple , list )):
494
502
output_dict = dict (zip (self ._outputs .keys (), outputs ))
@@ -502,19 +510,27 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
502
510
# Please see the comments in the called function for the reasons.
503
511
self ._send_output (output_dict [name ], name , input_metadata , op_output , context )
504
512
505
- def predict (self , data : Any , network : Any , * args , ** kwargs ) -> Union [Image , Any ]:
513
+ def predict (self , data : Any , * args , ** kwargs ) -> Union [Image , Any , Tuple [ Any , ...], Dict [ Any , Any ] ]:
506
514
"""Predicts output using the inferer."""
507
- return self ._inferer (inputs = data , network = network , * args , ** kwargs )
508
515
509
- def pre_process (self , data : Any ) -> Union [Image , Any ]:
516
+ return self ._inferer (inputs = data , network = self ._model_network , * args , ** kwargs )
517
+
518
+ def pre_process (self , data : Any , * args , ** kwargs ) -> Union [Image , Any , Tuple [Any , ...], Dict [Any , Any ]]:
510
519
"""Processes the input dictionary with the stored transform sequence `self._preproc`."""
511
520
512
521
if is_map_compose (self ._preproc ):
513
522
return self ._preproc (data )
514
523
return {k : self ._preproc (v ) for k , v in data .items ()}
515
524
516
- def post_process (self , data : Any , inputs : Dict ) -> Union [Image , Any ]:
517
- """Processes the output list/dictionary with the stored transform sequence `self._postproc`."""
525
+ def post_process (self , data : Any , * args , ** kwargs ) -> Union [Image , Any , Tuple [Any , ...], Dict [Any , Any ]]:
526
+ """Processes the output list/dictionary with the stored transform sequence `self._postproc`.
527
+
528
+ The "processed_inputs", in fact the metadata in it, need to be passed in so that the
529
+ invertible transforms in the post processing can work properly.
530
+ """
531
+
532
+ # Expect the inputs be passed in so that the inversion can work.
533
+ inputs = kwargs .get ("preprocessed_inputs" , {})
518
534
519
535
if is_map_compose (self ._postproc ):
520
536
if isinstance (data , (list , tuple )):
@@ -585,7 +601,7 @@ def _receive_input(self, name: str, op_input: InputContext, context: ExecutionCo
585
601
586
602
return value , metadata
587
603
588
- def _send_output (self , value , name : str , metadata : Dict , op_output : OutputContext , context : ExecutionContext ):
604
+ def _send_output (self , value : Any , name : str , metadata : Dict , op_output : OutputContext , context : ExecutionContext ):
589
605
"""Send the given output value to the output context."""
590
606
591
607
logging .debug (f"Setting output { name } " )
@@ -610,7 +626,7 @@ def _send_output(self, value, name: str, metadata: Dict, op_output: OutputContex
610
626
raise TypeError ("arg 1 must be of type torch.Tensor or ndarray." )
611
627
612
628
logging .debug (f"Output { name } numpy image shape: { value .shape } " )
613
- result = Image (np .swapaxes (np .squeeze (value , 0 ), 0 , 2 ).astype (np .uint8 ), metadata = metadata )
629
+ result : Any = Image (np .swapaxes (np .squeeze (value , 0 ), 0 , 2 ).astype (np .uint8 ), metadata = metadata )
614
630
logging .debug (f"Converted Image shape: { result .asnumpy ().shape } " )
615
631
elif otype == np .ndarray :
616
632
result = np .asarray (value )
0 commit comments