20
20
import torch_tensorrt
21
21
from monai .inferers import sliding_window_inference
22
22
from monai .networks .nets import SegResNet
23
- from monai .transforms import (Activationsd , AsDiscreted , Compose ,
24
- EnsureChannelFirstd , EnsureTyped , Invertd ,
25
- LoadImaged , NormalizeIntensityd , Orientationd ,
26
- ScaleIntensityd , Spacingd )
27
-
28
- from utils import (prepare_model_weights , prepare_tensorrt_model ,
29
- prepare_test_datalist )
23
+ from monai .transforms import (
24
+ Activationsd ,
25
+ AsDiscreted ,
26
+ Compose ,
27
+ EnsureChannelFirstd ,
28
+ EnsureTyped ,
29
+ Invertd ,
30
+ LoadImaged ,
31
+ NormalizeIntensityd ,
32
+ Orientationd ,
33
+ ScaleIntensityd ,
34
+ Spacingd ,
35
+ )
36
+
37
+ from utils import prepare_model_weights , prepare_tensorrt_model , prepare_test_datalist
30
38
31
39
32
40
def get_transforms (device , gpu_loading_flag = False , gpu_transforms_flag = False ):
@@ -49,6 +57,7 @@ def get_transforms(device, gpu_loading_flag=False, gpu_transforms_flag=False):
49
57
50
58
return infer_transforms
51
59
60
+
52
61
def get_post_transforms (infer_transforms ):
53
62
post_transforms = Compose (
54
63
[
@@ -65,6 +74,7 @@ def get_post_transforms(infer_transforms):
65
74
)
66
75
return post_transforms
67
76
77
+
68
78
def get_model (device , weights_path , trt_model_path , trt_flag = False ):
69
79
if not trt_flag :
70
80
model = SegResNet (
@@ -84,11 +94,12 @@ def get_model(device, weights_path, trt_model_path, trt_flag=False):
84
94
model = torch .jit .load (trt_model_path )
85
95
return model
86
96
97
+
87
98
def run_inference (data_list , infer_transforms , model , device , benchmark_type ):
88
99
total_time_dict = {}
89
100
roi_size = (96 , 96 , 96 )
90
101
sw_batch_size = 1
91
-
102
+
92
103
for idx , sample in enumerate (data_list ):
93
104
start = timer ()
94
105
data = infer_transforms ({"image" : sample })
@@ -117,6 +128,7 @@ def run_inference(data_list, infer_transforms, model, device, benchmark_type):
117
128
118
129
return total_time_dict
119
130
131
+
120
132
def main ():
121
133
parser = argparse .ArgumentParser (description = "Run inference benchmark." )
122
134
parser .add_argument ("--benchmark_type" , type = str , default = "original" , help = "Type of benchmark to run" )
@@ -146,5 +158,6 @@ def main():
146
158
df = pd .DataFrame (list (total_time_dict .items ()), columns = ["file_name" , "time" ])
147
159
df .to_csv (os .path .join (root_dir , f"time_{ args .benchmark_type } .csv" ), index = False )
148
160
161
+
149
162
if __name__ == "__main__" :
150
163
main ()
0 commit comments