@@ -49,6 +49,7 @@ def get_transforms(device, gpu_loading_flag=False, gpu_transforms_flag=False):
49
49
50
50
return infer_transforms
51
51
52
+
52
53
def get_post_transforms (infer_transforms ):
53
54
post_transforms = Compose (
54
55
[
@@ -65,6 +66,7 @@ def get_post_transforms(infer_transforms):
65
66
)
66
67
return post_transforms
67
68
69
+
68
70
def get_model (device , weights_path , trt_model_path , trt_flag = False ):
69
71
if not trt_flag :
70
72
model = SegResNet (
@@ -84,11 +86,12 @@ def get_model(device, weights_path, trt_model_path, trt_flag=False):
84
86
model = torch .jit .load (trt_model_path )
85
87
return model
86
88
89
+
87
90
def run_inference (data_list , infer_transforms , model , device , benchmark_type ):
88
91
total_time_dict = {}
89
92
roi_size = (96 , 96 , 96 )
90
- sw_batch_size = 1
91
-
93
+ sw_batch_size = 4
94
+
92
95
for idx , sample in enumerate (data_list ):
93
96
start = timer ()
94
97
data = infer_transforms ({"image" : sample })
@@ -114,9 +117,10 @@ def run_inference(data_list, infer_transforms, model, device, benchmark_type):
114
117
sample_name = sample .split ("/" )[- 1 ]
115
118
if idx > 0 :
116
119
total_time_dict [sample_name ] = end - start
117
-
120
+ print ( f"Time taken for { sample_name } : { end - start } seconds" )
118
121
return total_time_dict
119
122
123
+
120
124
def main ():
121
125
parser = argparse .ArgumentParser (description = "Run inference benchmark." )
122
126
parser .add_argument ("--benchmark_type" , type = str , default = "original" , help = "Type of benchmark to run" )
@@ -128,8 +132,8 @@ def main():
128
132
torch_tensorrt .runtime .set_multi_device_safe_mode (True )
129
133
device = torch .device ("cuda:0" ) if torch .cuda .is_available () else torch .device ("cpu" )
130
134
train_files = prepare_test_datalist (root_dir )
131
- # since the dataset is too large, the smallest 21 files are used for warm up (1 file) and benchmarking (11 files)
132
- train_files = sorted (train_files , key = lambda x : os .path .getsize (x ), reverse = False )[:21 ]
135
+ # since the dataset is too large, the smallest 31 files are used for warm up (1 file) and benchmarking (30 files)
136
+ train_files = sorted (train_files , key = lambda x : os .path .getsize (x ), reverse = False )[:31 ]
133
137
weights_path = prepare_model_weights (root_dir = root_dir , bundle_name = "wholeBody_ct_segmentation" )
134
138
trt_model_name = "model_trt.ts"
135
139
trt_model_path = prepare_tensorrt_model (root_dir , weights_path , trt_model_name )
@@ -146,5 +150,6 @@ def main():
146
150
df = pd .DataFrame (list (total_time_dict .items ()), columns = ["file_name" , "time" ])
147
151
df .to_csv (os .path .join (root_dir , f"time_{ args .benchmark_type } .csv" ), index = False )
148
152
153
+
149
154
if __name__ == "__main__" :
150
155
main ()
0 commit comments