Skip to content

Commit f4840a7

Browse files
finalize report
Signed-off-by: Yiheng Wang <vennw@nvidia.com>
1 parent 45b5da3 commit f4840a7

File tree

4 files changed

+156
-45
lines changed

4 files changed

+156
-45
lines changed

acceleration/fast_inference_tutorial/fast_inference_tutorial.ipynb

Lines changed: 143 additions & 39 deletions
Large diffs are not rendered by default.

acceleration/fast_inference_tutorial/run_benchmark.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def get_transforms(device, gpu_loading_flag=False, gpu_transforms_flag=False):
4949

5050
return infer_transforms
5151

52+
5253
def get_post_transforms(infer_transforms):
5354
post_transforms = Compose(
5455
[
@@ -65,6 +66,7 @@ def get_post_transforms(infer_transforms):
6566
)
6667
return post_transforms
6768

69+
6870
def get_model(device, weights_path, trt_model_path, trt_flag=False):
6971
if not trt_flag:
7072
model = SegResNet(
@@ -84,11 +86,12 @@ def get_model(device, weights_path, trt_model_path, trt_flag=False):
8486
model = torch.jit.load(trt_model_path)
8587
return model
8688

89+
8790
def run_inference(data_list, infer_transforms, model, device, benchmark_type):
8891
total_time_dict = {}
8992
roi_size = (96, 96, 96)
90-
sw_batch_size = 1
91-
93+
sw_batch_size = 4
94+
9295
for idx, sample in enumerate(data_list):
9396
start = timer()
9497
data = infer_transforms({"image": sample})
@@ -114,9 +117,10 @@ def run_inference(data_list, infer_transforms, model, device, benchmark_type):
114117
sample_name = sample.split("/")[-1]
115118
if idx > 0:
116119
total_time_dict[sample_name] = end - start
117-
120+
print(f"Time taken for {sample_name}: {end - start} seconds")
118121
return total_time_dict
119122

123+
120124
def main():
121125
parser = argparse.ArgumentParser(description="Run inference benchmark.")
122126
parser.add_argument("--benchmark_type", type=str, default="original", help="Type of benchmark to run")
@@ -128,8 +132,8 @@ def main():
128132
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
129133
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
130134
train_files = prepare_test_datalist(root_dir)
131-
# since the dataset is too large, the smallest 21 files are used for warm up (1 file) and benchmarking (11 files)
132-
train_files = sorted(train_files, key=lambda x: os.path.getsize(x), reverse=False)[:21]
135+
# since the dataset is too large, the smallest 31 files are used for warm up (1 file) and benchmarking (30 files)
136+
train_files = sorted(train_files, key=lambda x: os.path.getsize(x), reverse=False)[:31]
133137
weights_path = prepare_model_weights(root_dir=root_dir, bundle_name="wholeBody_ct_segmentation")
134138
trt_model_name = "model_trt.ts"
135139
trt_model_path = prepare_tensorrt_model(root_dir, weights_path, trt_model_name)
@@ -146,5 +150,6 @@ def main():
146150
df = pd.DataFrame(list(total_time_dict.items()), columns=["file_name", "time"])
147151
df.to_csv(os.path.join(root_dir, f"time_{args.benchmark_type}.csv"), index=False)
148152

153+
149154
if __name__ == "__main__":
150155
main()

acceleration/fast_inference_tutorial/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def prepare_tensorrt_model(root_dir, weights_path, trt_model_name="model_trt.ts"
7878
model=model,
7979
precision="fp16",
8080
input_shape=[1, 1, 96, 96, 96],
81-
dynamic_batchsize=[1, 1, 1],
81+
dynamic_batchsize=[1, 4, 4],
8282
use_trace=True,
8383
verify=False,
8484
)

runner.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" TCIA_PROSTATEx_Pros
7070
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" lazy_resampling_functional.ipynb)
7171
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" lazy_resampling_compose.ipynb)
7272
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" TensorRT_inference_acceleration.ipynb)
73+
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" fast_inference_tutorial.ipynb)
7374
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" lazy_resampling_benchmark.ipynb)
7475
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" modular_patch_inferer.ipynb)
7576
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" GDS_dataset.ipynb)
@@ -117,6 +118,7 @@ skip_run_papermill=("${skip_run_papermill[@]}" .*swinunetr_finetune*)
117118
skip_run_papermill=("${skip_run_papermill[@]}" .*active_learning*)
118119
skip_run_papermill=("${skip_run_papermill[@]}" .*transform_visualization*) # https://github.com/Project-MONAI/tutorials/issues/1155
119120
skip_run_papermill=("${skip_run_papermill[@]}" .*TensorRT_inference_acceleration*)
121+
skip_run_papermill=("${skip_run_papermill[@]}" .*fast_inference_tutorial*)
120122
skip_run_papermill=("${skip_run_papermill[@]}" .*mednist_classifier_ray*) # https://github.com/Project-MONAI/tutorials/issues/1307
121123
skip_run_papermill=("${skip_run_papermill[@]}" .*TorchIO_MONAI_PyTorch_Lightning*) # https://github.com/Project-MONAI/tutorials/issues/1324
122124
skip_run_papermill=("${skip_run_papermill[@]}" .*GDS_dataset*) # https://github.com/Project-MONAI/tutorials/issues/1324

0 commit comments

Comments
 (0)