Skip to content

Commit b538364

Browse files
committed
Updated torchvision_tutorial.rst and _static/tv-training-code.py
1 parent e7c86fd commit b538364

File tree

2 files changed

+84
-86
lines changed

2 files changed

+84
-86
lines changed

_static/tv-training-code.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,23 @@
22
# http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
33

44
import os
5-
import numpy as np
65
import torch
7-
from PIL import Image
86

97
import torchvision
108
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
119
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
10+
from torchvision.io import read_image
11+
from torchvision.ops.boxes import masks_to_boxes
12+
from torchvision import datapoints as dp
13+
from torchvision.transforms.v2 import functional as F
14+
from torchvision.transforms import v2 as T
15+
1216

1317
from engine import train_one_epoch, evaluate
1418
import utils
15-
import transforms as T
1619

1720

18-
class PennFudanDataset(object):
21+
class PennFudanDataset(torch.utils.data.Dataset):
1922
def __init__(self, root, transforms):
2023
self.root = root
2124
self.transforms = transforms
@@ -28,47 +31,36 @@ def __getitem__(self, idx):
2831
# load images and masks
2932
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
3033
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
31-
img = Image.open(img_path).convert("RGB")
32-
# note that we haven't converted the mask to RGB,
33-
# because each color corresponds to a different instance
34-
# with 0 being background
35-
mask = Image.open(mask_path)
36-
37-
mask = np.array(mask)
34+
img = read_image(img_path)
35+
mask = read_image(mask_path)
3836
# instances are encoded as different colors
39-
obj_ids = np.unique(mask)
37+
obj_ids = torch.unique(mask)
4038
# first id is the background, so remove it
4139
obj_ids = obj_ids[1:]
40+
num_objs = len(obj_ids)
4241

4342
# split the color-encoded mask into a set
4443
# of binary masks
45-
masks = mask == obj_ids[:, None, None]
44+
masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)
4645

4746
# get bounding box coordinates for each mask
48-
num_objs = len(obj_ids)
49-
boxes = []
50-
for i in range(num_objs):
51-
pos = np.where(masks[i])
52-
xmin = np.min(pos[1])
53-
xmax = np.max(pos[1])
54-
ymin = np.min(pos[0])
55-
ymax = np.max(pos[0])
56-
boxes.append([xmin, ymin, xmax, ymax])
57-
58-
boxes = torch.as_tensor(boxes, dtype=torch.float32)
47+
boxes = masks_to_boxes(masks)
48+
5949
# there is only one class
6050
labels = torch.ones((num_objs,), dtype=torch.int64)
61-
masks = torch.as_tensor(masks, dtype=torch.uint8)
6251

63-
image_id = torch.tensor([idx])
52+
image_id = idx
6453
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
6554
# suppose all instances are not crowd
6655
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
6756

57+
# Wrap sample and targets into torchvision datapoints:
58+
img = dp.Image(img)
59+
6860
target = {}
69-
target["boxes"] = boxes
61+
target["boxes"] = dp.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
62+
target["masks"] = dp.Mask(masks)
7063
target["labels"] = labels
71-
target["masks"] = masks
7264
target["image_id"] = image_id
7365
target["area"] = area
7466
target["iscrowd"] = iscrowd
@@ -81,9 +73,10 @@ def __getitem__(self, idx):
8173
def __len__(self):
8274
return len(self.imgs)
8375

76+
8477
def get_model_instance_segmentation(num_classes):
85-
# load an instance segmentation model pre-trained pre-trained on COCO
86-
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
78+
# load an instance segmentation model pre-trained on COCO
79+
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
8780

8881
# get number of input features for the classifier
8982
in_features = model.roi_heads.box_predictor.cls_score.in_features
@@ -103,9 +96,11 @@ def get_model_instance_segmentation(num_classes):
10396

