From 68089abe8bebf6b39c5483e83681023fe9c557dd Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 22 Aug 2023 14:06:05 +0200 Subject: [PATCH 1/3] Updated torchvision_tutorial.rst and _static/tv-training-code.py --- _static/tv-training-code.py | 59 ++++---- intermediate_source/torchvision_tutorial.rst | 148 +++++++++---------- 2 files changed, 98 insertions(+), 109 deletions(-) diff --git a/_static/tv-training-code.py b/_static/tv-training-code.py index 6fa60d0df6e..890f8fe9223 100644 --- a/_static/tv-training-code.py +++ b/_static/tv-training-code.py @@ -2,20 +2,23 @@ # http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html import os -import numpy as np import torch -from PIL import Image import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor +from torchvision.io import read_image +from torchvision.ops.boxes import masks_to_boxes +from torchvision import datapoints as dp +from torchvision.transforms.v2 import functional as F +from torchvision.transforms import v2 as T + from engine import train_one_epoch, evaluate import utils -import transforms as T -class PennFudanDataset(object): +class PennFudanDataset(torch.utils.data.Dataset): def __init__(self, root, transforms): self.root = root self.transforms = transforms @@ -28,47 +31,36 @@ def __getitem__(self, idx): # load images and masks img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) - img = Image.open(img_path).convert("RGB") - # note that we haven't converted the mask to RGB, - # because each color corresponds to a different instance - # with 0 being background - mask = Image.open(mask_path) - - mask = np.array(mask) + img = read_image(img_path) + mask = read_image(mask_path) # instances are encoded as different colors - obj_ids = np.unique(mask) + obj_ids = torch.unique(mask) # first id is the background, so remove it obj_ids = obj_ids[1:] + num_objs = len(obj_ids) # split the color-encoded mask into a set # of binary masks - masks = mask == obj_ids[:, None, None] + masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8) # get bounding box coordinates for each mask - num_objs = len(obj_ids) - boxes = [] - for i in range(num_objs): - pos = np.where(masks[i]) - xmin = np.min(pos[1]) - xmax = np.max(pos[1]) - ymin = np.min(pos[0]) - ymax = np.max(pos[0]) - boxes.append([xmin, ymin, xmax, ymax]) - - boxes = torch.as_tensor(boxes, dtype=torch.float32) + boxes = masks_to_boxes(masks) + # there is only one class labels = torch.ones((num_objs,), dtype=torch.int64) - masks = torch.as_tensor(masks, dtype=torch.uint8) - image_id = torch.tensor([idx]) + image_id = idx area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # suppose all instances are not crowd iscrowd = torch.zeros((num_objs,), dtype=torch.int64) + # Wrap sample and targets into torchvision datapoints: + img = dp.Image(img) + target = {} - target["boxes"] = boxes + target["boxes"] = dp.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img)) + target["masks"] = dp.Mask(masks) target["labels"] = labels - target["masks"] = masks target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd @@ -81,9 +73,10 @@ def __getitem__(self, idx): def __len__(self): return len(self.imgs) + def get_model_instance_segmentation(num_classes): - # load an instance segmentation model pre-trained pre-trained on COCO - model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) + # load an instance segmentation model pre-trained on COCO + model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT") # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features @@ -103,9 +96,11 @@ def get_model_instance_segmentation(num_classes): def get_transform(train): transforms = [] - transforms.append(T.ToTensor()) + transforms.append(T.ToImage()) if train: transforms.append(T.RandomHorizontalFlip(0.5)) + transforms.append(T.ToDtype(torch.float, scale=True)) + transforms.append(T.ToPureTensor()) return T.Compose(transforms) @@ -160,6 +155,6 @@ def main(): evaluate(model, data_loader_test, device=device) print("That's it!") - + if __name__ == "__main__": main() diff --git a/intermediate_source/torchvision_tutorial.rst b/intermediate_source/torchvision_tutorial.rst index 21d47e258f7..c208831d470 100644 --- a/intermediate_source/torchvision_tutorial.rst +++ b/intermediate_source/torchvision_tutorial.rst @@ -2,17 +2,17 @@ TorchVision Object Detection Finetuning Tutorial ==================================================== .. tip:: - To get the most of this tutorial, we suggest using this - `Colab Version `__. - This will allow you to experiment with the information presented below. + To get the most of this tutorial, we suggest using this + `Colab Version `__. + This will allow you to experiment with the information presented below. For this tutorial, we will be finetuning a pre-trained `Mask -R-CNN `__ model in the `Penn-Fudan +R-CNN `__ model on the `Penn-Fudan Database for Pedestrian Detection and Segmentation `__. It contains 170 images with 345 instances of pedestrians, and we will use it to illustrate how to use the new features in torchvision in order to train -an instance segmentation model on a custom dataset. +an object detection model on a custom dataset. Defining the Dataset -------------------- @@ -24,44 +24,39 @@ adding new custom datasets. The dataset should inherit from the standard ``__getitem__``. The only specificity that we require is that the dataset ``__getitem__`` -should return: +should return a tuple: -- image: a PIL Image of size ``(H, W)`` +- image: ``torchvision.datapoints.Image[3, H, W]`` or a PIL Image of size ``(H, W)`` - target: a dict containing the following fields - - ``boxes (FloatTensor[N, 4])``: the coordinates of the ``N`` + - ``boxes (torchvision.datapoints.BoundingBoxes[N, 4])``: the coordinates of the ``N`` bounding boxes in ``[x0, y0, x1, y1]`` format, ranging from ``0`` to ``W`` and ``0`` to ``H`` - ``labels (Int64Tensor[N])``: the label for each bounding box. ``0`` represents always the background class. - - ``image_id (Int64Tensor[1])``: an image identifier. It should be + - ``image_id (int)``: an image identifier. It should be unique between all the images in the dataset, and is used during evaluation - - ``area (Tensor[N])``: The area of the bounding box. This is used + - ``area (Float32Tensor[N])``: The area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes. - ``iscrowd (UInt8Tensor[N])``: instances with iscrowd=True will be ignored during evaluation. - - (optionally) ``masks (UInt8Tensor[N, H, W])``: The segmentation + - (optionally) ``masks (torchvision.datapoints.Mask[N, H, W])``: The segmentation masks for each one of the objects - - (optionally) ``keypoints (FloatTensor[N, K, 3])``: For each one of - the N objects, it contains the K keypoints in - ``[x, y, visibility]`` format, defining the object. visibility=0 - means that the keypoint is not visible. Note that for data - augmentation, the notion of flipping a keypoint is dependent on - the data representation, and you should probably adapt - ``references/detection/transforms.py`` for your new keypoint - representation - -If your model returns the above methods, they will make it work for both -training and evaluation, and will use the evaluation scripts from + +If your dataset is compliant with above requirements then it will work for both +training and evaluation codes from the reference script. Evaluation code will use scripts from ``pycocotools`` which can be installed with ``pip install pycocotools``. .. note :: - For Windows, please install ``pycocotools`` from `gautamchitnis `__ with command + For Windows, please install ``pycocotools`` from `gautamchitnis `__ with command ``pip install git+https://github.com/gautamchitnis/cocoapi.git@cocodataset-master#subdirectory=PythonAPI`` -One note on the ``labels``. The model considers class ``0`` as background. If your dataset does not contain the background class, you should not have ``0`` in your ``labels``. For example, assuming you have just two classes, *cat* and *dog*, you can define ``1`` (not ``0``) to represent *cats* and ``2`` to represent *dogs*. So, for instance, if one of the images has both classes, your ``labels`` tensor should look like ``[1,2]``. +One note on the ``labels``. The model considers class ``0`` as background. If your dataset does not contain the background class, +you should not have ``0`` in your ``labels``. For example, assuming you have just two classes, *cat* and *dog*, you can +define ``1`` (not ``0``) to represent *cats* and ``2`` to represent *dogs*. So, for instance, if one of the images has both +classes, your ``labels`` tensor should look like ``[1,2]``. Additionally, if you want to use aspect ratio grouping during training (so that each batch only contains images with similar aspect ratios), @@ -94,7 +89,7 @@ have the following folder structure: FudanPed00003.png FudanPed00004.png -Here is one example of a pair of images and segmentation masks +Here is one example of a pair of images and segmentation masks .. image:: ../../_static/img/tv_tutorial/tv_image01.png @@ -103,13 +98,21 @@ Here is one example of a pair of images and segmentation masks So each image has a corresponding segmentation mask, where each color correspond to a different instance. Let’s write a ``torch.utils.data.Dataset`` class for this dataset. +In the code below, we are wrapping images, bounding boxes and masks into +``torchvision.datapoints`` structures so that we will be able to apply torchvision +built-in transformations (`new Transforms API `_) +that cover the object detection and segmentation tasks. +For more information about torchvision datapoints see `this documentation `_. .. code:: python import os - import numpy as np import torch - from PIL import Image + + from torchvision.io import read_image + from torchvision.ops.boxes import masks_to_boxes + from torchvision import datapoints as dp + from torchvision.transforms.v2 import functional as F class PennFudanDataset(torch.utils.data.Dataset): @@ -125,48 +128,36 @@ Let’s write a ``torch.utils.data.Dataset`` class for this dataset. # load images and masks img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) - img = Image.open(img_path).convert("RGB") - # note that we haven't converted the mask to RGB, - # because each color corresponds to a different instance - # with 0 being background - mask = Image.open(mask_path) - # convert the PIL Image into a numpy array - mask = np.array(mask) + img = read_image(img_path) + mask = read_image(mask_path) # instances are encoded as different colors - obj_ids = np.unique(mask) + obj_ids = torch.unique(mask) # first id is the background, so remove it obj_ids = obj_ids[1:] + num_objs = len(obj_ids) # split the color-encoded mask into a set # of binary masks - masks = mask == obj_ids[:, None, None] + masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8) # get bounding box coordinates for each mask - num_objs = len(obj_ids) - boxes = [] - for i in range(num_objs): - pos = np.nonzero(masks[i]) - xmin = np.min(pos[1]) - xmax = np.max(pos[1]) - ymin = np.min(pos[0]) - ymax = np.max(pos[0]) - boxes.append([xmin, ymin, xmax, ymax]) - - # convert everything into a torch.Tensor - boxes = torch.as_tensor(boxes, dtype=torch.float32) + boxes = masks_to_boxes(masks) + # there is only one class labels = torch.ones((num_objs,), dtype=torch.int64) - masks = torch.as_tensor(masks, dtype=torch.uint8) image_id = torch.tensor([idx]) area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # suppose all instances are not crowd iscrowd = torch.zeros((num_objs,), dtype=torch.int64) + # Wrap sample and targets into torchvision datapoints: + img = dp.Image(img) + target = {} - target["boxes"] = boxes + target["boxes"] = dp.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img)) + target["masks"] = dp.Mask(masks) target["labels"] = labels - target["masks"] = masks target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd @@ -189,7 +180,7 @@ In this tutorial, we will be using `Mask R-CNN `__, which is based on top of `Faster R-CNN `__. Faster R-CNN is a model that predicts both bounding boxes and class scores for potential -objects in the image. +objects in the image. .. image:: ../../_static/img/tv_tutorial/tv_image03.png @@ -199,7 +190,7 @@ instance. .. image:: ../../_static/img/tv_tutorial/tv_image04.png -There are two common +There are two common situations where one might want to modify one of the available models in torchvision modelzoo. The first is when we want to start from a pre-trained model, and just finetune the @@ -229,7 +220,7 @@ way of doing it: # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one - model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 2 - Modifying the model to add a different backbone ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -252,7 +243,7 @@ way of doing it: # location, with 5 different sizes and 3 different aspect # ratios. We have a Tuple[Tuple[int]] because each feature # map could potentially have different sizes and - # aspect ratios + # aspect ratios anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)) @@ -273,10 +264,10 @@ way of doing it: rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler) -An Instance segmentation model for PennFudan Dataset -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Object detection model for PennFudan Dataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In our case, we want to fine-tune from a pre-trained model, given that +In our case, we want to finetune from a pre-trained model, given that our dataset is very small, so we will be following approach number 1. Here we want to also compute the instance segmentation masks, so we will @@ -316,30 +307,33 @@ Putting everything together In ``references/detection/``, we have a number of helper functions to simplify training and evaluating detection models. Here, we will use -``references/detection/engine.py``, ``references/detection/utils.py`` -and ``references/detection/transforms.py``. Just copy everything under -``references/detection`` to your folder and use them here. +``references/detection/engine.py`` and ``references/detection/utils.py``. +Just copy everything under ``references/detection`` to your folder and use them here. + +Since v0.15.0 torchvision provides `new Transforms API `_ +to easily write data augmentation pipelines for Object Detection and Segmentation tasks. Let’s write some helper functions for data augmentation / transformation: .. code:: python - import transforms as T + from torchvision.transforms import v2 as T + def get_transform(train): transforms = [] - transforms.append(T.PILToTensor()) - transforms.append(T.ConvertImageDtype(torch.float)) if train: - transforms.append(T.RandomHorizontalFlip(0.5)) + transforms.append(T.RandomHorizontalFlip(0.5)) + transforms.append(T.ToDtype(torch.float, scale=True)) + transforms.append(T.ToPureTensor()) return T.Compose(transforms) Testing ``forward()`` method (Optional) --------------------------------------- -Before iterating over the dataset, it's good to see what the model +Before iterating over the dataset, it's good to see what the model expects during training and inference time on sample data. .. code:: python @@ -347,17 +341,17 @@ expects during training and inference time on sample data. model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT") dataset = PennFudanDataset('PennFudanPed', get_transform(train=True)) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=2, shuffle=True, num_workers=4, - collate_fn=utils.collate_fn) + dataset, batch_size=2, shuffle=True, num_workers=4, + collate_fn=utils.collate_fn) # For Training - images,targets = next(iter(data_loader)) + images, targets = next(iter(data_loader)) images = list(image for image in images) targets = [{k: v for k, v in t.items()} for t in targets] - output = model(images,targets) # Returns losses and detections + output = model(images, targets) # Returns losses and detections # For inference model.eval() x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] - predictions = model(x) # Returns predictions + predictions = model(x) # Returns predictions Let’s now write the main function which performs the training and the validation: @@ -504,12 +498,12 @@ After training for 10 epochs, I got the following metrics Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.818 But what do the predictions look like? Let’s take one image in the -dataset and verify +dataset and verify .. image:: ../../_static/img/tv_tutorial/tv_image05.png The trained model predicts 9 -instances of person in this image, let’s see a couple of them: +instances of person in this image, let’s see a couple of them: .. image:: ../../_static/img/tv_tutorial/tv_image06.png @@ -521,7 +515,7 @@ Wrapping up ----------- In this tutorial, you have learned how to create your own training -pipeline for instance segmentation models, on a custom dataset. For +pipeline for object detection models on a custom dataset. For that, you wrote a ``torch.utils.data.Dataset`` class that returns the images and the ground truth boxes and segmentation masks. You also leveraged a Mask R-CNN model pre-trained on COCO train2017 in order to @@ -531,7 +525,7 @@ For a more complete example, which includes multi-machine / multi-gpu training, check ``references/detection/train.py``, which is present in the torchvision repo. -You can download a full source file for this tutorial -`here `__. - +You can download a full source file for this tutorial +`here `__. + From 25efd1995787e6affbe79c44e78bc55c74528f3e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 24 Aug 2023 09:27:11 +0200 Subject: [PATCH 2/3] Addressed review comments --- intermediate_source/torchvision_tutorial.rst | 35 +++++++++++--------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/intermediate_source/torchvision_tutorial.rst b/intermediate_source/torchvision_tutorial.rst index c208831d470..939a404aed1 100644 --- a/intermediate_source/torchvision_tutorial.rst +++ b/intermediate_source/torchvision_tutorial.rst @@ -12,7 +12,7 @@ Database for Pedestrian Detection and Segmentation `__. It contains 170 images with 345 instances of pedestrians, and we will use it to illustrate how to use the new features in torchvision in order to train -an object detection model on a custom dataset. +an object detection and instance segmentation model on a custom dataset. Defining the Dataset -------------------- @@ -26,22 +26,23 @@ adding new custom datasets. The dataset should inherit from the standard The only specificity that we require is that the dataset ``__getitem__`` should return a tuple: -- image: ``torchvision.datapoints.Image[3, H, W]`` or a PIL Image of size ``(H, W)`` +- image: :class:`torchvision.datapoints.Image` of shape ``[3, H, W]`` or a PIL Image of size ``(H, W)`` - target: a dict containing the following fields - - ``boxes (torchvision.datapoints.BoundingBoxes[N, 4])``: the coordinates of the ``N`` - bounding boxes in ``[x0, y0, x1, y1]`` format, ranging from ``0`` + - ``boxes``, :class:`torchvision.datapoints.BoundingBoxes` of shape ``[N, 4]``: + the coordinates of the ``N`` bounding boxes in ``[x0, y0, x1, y1]`` format, ranging from ``0`` to ``W`` and ``0`` to ``H`` - - ``labels (Int64Tensor[N])``: the label for each bounding box. ``0`` represents always the background class. - - ``image_id (int)``: an image identifier. It should be + - ``labels``, integer :class:`torch.Tensor` of shape ``[N]``: the label for each bounding box. + ``0`` represents always the background class. + - ``image_id``, int: an image identifier. It should be unique between all the images in the dataset, and is used during evaluation - - ``area (Float32Tensor[N])``: The area of the bounding box. This is used + - ``area``, float :class:`torch.Tensor` of shape ``[N]``: the area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes. - - ``iscrowd (UInt8Tensor[N])``: instances with iscrowd=True will be + - ``iscrowd``, uint8 :class:`torch.Tensor` of shape ``[N]``: instances with iscrowd=True will be ignored during evaluation. - - (optionally) ``masks (torchvision.datapoints.Mask[N, H, W])``: The segmentation + - (optionally) ``masks``, :class:`torchvision.datapoints.Mask` of shape ``[N, H, W]``: the segmentation masks for each one of the objects If your dataset is compliant with above requirements then it will work for both @@ -97,12 +98,16 @@ Here is one example of a pair of images and segmentation masks So each image has a corresponding segmentation mask, where each color correspond to a different instance. -Let’s write a ``torch.utils.data.Dataset`` class for this dataset. +Let’s write a :class:`torch.utils.data.Dataset` class for this dataset. In the code below, we are wrapping images, bounding boxes and masks into -``torchvision.datapoints`` structures so that we will be able to apply torchvision +``torchvision.datapoints`` classes so that we will be able to apply torchvision built-in transformations (`new Transforms API `_) -that cover the object detection and segmentation tasks. -For more information about torchvision datapoints see `this documentation `_. +for the given object detection and segmentation task. +Namely, image tensors will be wrapped by :class:`torchvision.datapoints.Image`, bounding boxes into +:class:`torchvision.datapoints.BoundingBoxes` and masks into :class:`torchvision.datapoints.Mask`. +As datapoints are :class:`torch.Tensor` subclasses, wrapped objects are also tensors and inherit plain +:class:`torch.Tensor` API. For more information about torchvision datapoints see +`this documentation `_. .. code:: python @@ -264,8 +269,8 @@ way of doing it: rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler) -Object detection model for PennFudan Dataset -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Object detection and instance segmentation model for PennFudan Dataset +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In our case, we want to finetune from a pre-trained model, given that our dataset is very small, so we will be following approach number 1. From 615bf83118fd6a5f666c59ce0e50dca9c08287a9 Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 24 Aug 2023 14:54:49 +0200 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Nicolas Hug --- intermediate_source/torchvision_tutorial.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/intermediate_source/torchvision_tutorial.rst b/intermediate_source/torchvision_tutorial.rst index 939a404aed1..42cae8c4b7f 100644 --- a/intermediate_source/torchvision_tutorial.rst +++ b/intermediate_source/torchvision_tutorial.rst @@ -26,7 +26,7 @@ adding new custom datasets. The dataset should inherit from the standard The only specificity that we require is that the dataset ``__getitem__`` should return a tuple: -- image: :class:`torchvision.datapoints.Image` of shape ``[3, H, W]`` or a PIL Image of size ``(H, W)`` +- image: :class:`torchvision.datapoints.Image` of shape ``[3, H, W]``, a pure tensor, or a PIL Image of size ``(H, W)`` - target: a dict containing the following fields - ``boxes``, :class:`torchvision.datapoints.BoundingBoxes` of shape ``[N, 4]``: @@ -105,7 +105,7 @@ built-in transformations (`new Transforms API `_.