|
13 | 13 | # illustrate how to use the new features in torchvision in order to train
|
14 | 14 | # an object detection and instance segmentation model on a custom dataset.
|
15 | 15 | #
|
| 16 | +# |
| 17 | +# .. note :: |
| 18 | +# |
| 19 | +# This tutorial works only with torchvision version >=0.16 or nightly. |
| 20 | +# |
| 21 | +# |
16 | 22 | # Defining the Dataset
|
17 | 23 | # --------------------
|
18 | 24 | #
|
|
106 | 112 | # :class:`torchvision.tv_tensors.BoundingBoxes` and masks into :class:`torchvision.tv_tensors.Mask`.
|
107 | 113 | # As ``torchvision.TVTensor`` are :class:`torch.Tensor` subclasses, wrapped objects are also tensors and inherit the plain
|
108 | 114 | # :class:`torch.Tensor` API. For more information about torchvision ``tv_tensors`` see
|
109 |
| -# `this documentation <https://pytorch.org/vision/main/auto_examples/v2_transforms/plot_transforms_v2.html#sphx-glr-auto-examples-v2-transforms-plot-transforms-v2-py>`_. |
| 115 | +# `this documentation <https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_getting_started.html#what-are-tvtensors>`_. |
110 | 116 |
|
111 | 117 | import os
|
112 | 118 | import torch
|
113 | 119 |
|
114 | 120 | from torchvision.io import read_image
|
115 | 121 | from torchvision.ops.boxes import masks_to_boxes
|
116 |
| -from torchvision import tv_tensors as dp |
| 122 | +from torchvision import tv_tensors |
117 | 123 | from torchvision.transforms.v2 import functional as F
|
118 | 124 |
|
119 | 125 |
|
@@ -154,11 +160,11 @@ def __getitem__(self, idx):
|
154 | 160 | iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
|
155 | 161 |
|
156 | 162 | # Wrap sample and targets into torchvision tv_tensors:
|
157 |
| - img = dp.Image(img) |
| 163 | + img = tv_tensors.Image(img) |
158 | 164 |
|
159 | 165 | target = {}
|
160 |
| - target["boxes"] = dp.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img)) |
161 |
| - target["masks"] = dp.Mask(masks) |
| 166 | + target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img)) |
| 167 | + target["masks"] = tv_tensors.Mask(masks) |
162 | 168 | target["labels"] = labels
|
163 | 169 | target["image_id"] = image_id
|
164 | 170 | target["area"] = area
|
@@ -362,11 +368,13 @@ def get_transform(train):
|
362 | 368 | images = list(image for image in images)
|
363 | 369 | targets = [{k: v for k, v in t.items()} for t in targets]
|
364 | 370 | output = model(images, targets) # Returns losses and detections
|
| 371 | +print(output) |
| 372 | + |
365 | 373 | # For inference
|
366 | 374 | model.eval()
|
367 | 375 | x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
|
368 | 376 | predictions = model(x) # Returns predictions
|
369 |
| -print(predictions[0]["boxes"].shape, predictions[0]["labels"], predictions[0]["scores"]) |
| 377 | +print(predictions[0]) |
370 | 378 |
|
371 | 379 |
|
372 | 380 | ######################################################################
|
@@ -453,13 +461,35 @@ def get_transform(train):
|
453 | 461 | #
|
454 | 462 | # .. image:: ../../_static/img/tv_tutorial/tv_image05.png
|
455 | 463 | #
|
456 |
| -# The trained model predicts 9 |
457 |
| -# instances of person in this image, let’s see a couple of them: |
458 |
| -# |
459 |
| -# .. image:: ../../_static/img/tv_tutorial/tv_image06.png |
460 |
| -# |
461 |
| -# .. image:: ../../_static/img/tv_tutorial/tv_image07.png |
462 |
| -# |
| 464 | +import matplotlib.pyplot as plt |
| 465 | + |
| 466 | +from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks |
| 467 | + |
| 468 | + |
| 469 | +image = read_image("../_static/img/tv_tutorial/tv_image05.png") |
| 470 | +eval_transform = get_transform(train=False) |
| 471 | + |
| 472 | +model.eval() |
| 473 | +with torch.no_grad(): |
| 474 | + x = eval_transform(image) |
| 475 | + # convert RGBA -> RGB and move to device |
| 476 | + x = x[:3, ...].to(device) |
| 477 | + predictions = model([x, ]) |
| 478 | + pred = predictions[0] |
| 479 | + |
| 480 | +image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8) |
| 481 | +image = image[:3, ...] |
| 482 | +pred_labels = [f"pedestrian: {score:.3f}" for label, score in zip(pred["labels"], pred["scores"])] |
| 483 | +pred_boxes = pred["boxes"].long() |
| 484 | +output_image = draw_bounding_boxes(image, pred_boxes, pred_labels, colors="red") |
| 485 | + |
| 486 | +masks = (pred["masks"] > 0.7).squeeze(1) |
| 487 | +output_image = draw_segmentation_masks(output_image, masks, alpha=0.5, colors="blue") |
| 488 | + |
| 489 | +plt.figure() |
| 490 | +plt.imshow(output_image.permute(1, 2, 0)) |
| 491 | + |
| 492 | +###################################################################### |
463 | 493 | # The results look good!
|
464 | 494 | #
|
465 | 495 | # Wrapping up
|
|
0 commit comments