10497
def get_transform(train):
10598
transforms = []
106-
transforms.append(T.ToTensor())
99+
transforms.append(T.ToImage())
107100
if train:
108101
transforms.append(T.RandomHorizontalFlip(0.5))
102+
transforms.append(T.ToDtype(torch.float, scale=True))
103+
transforms.append(T.ToPureTensor())
109104
return T.Compose(transforms)
110105

111106

@@ -160,6 +155,6 @@ def main():
160155
evaluate(model, data_loader_test, device=device)
161156

162157
print("That's it!")
163-
158+
164159
if __name__ == "__main__":
165160
main()

intermediate_source/torchvision_tutorial.rst

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ TorchVision Object Detection Finetuning Tutorial
22
====================================================
33

44
.. tip::
5-
To get the most of this tutorial, we suggest using this
6-
`Colab Version <https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/torchvision_finetuning_instance_segmentation.ipynb>`__.
7-
This will allow you to experiment with the information presented below.
5+
To get the most of this tutorial, we suggest using this
6+
`Colab Version <https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/torchvision_finetuning_instance_segmentation.ipynb>`__.
7+
This will allow you to experiment with the information presented below.
88

99
For this tutorial, we will be finetuning a pre-trained `Mask
1010
R-CNN <https://arxiv.org/abs/1703.06870>`__ model in the `Penn-Fudan
@@ -57,11 +57,14 @@ training and evaluation, and will use the evaluation scripts from
5757
``pycocotools`` which can be installed with ``pip install pycocotools``.
5858

5959
.. note ::
60-
For Windows, please install ``pycocotools`` from `gautamchitnis <https://github.com/gautamchitnis/cocoapi>`__ with command
60+
For Windows, please install ``pycocotools`` from `gautamchitnis <https://github.com/gautamchitnis/cocoapi>`__ with command
6161
6262
``pip install git+https://github.com/gautamchitnis/cocoapi.git@cocodataset-master#subdirectory=PythonAPI``
6363
64-
One note on the ``labels``. The model considers class ``0`` as background. If your dataset does not contain the background class, you should not have ``0`` in your ``labels``. For example, assuming you have just two classes, *cat* and *dog*, you can define ``1`` (not ``0``) to represent *cats* and ``2`` to represent *dogs*. So, for instance, if one of the images has both classes, your ``labels`` tensor should look like ``[1,2]``.
64+
One note on the ``labels``. The model considers class ``0`` as background. If your dataset does not contain the background class,
65+
you should not have ``0`` in your ``labels``. For example, assuming you have just two classes, *cat* and *dog*, you can
66+
define ``1`` (not ``0``) to represent *cats* and ``2`` to represent *dogs*. So, for instance, if one of the images has both
67+
classes, your ``labels`` tensor should look like ``[1,2]``.
6568

6669
Additionally, if you want to use aspect ratio grouping during training
6770
(so that each batch only contains images with similar aspect ratios),
@@ -94,7 +97,7 @@ have the following folder structure:
9497
FudanPed00003.png
9598
FudanPed00004.png
9699

97-
Here is one example of a pair of images and segmentation masks
100+
Here is one example of a pair of images and segmentation masks
98101

99102
.. image:: ../../_static/img/tv_tutorial/tv_image01.png
100103

@@ -103,13 +106,21 @@ Here is one example of a pair of images and segmentation masks
103106
So each image has a corresponding
104107
segmentation mask, where each color correspond to a different instance.
105108
Let’s write a ``torch.utils.data.Dataset`` class for this dataset.
109+
In the code below, we are wrapping images, bounding boxes and masks into
110+
``torchvision.datapoints`` structures so that we will be able to apply torchvision
111+
built-in transformations (`new Transforms API <https://pytorch.org/vision/stable/transforms.html>`_)
112+
that cover the object detection and segmentation tasks.
113+
For more information about torchvision datapoints see `this documentation <https://pytorch.org/vision/stable/datapoints.html>`_.
106114

