|
72 | 72 | "!python -c \"import matplotlib\" || pip install -q matplotlib\n",
|
73 | 73 | "!python -c \"import torch_tensorrt\" || pip install torch_tensorrt\n",
|
74 | 74 | "!python -c \"import kvikio\" || pip install kvikio-cu12\n",
|
75 |
| - "!python -c \"import ignite\" || pip install pytorch-ignite\n", |
76 | 75 | "!python -c \"import pandas\" || pip install pandas\n",
|
77 | 76 | "!python -c \"import requests\" || pip install requests\n",
|
78 |
| - "!python -c \"import fire\" || pip install fire\n", |
79 | 77 | "!python -c \"import onnx\" || pip install onnx\n",
|
80 | 78 | "%matplotlib inline"
|
81 | 79 | ]
|
|
106 | 104 | " Spacingd,\n",
|
107 | 105 | " NormalizeIntensityd,\n",
|
108 | 106 | " ScaleIntensityd,\n",
|
109 |
| - " Invertd,\n", |
110 |
| - " Activationsd,\n", |
111 |
| - " AsDiscreted,\n", |
112 | 107 | " Compose,\n",
|
113 | 108 | ")\n",
|
114 | 109 | "from monai.inferers import sliding_window_inference\n",
|
115 | 110 | "from monai.networks.nets import SegResNet\n",
|
| 111 | + "import matplotlib.pyplot as plt\n", |
116 | 112 | "import torch\n",
|
| 113 | + "import gc\n", |
117 | 114 | "import pandas as pd\n",
|
118 | 115 | "from timeit import default_timer as timer\n",
|
119 | 116 | "\n",
|
120 |
| - "print(f\"Torch-TensorRT version: {torch_tensorrt.__version__}.\")\n", |
121 |
| - "\n", |
122 | 117 | "print_config()"
|
123 | 118 | ]
|
124 | 119 | },
|
|
163 | 158 | " precision=\"fp16\",\n",
|
164 | 159 | " input_shape=[1, 1, 96, 96, 96],\n",
|
165 | 160 | " dynamic_batchsize=[1, 1, 1],\n",
|
166 |
| - " use_trace=False,\n", |
167 |
| - " verify=True,\n", |
| 161 | + " use_trace=True,\n", |
| 162 | + " verify=False,\n", |
168 | 163 | ")\n",
|
169 | 164 | "\n",
|
170 | 165 | "save_net_with_metadata(torchscript_model, \"segresnet_trt\")\n",
|
|
236 | 231 | "\n",
|
237 | 232 | "A variable `benchmark_type` is used to specify the type of benchmark to run. To have a fair comparison, each benchmark type should be run after restarting the notebook kernel. `benchmark_type` can be one of the following:\n",
|
238 | 233 | "\n",
|
239 |
| - "- `\"original\"`: benchmark the original model inference (with `amp` enabled).\n", |
| 234 | + "- `\"original\"`: benchmark the original model inference.\n", |
240 | 235 | "- `\"trt\"`: benchmark the TensorRT accelerated model inference.\n",
|
241 |
| - "- `\"trt_gpu_transforms\"`: benchmark the TensorRT accelerated model inference with GPU transforms.\n", |
242 |
| - "- `\"trt_gds_gpu_transforms\"`: benchmark the TensorRT accelerated model inference with GPU data loading and GPU transforms." |
| 236 | + "- `\"trt_gpu_transforms\"`: benchmark the model inference with GPU transforms.\n", |
| 237 | + "- `\"trt_gds_gpu_transforms\"`: benchmark the model inference with GPU data loading and GPU transforms." |
243 | 238 | ]
|
244 | 239 | },
|
245 | 240 | {
|
246 | 241 | "cell_type": "code",
|
247 |
| - "execution_count": 3, |
| 242 | + "execution_count": 4, |
248 | 243 | "metadata": {},
|
249 | 244 | "outputs": [],
|
250 | 245 | "source": [
|
|
276 | 271 | "from utils import prepare_test_datalist, prepare_model_weights, prepare_tensorrt_model\n",
|
277 | 272 | "\n",
|
278 | 273 | "root_dir = \".\"\n",
|
| 274 | + "torch.backends.cudnn.benchmark = True\n", |
| 275 | + "torch_tensorrt.runtime.set_multi_device_safe_mode(True)\n", |
279 | 276 | "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
280 | 277 | "train_files = prepare_test_datalist(root_dir)\n",
|
| 278 | + "# since the dataset is too large, the smallest 21 files are used for warm up (1 file) and benchmarking (11 files)\n", |
| 279 | + "train_files = sorted(train_files, key=lambda x: os.path.getsize(x), reverse=False)[:21]\n", |
281 | 280 | "weights_path = prepare_model_weights(root_dir=root_dir, bundle_name=\"wholeBody_ct_segmentation\")\n",
|
282 | 281 | "trt_model_name = \"model_trt.ts\"\n",
|
283 | 282 | "trt_model_path = prepare_tensorrt_model(root_dir, weights_path, trt_model_name)"
|
|
292 | 291 | },
|
293 | 292 | {
|
294 | 293 | "cell_type": "code",
|
295 |
| - "execution_count": 5, |
| 294 | + "execution_count": 6, |
296 | 295 | "metadata": {},
|
297 | 296 | "outputs": [],
|
298 | 297 | "source": [
|
|
317 | 316 | " return infer_transforms\n",
|
318 | 317 | "\n",
|
319 | 318 | "\n",
|
320 |
| - "def get_post_transforms(infer_transforms):\n", |
321 |
| - " post_transforms = Compose(\n", |
322 |
| - " [\n", |
323 |
| - " Activationsd(keys=\"pred\", softmax=True),\n", |
324 |
| - " AsDiscreted(keys=\"pred\", argmax=True),\n", |
325 |
| - " Invertd(\n", |
326 |
| - " keys=\"pred\",\n", |
327 |
| - " transform=infer_transforms,\n", |
328 |
| - " orig_keys=\"image\",\n", |
329 |
| - " nearest_interp=True,\n", |
330 |
| - " to_tensor=True,\n", |
331 |
| - " ),\n", |
332 |
| - " ]\n", |
333 |
| - " )\n", |
334 |
| - " return post_transforms\n", |
335 |
| - "\n", |
336 |
| - "\n", |
337 | 319 | "def get_model(device, weights_path, trt_model_path, trt_flag=False):\n",
|
338 | 320 | " if not trt_flag:\n",
|
339 | 321 | " model = SegResNet(\n",
|
|
364 | 346 | },
|
365 | 347 | {
|
366 | 348 | "cell_type": "code",
|
367 |
| - "execution_count": 6, |
| 349 | + "execution_count": 7, |
368 | 350 | "metadata": {},
|
369 | 351 | "outputs": [],
|
370 | 352 | "source": [
|
371 |
| - "def run_inference(data_list, infer_transforms, post_transforms, model, device, benchmark_type):\n", |
| 353 | + "def run_inference(data_list, infer_transforms, model, device, benchmark_type):\n", |
372 | 354 | " total_time_dict = {}\n",
|
373 | 355 | " roi_size = (96, 96, 96)\n",
|
374 | 356 | " sw_batch_size = 1\n",
|
375 |
| - "\n", |
376 |
| - " for idx, sample in enumerate(data_list[:5]):\n", |
| 357 | + " \n", |
| 358 | + " for idx, sample in enumerate(data_list[:10]):\n", |
377 | 359 | " start = timer()\n",
|
378 | 360 | " data = infer_transforms({\"image\": sample})\n",
|
379 | 361 | "\n",
|
|
383 | 365 | " if benchmark_type in [\"trt\", \"original\"]\n",
|
384 | 366 | " else data[\"image\"].unsqueeze(0)\n",
|
385 | 367 | " )\n",
|
386 |
| - " if benchmark_type == \"original\":\n", |
387 |
| - " with torch.autocast(device_type=\"cuda\"):\n", |
388 |
| - " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n", |
389 |
| - " else:\n", |
390 |
| - " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n", |
391 | 368 | "\n",
|
392 |
| - " data[\"pred\"] = output_image.squeeze(0)\n", |
393 |
| - " # data = post_transforms(data)\n", |
| 369 | + " output_image = sliding_window_inference(input_image, roi_size, sw_batch_size, model)\n", |
| 370 | + " output_image = output_image.cpu()\n", |
394 | 371 | "\n",
|
395 | 372 | " end = timer()\n",
|
396 | 373 | "\n",
|
| 374 | + " print(output_image.mean())\n", |
| 375 | + "\n", |
| 376 | + " del data\n", |
| 377 | + " del input_image\n", |
| 378 | + " del output_image\n", |
| 379 | + " torch.cuda.empty_cache()\n", |
| 380 | + " gc.collect()\n", |
| 381 | + "\n", |
397 | 382 | " sample_name = sample.split(\"/\")[-1]\n",
|
398 | 383 | " if idx > 0:\n",
|
399 | 384 | " total_time_dict[sample_name] = end - start\n",
|
400 |
| - "\n", |
| 385 | + " print(end - start)\n", |
401 | 386 | " return total_time_dict"
|
402 | 387 | ]
|
403 | 388 | },
|
404 | 389 | {
|
405 | 390 | "cell_type": "markdown",
|
406 | 391 | "metadata": {},
|
407 | 392 | "source": [
|
408 |
| - "## Benchmark the end-to-end bundle inference" |
| 393 | + "### Running the Benchmark\n", |
| 394 | + "\n", |
| 395 | + "The cell below will execute the benchmark based on the `benchmark_type` variable.\n", |
| 396 | + "\n", |
| 397 | + "#### Optional: Using the Python Script\n", |
| 398 | + "\n", |
| 399 | + "For convenience, a Python script, [`run_benchmark.py`](./run_benchmark.py), is available to run the benchmark. You can open a terminal and execute the following command to run the benchmark for all benchmark types:\n", |
| 400 | + "\n", |
| 401 | + "\n", |
| 402 | + "```bash\n", |
| 403 | + "for benchmark_type in \"original\" \"trt\" \"trt_gpu_transforms\" \"trt_gds_gpu_transforms\"; do\n", |
| 404 | + " python run_benchmark.py --benchmark_type \"$benchmark_type\"\n", |
| 405 | + "done\n", |
| 406 | + "```" |
409 | 407 | ]
|
410 | 408 | },
|
411 | 409 | {
|
|
426 | 424 | " gpu_loading_flag = True\n",
|
427 | 425 | "\n",
|
428 | 426 | "infer_transforms = get_transforms(device, gpu_loading_flag, gpu_transforms_flag)\n",
|
429 |
| - "post_transforms = get_post_transforms(infer_transforms)\n", |
430 | 427 | "model = get_model(device, weights_path, trt_model_path, trt_flag)\n",
|
431 | 428 | "\n",
|
432 |
| - "total_time_dict = run_inference(train_files, infer_transforms, post_transforms, model, device, benchmark_type)" |
| 429 | + "total_time_dict = run_inference(train_files, infer_transforms, model, device, benchmark_type)\n", |
| 430 | + "\n", |
| 431 | + "df = pd.DataFrame(list(total_time_dict.items()), columns=[\"file_name\", \"time\"])\n", |
| 432 | + "df.to_csv(os.path.join(root_dir, f\"time_{benchmark_type}.csv\"), index=False)" |
| 433 | + ] |
| 434 | + }, |
| 435 | + { |
| 436 | + "cell_type": "markdown", |
| 437 | + "metadata": {}, |
| 438 | + "source": [ |
| 439 | + "## Analyze and Visualize the Results\n", |
| 440 | + "\n", |
| 441 | + "In this section, we will analyze and visualize the results.\n", |
| 442 | + "All cell outputs presented in this section were obtained by a NVIDIA RTX A6000 GPU." |
| 443 | + ] |
| 444 | + }, |
| 445 | + { |
| 446 | + "cell_type": "markdown", |
| 447 | + "metadata": {}, |
| 448 | + "source": [ |
| 449 | + "### Collect Benchmark Results" |
433 | 450 | ]
|
434 | 451 | },
|
435 | 452 | {
|
436 | 453 | "cell_type": "code",
|
437 |
| - "execution_count": 8, |
| 454 | + "execution_count": 18, |
438 | 455 | "metadata": {},
|
439 | 456 | "outputs": [],
|
440 | 457 | "source": [
|
441 |
| - "df = pd.DataFrame(list(total_time_dict.items()), columns=[\"file_name\", \"time\"])\n", |
442 |
| - "df.to_csv(os.path.join(root_dir, f\"time_{benchmark_type}.csv\"), index=False)" |
| 458 | + "# collect benchmark results\n", |
| 459 | + "all_df = pd.read_csv(os.path.join(root_dir, f\"time_original.csv\"))\n", |
| 460 | + "all_df.columns = [\"file_name\", \"original_time\"]\n", |
| 461 | + "\n", |
| 462 | + "for benchmark_type in [\"trt\", \"trt_gpu_transforms\", \"trt_gds_gpu_transforms\"]:\n", |
| 463 | + " df = pd.read_csv(os.path.join(root_dir, f\"time_{benchmark_type}.csv\"))\n", |
| 464 | + " df.columns = [\"file_name\", f\"{benchmark_type}_time\"]\n", |
| 465 | + " all_df = pd.merge(all_df, df, on=\"file_name\", how=\"left\")\n", |
| 466 | + "\n", |
| 467 | + "# for each file, add it's size\n", |
| 468 | + "all_df[\"file_size\"] = all_df[\"file_name\"].apply(lambda x: os.path.getsize(os.path.join(root_dir, \"Task03_Liver\", \"imagesTs_nii\", x)))\n", |
| 469 | + "# sort by file size\n", |
| 470 | + "all_df = all_df.sort_values(by=\"file_size\", ascending=True)\n", |
| 471 | + "# convert file size to MB\n", |
| 472 | + "all_df[\"file_size\"] = all_df[\"file_size\"] / 1024 / 1024\n", |
| 473 | + "# get the average time for each benchmark type\n", |
| 474 | + "average_time = all_df.mean(numeric_only=True)\n", |
| 475 | + "del average_time[\"file_size\"]" |
| 476 | + ] |
| 477 | + }, |
| 478 | + { |
| 479 | + "cell_type": "markdown", |
| 480 | + "metadata": {}, |
| 481 | + "source": [ |
| 482 | + "### Visualize Average Inference Time for Each Benchmark Type" |
| 483 | + ] |
| 484 | + }, |
| 485 | + { |
| 486 | + "cell_type": "code", |
| 487 | + "execution_count": null, |
| 488 | + "metadata": {}, |
| 489 | + "outputs": [], |
| 490 | + "source": [ |
| 491 | + "plt.figure(figsize=(10, 6))\n", |
| 492 | + "average_time.plot(kind='bar', color=['skyblue', 'orange', 'green', 'red'])\n", |
| 493 | + "plt.title('Average Inference Time for Each Benchmark Type')\n", |
| 494 | + "plt.xlabel('Benchmark Type')\n", |
| 495 | + "plt.ylabel('Average Time (seconds)')\n", |
| 496 | + "plt.xticks(rotation=45)\n", |
| 497 | + "plt.tight_layout()\n", |
| 498 | + "plt.show()" |
443 | 499 | ]
|
| 500 | + }, |
| 501 | + { |
| 502 | + "cell_type": "code", |
| 503 | + "execution_count": null, |
| 504 | + "metadata": {}, |
| 505 | + "outputs": [], |
| 506 | + "source": [] |
444 | 507 | }
|
445 | 508 | ],
|
446 | 509 | "metadata": {
|
|
0 commit comments