Skip to content

Commit 5bd3f67

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 33f9bfc commit 5bd3f67

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

acceleration/fast_inference_tutorial/fast_inference_tutorial.ipynb

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
" Invertd,\n",
110110
" Activationsd,\n",
111111
" AsDiscreted,\n",
112-
" Compose\n",
112+
" Compose,\n",
113113
")\n",
114114
"from monai.inferers import sliding_window_inference\n",
115115
"from monai.networks.nets import SegResNet\n",
@@ -316,6 +316,7 @@
316316
"\n",
317317
" return infer_transforms\n",
318318
"\n",
319+
"\n",
319320
"def get_post_transforms(infer_transforms):\n",
320321
" post_transforms = Compose(\n",
321322
" [\n",
@@ -332,6 +333,7 @@
332333
" )\n",
333334
" return post_transforms\n",
334335
"\n",
336+
"\n",
335337
"def get_model(device, weights_path, trt_model_path, trt_flag=False):\n",
336338
" if not trt_flag:\n",
337339
" model = SegResNet(\n",
@@ -376,16 +378,20 @@
376378
" data = infer_transforms({\"image\": sample})\n",
377379
"\n",
378380
" with torch.no_grad():\n",
379-
" input_image = data[\"image\"].unsqueeze(0).to(device) if benchmark_type in [\"trt\", \"original\"] else data[\"image\"].unsqueeze(0)\n",
381+
" input_image = (\n",
382+
" data[\"image\"].unsqueeze(0).to(device)\n",
383+
" if benchmark_type in [\"trt\", \"original\"]\n",
384+
" else data[\"image\"].unsqueeze(0)\n",
385+
" )\n",
380386
" if benchmark_type == \"original\":\n",
381387
" with torch.autocast(device_type=\"cuda\"):\n",
382388
" output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n",
383389
" else:\n",
384390
" output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n",
385-
" \n",
391+
"\n",
386392
" data[\"pred\"] = output_image.squeeze(0)\n",
387393
" # data = post_transforms(data)\n",
388-
" \n",
394+
"\n",
389395
" end = timer()\n",
390396
"\n",
391397
" sample_name = sample.split(\"/\")[-1]\n",

0 commit comments

Comments
 (0)