107115
.. code:: python
108116
109117
import os
110-
import numpy as np
111118
import torch
112-
from PIL import Image
119+
120+
from torchvision.io import read_image
121+
from torchvision.ops.boxes import masks_to_boxes
122+
from torchvision import datapoints as dp
123+
from torchvision.transforms.v2 import functional as F
113124
114125
115126
class PennFudanDataset(torch.utils.data.Dataset):
@@ -125,48 +136,36 @@ Let’s write a ``torch.utils.data.Dataset`` class for this dataset.
125136
# load images and masks
126137
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
127138
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
128-
img = Image.open(img_path).convert("RGB")
129-
# note that we haven't converted the mask to RGB,
130-
# because each color corresponds to a different instance
131-
# with 0 being background
132-
mask = Image.open(mask_path)
133-
# convert the PIL Image into a numpy array
134-
mask = np.array(mask)
139+
img = read_image(img_path)
140+
mask = read_image(mask_path)
135141
# instances are encoded as different colors
136-
obj_ids = np.unique(mask)
142+
obj_ids = torch.unique(mask)
137143
# first id is the background, so remove it
138144
obj_ids = obj_ids[1:]
145+
num_objs = len(obj_ids)
139146
140147
# split the color-encoded mask into a set
141148
# of binary masks
142-
masks = mask == obj_ids[:, None, None]
149+
masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)
143150
144151
# get bounding box coordinates for each mask
145-
num_objs = len(obj_ids)
146-
boxes = []
147-
for i in range(num_objs):
148-
pos = np.nonzero(masks[i])
149-
xmin = np.min(pos[1])
150-
xmax = np.max(pos[1])
151-
ymin = np.min(pos[0])
152-
ymax = np.max(pos[0])
153-
boxes.append([xmin, ymin, xmax, ymax])
154-
155-
# convert everything into a torch.Tensor
156-
boxes = torch.as_tensor(boxes, dtype=torch.float32)
152+
boxes = masks_to_boxes(masks)
153+
157154
# there is only one class
158155
labels = torch.ones((num_objs,), dtype=torch.int64)
159-
masks = torch.as_tensor(masks, dtype=torch.uint8)
160156
161157
image_id = torch.tensor([idx])
162158
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
163159
# suppose all instances are not crowd
164160
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
165161
162+
# Wrap sample and targets into torchvision datapoints:
163+
img = dp.Image(img)
164+
166165
target = {}
167-
target["boxes"] = boxes
166+
target["boxes"] = dp.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
167+
target["masks"] = dp.Mask(masks)
168168
target["labels"] = labels
169-
target["masks"] = masks
170169
target["image_id"] = image_id
171170
target["area"] = area
172171
target["iscrowd"] = iscrowd
@@ -189,7 +188,7 @@ In this tutorial, we will be using `Mask
189188
R-CNN <https://arxiv.org/abs/1703.06870>`__, which is based on top of
190189
`Faster R-CNN <https://arxiv.org/abs/1506.01497>`__. Faster R-CNN is a
191190
model that predicts both bounding boxes and class scores for potential
192-
objects in the image.
191+
objects in the image.
193192

194193
.. image:: ../../_static/img/tv_tutorial/tv_image03.png
195194

@@ -199,7 +198,7 @@ instance.
199198

200199
.. image:: ../../_static/img/tv_tutorial/tv_image04.png
201200

202-
There are two common
201+
There are two common
203202
situations where one might want
204203
to modify one of the available models in torchvision modelzoo. The first
205204
is when we want to start from a pre-trained model, and just finetune the
@@ -229,7 +228,7 @@ way of doing it:
229228
# get number of input features for the classifier
230229
in_features = model.roi_heads.box_predictor.cls_score.in_features
231230
# replace the pre-trained head with a new one
232-
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
231+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
233232
234233
2 - Modifying the model to add a different backbone
235234
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -252,7 +251,7 @@ way of doing it:
252251
# location, with 5 different sizes and 3 different aspect
253252
# ratios. We have a Tuple[Tuple[int]] because each feature
254253
# map could potentially have different sizes and
255-
# aspect ratios
254+
# aspect ratios
256255
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
257256
aspect_ratios=((0.5, 1.0, 2.0),))
258257
@@ -316,48 +315,52 @@ Putting everything together
316315

