Skip to content

Commit dc1c24f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 45b5da3 commit dc1c24f

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

acceleration/fast_inference_tutorial/fast_inference_tutorial.ipynb

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@
354354
" total_time_dict = {}\n",
355355
" roi_size = (96, 96, 96)\n",
356356
" sw_batch_size = 1\n",
357-
" \n",
357+
"\n",
358358
" for idx, sample in enumerate(data_list[:10]):\n",
359359
" start = timer()\n",
360360
" data = infer_transforms({\"image\": sample})\n",
@@ -465,7 +465,9 @@
465465
" all_df = pd.merge(all_df, df, on=\"file_name\", how=\"left\")\n",
466466
"\n",
467467
"# for each file, add it's size\n",
468-
"all_df[\"file_size\"] = all_df[\"file_name\"].apply(lambda x: os.path.getsize(os.path.join(root_dir, \"Task03_Liver\", \"imagesTs_nii\", x)))\n",
468+
"all_df[\"file_size\"] = all_df[\"file_name\"].apply(\n",
469+
" lambda x: os.path.getsize(os.path.join(root_dir, \"Task03_Liver\", \"imagesTs_nii\", x))\n",
470+
")\n",
469471
"# sort by file size\n",
470472
"all_df = all_df.sort_values(by=\"file_size\", ascending=True)\n",
471473
"# convert file size to MB\n",
@@ -489,10 +491,10 @@
489491
"outputs": [],
490492
"source": [
491493
"plt.figure(figsize=(10, 6))\n",
492-
"average_time.plot(kind='bar', color=['skyblue', 'orange', 'green', 'red'])\n",
493-
"plt.title('Average Inference Time for Each Benchmark Type')\n",
494-
"plt.xlabel('Benchmark Type')\n",
495-
"plt.ylabel('Average Time (seconds)')\n",
494+
"average_time.plot(kind=\"bar\", color=[\"skyblue\", \"orange\", \"green\", \"red\"])\n",
495+
"plt.title(\"Average Inference Time for Each Benchmark Type\")\n",
496+
"plt.xlabel(\"Benchmark Type\")\n",
497+
"plt.ylabel(\"Average Time (seconds)\")\n",
496498
"plt.xticks(rotation=45)\n",
497499
"plt.tight_layout()\n",
498500
"plt.show()"

acceleration/fast_inference_tutorial/run_benchmark.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,21 @@
2020
import torch_tensorrt
2121
from monai.inferers import sliding_window_inference
2222
from monai.networks.nets import SegResNet
23-
from monai.transforms import (Activationsd, AsDiscreted, Compose,
24-
EnsureChannelFirstd, EnsureTyped, Invertd,
25-
LoadImaged, NormalizeIntensityd, Orientationd,
26-
ScaleIntensityd, Spacingd)
27-
28-
from utils import (prepare_model_weights, prepare_tensorrt_model,
29-
prepare_test_datalist)
23+
from monai.transforms import (
24+
Activationsd,
25+
AsDiscreted,
26+
Compose,
27+
EnsureChannelFirstd,
28+
EnsureTyped,
29+
Invertd,
30+
LoadImaged,
31+
NormalizeIntensityd,
32+
Orientationd,
33+
ScaleIntensityd,
34+
Spacingd,
35+
)
36+
37+
from utils import prepare_model_weights, prepare_tensorrt_model, prepare_test_datalist
3038

3139

3240
def get_transforms(device, gpu_loading_flag=False, gpu_transforms_flag=False):
@@ -49,6 +57,7 @@ def get_transforms(device, gpu_loading_flag=False, gpu_transforms_flag=False):
4957

5058
return infer_transforms
5159

60+
5261
def get_post_transforms(infer_transforms):
5362
post_transforms = Compose(
5463
[
@@ -65,6 +74,7 @@ def get_post_transforms(infer_transforms):
6574
)
6675
return post_transforms
6776

77+
6878
def get_model(device, weights_path, trt_model_path, trt_flag=False):
6979
if not trt_flag:
7080
model = SegResNet(
@@ -84,11 +94,12 @@ def get_model(device, weights_path, trt_model_path, trt_flag=False):
8494
model = torch.jit.load(trt_model_path)
8595
return model
8696

97+
8798
def run_inference(data_list, infer_transforms, model, device, benchmark_type):
8899
total_time_dict = {}
89100
roi_size = (96, 96, 96)
90101
sw_batch_size = 1
91-
102+
92103
for idx, sample in enumerate(data_list):
93104
start = timer()
94105
data = infer_transforms({"image": sample})
@@ -117,6 +128,7 @@ def run_inference(data_list, infer_transforms, model, device, benchmark_type):
117128

118129
return total_time_dict
119130

131+
120132
def main():
121133
parser = argparse.ArgumentParser(description="Run inference benchmark.")
122134
parser.add_argument("--benchmark_type", type=str, default="original", help="Type of benchmark to run")
@@ -146,5 +158,6 @@ def main():
146158
df = pd.DataFrame(list(total_time_dict.items()), columns=["file_name", "time"])
147159
df.to_csv(os.path.join(root_dir, f"time_{args.benchmark_type}.csv"), index=False)
148160

161+
149162
if __name__ == "__main__":
150163
main()

0 commit comments

Comments
 (0)