|
109 | 109 | " Invertd,\n",
|
110 | 110 | " Activationsd,\n",
|
111 | 111 | " AsDiscreted,\n",
|
112 |
| - " Compose\n", |
| 112 | + " Compose,\n", |
113 | 113 | ")\n",
|
114 | 114 | "from monai.inferers import sliding_window_inference\n",
|
115 | 115 | "from monai.networks.nets import SegResNet\n",
|
|
316 | 316 | "\n",
|
317 | 317 | " return infer_transforms\n",
|
318 | 318 | "\n",
|
| 319 | + "\n", |
319 | 320 | "def get_post_transforms(infer_transforms):\n",
|
320 | 321 | " post_transforms = Compose(\n",
|
321 | 322 | " [\n",
|
|
332 | 333 | " )\n",
|
333 | 334 | " return post_transforms\n",
|
334 | 335 | "\n",
|
| 336 | + "\n", |
335 | 337 | "def get_model(device, weights_path, trt_model_path, trt_flag=False):\n",
|
336 | 338 | " if not trt_flag:\n",
|
337 | 339 | " model = SegResNet(\n",
|
|
376 | 378 | " data = infer_transforms({\"image\": sample})\n",
|
377 | 379 | "\n",
|
378 | 380 | " with torch.no_grad():\n",
|
379 |
| - " input_image = data[\"image\"].unsqueeze(0).to(device) if benchmark_type in [\"trt\", \"original\"] else data[\"image\"].unsqueeze(0)\n", |
| 381 | + " input_image = (\n", |
| 382 | + " data[\"image\"].unsqueeze(0).to(device)\n", |
| 383 | + " if benchmark_type in [\"trt\", \"original\"]\n", |
| 384 | + " else data[\"image\"].unsqueeze(0)\n", |
| 385 | + " )\n", |
380 | 386 | " if benchmark_type == \"original\":\n",
|
381 | 387 | " with torch.autocast(device_type=\"cuda\"):\n",
|
382 | 388 | " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n",
|
383 | 389 | " else:\n",
|
384 | 390 | " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n",
|
385 |
| - " \n", |
| 391 | + "\n", |
386 | 392 | " data[\"pred\"] = output_image.squeeze(0)\n",
|
387 | 393 | " # data = post_transforms(data)\n",
|
388 |
| - " \n", |
| 394 | + "\n", |
389 | 395 | " end = timer()\n",
|
390 | 396 | "\n",
|
391 | 397 | " sample_name = sample.split(\"/\")[-1]\n",
|
|
0 commit comments