317316
In ``references/detection/``, we have a number of helper functions to
318317
simplify training and evaluating detection models. Here, we will use
319-
``references/detection/engine.py``, ``references/detection/utils.py``
320-
and ``references/detection/transforms.py``. Just copy everything under
321-
``references/detection`` to your folder and use them here.
318+
``references/detection/engine.py`` and ``references/detection/utils.py``.
319+
Just copy everything under ``references/detection`` to your folder and use them here.
320+
321+
Since v0.15.0 torchvision provides `new Transforms API <https://pytorch.org/vision/stable/transforms.html>`_
322+
to easily write data augmentation pipelines for Object Detection and Segmentation tasks.
322323

323324
Let’s write some helper functions for data augmentation /
324325
transformation:
325326

326327
.. code:: python
327328
328-
import transforms as T
329+
from torchvision.transforms import v2 as T
330+
329331
330332
def get_transform(train):
331333
transforms = []
332-
transforms.append(T.PILToTensor())
333-
transforms.append(T.ConvertImageDtype(torch.float))
334+
transforms.append(T.ToImage())
334335
if train:
335-
transforms.append(T.RandomHorizontalFlip(0.5))
336+
transforms.append(T.RandomHorizontalFlip(0.5))
337+
transforms.append(T.ToDtype(torch.float, scale=True))
338+
transforms.append(T.ToPureTensor())
336339
return T.Compose(transforms)
337340
338341
339342
Testing ``forward()`` method (Optional)
340343
---------------------------------------
341344

342-
Before iterating over the dataset, it's good to see what the model
345+
Before iterating over the dataset, it's good to see what the model
343346
expects during training and inference time on sample data.
344347

345348
.. code:: python
346349
347350
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
348351
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
349352
data_loader = torch.utils.data.DataLoader(
350-
dataset, batch_size=2, shuffle=True, num_workers=4,
351-
collate_fn=utils.collate_fn)
353+
dataset, batch_size=2, shuffle=True, num_workers=4,
354+
collate_fn=utils.collate_fn)
352355
# For Training
353-
images,targets = next(iter(data_loader))
356+
images, targets = next(iter(data_loader))
354357
images = list(image for image in images)
355358
targets = [{k: v for k, v in t.items()} for t in targets]
356-
output = model(images,targets) # Returns losses and detections
359+
output = model(images, targets) # Returns losses and detections
357360
# For inference
358361
model.eval()
359362
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
360-
predictions = model(x) # Returns predictions
363+
predictions = model(x) # Returns predictions
361364
362365
Let’s now write the main function which performs the training and the
363366
validation:
@@ -504,12 +507,12 @@ After training for 10 epochs, I got the following metrics
504507
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.818
505508

506509
But what do the predictions look like? Let’s take one image in the
507-
dataset and verify
510+
dataset and verify
508511

509512
.. image:: ../../_static/img/tv_tutorial/tv_image05.png
510513

511514
The trained model predicts 9
512-
instances of person in this image, let’s see a couple of them:
515+
instances of person in this image, let’s see a couple of them:
513516

514517
.. image:: ../../_static/img/tv_tutorial/tv_image06.png
515518

@@ -531,7 +534,7 @@ For a more complete example, which includes multi-machine / multi-gpu
531534
training, check ``references/detection/train.py``, which is present in
532535
the torchvision repo.
533536

534-
You can download a full source file for this tutorial
535-
`here <https://pytorch.org/tutorials/_static/tv-training-code.py>`__.
536-
537+
You can download a full source file for this tutorial
538+
`here <https://pytorch.org/tutorials/_static/tv-training-code.py>`__.
539+
537540

0 commit comments

Comments
 (0)