Skip to content

Ported torchvision detection tutorial into sphinx gallery format #2540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ download:
wget -nv -N https://www.manythings.org/anki/deu-eng.zip -P $(DATADIR)
unzip -o $(DATADIR)/deu-eng.zip -d beginner_source/data/

# Download PennFudanPed dataset for intermediate_source/torchvision_tutorial.py
wget https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip -P $(DATADIR)
unzip -o $(DATADIR)/PennFudanPed.zip -d intermediate_source/data/

docs:
make download
Expand All @@ -103,3 +106,5 @@ html-noplot:
clean-cache:
make clean
rm -rf advanced beginner intermediate recipes
# remove additional python files downloaded for torchvision_tutorial.py
rm -rf intermediate_source/engine.py intermediate_source/utils.py intermediate_source/transforms.py intermediate_source/coco_eval.py intermediate_source/coco_utils.py
Binary file removed _static/img/tv_tutorial/tv_image01.png
Binary file not shown.
Binary file removed _static/img/tv_tutorial/tv_image02.png
Binary file not shown.
Binary file removed _static/img/tv_tutorial/tv_image05.png
Binary file not shown.
Binary file removed _static/img/tv_tutorial/tv_image06.png
Binary file not shown.
10 changes: 10 additions & 0 deletions en-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ RRef
OOM
subfolder
Dialogs
PennFudan
performant
multithreading
linearities
Expand All @@ -36,6 +37,8 @@ breakpoint
MobileNet
DeepLabV
Resampling
RCNN
RPN
APIs
ATen
AVX
Expand Down Expand Up @@ -145,6 +148,7 @@ LRSchedulers
Lua
Luong
macos
mAP
MLP
MLPs
MNIST
Expand Down Expand Up @@ -178,10 +182,12 @@ OU
PIL
PPO
Plotly
pre
Prec
Profiler
PyTorch's
RGB
RGBA
RL
RNN
RNNs
Expand Down Expand Up @@ -345,6 +351,7 @@ jit
jitter
jpg
judgements
keypoint
kwargs
labelled
learnable
Expand Down Expand Up @@ -425,6 +432,7 @@ reinitializes
relu
reproducibility
rescale
rescaling
resnet
restride
rewinded
Expand Down Expand Up @@ -476,10 +484,12 @@ torchscriptable
torchtext
torchtext's
torchvision
TorchVision
torchviz
traceback
tradeoff
tradeoffs
uint
uncomment
uncommented
underflowing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,10 @@

######################################################################
#
# .. tip::
#
# To get the most of this tutorial, we suggest using this
# `Colab Version <https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/torchvision_finetuning_instance_segmentation.ipynb>`__.
# This will allow you to experiment with the information presented below.
#
#
# For this tutorial, we will be finetuning a pre-trained `Mask
# R-CNN <https://arxiv.org/abs/1703.06870>`__ model on the `Penn-Fudan
# R-CNN <https://arxiv.org/abs/1703.06870>`_ model on the `Penn-Fudan
# Database for Pedestrian Detection and
# Segmentation <https://www.cis.upenn.edu/~jshi/ped_html/>`__. It contains
# Segmentation <https://www.cis.upenn.edu/~jshi/ped_html/>`_. It contains
# 170 images with 345 instances of pedestrians, and we will use it to
# illustrate how to use the new features in torchvision in order to train
# an object detection and instance segmentation model on a custom dataset.
Expand All @@ -35,7 +28,7 @@
# The reference scripts for training object detection, instance
# segmentation and person keypoint detection allows for easily supporting
# adding new custom datasets. The dataset should inherit from the standard
# ``torch.utils.data.Dataset`` class, and implement ``__len__`` and
# :class:`torch.utils.data.Dataset` class, and implement ``__len__`` and
# ``__getitem__``.
#
# The only specificity that we require is that the dataset ``__getitem__``
Expand Down Expand Up @@ -65,7 +58,7 @@
# ``pycocotools`` which can be installed with ``pip install pycocotools``.
#
# .. note ::
# For Windows, please install ``pycocotools`` from `gautamchitnis <https://github.com/gautamchitnis/cocoapi>`__ with command
# For Windows, please install ``pycocotools`` from `gautamchitnis <https://github.com/gautamchitnis/cocoapi>`_ with command
#
# ``pip install git+https://github.com/gautamchitnis/cocoapi.git@cocodataset-master#subdirectory=PythonAPI``
#
Expand All @@ -85,10 +78,16 @@
# Writing a custom dataset for PennFudan
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Let’s write a dataset for the PennFudan dataset. After `downloading and
# extracting the zip
# file <https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip>`__, we
# have the following folder structure:
# Let’s write a dataset for the PennFudan dataset. First, let's download the dataset and
# extract the `zip file <https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip>`_:
#
# .. code:: python
#
# wget https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip -P data
# cd data && unzip PennFudanPed.zip
#
#
# We have the following folder structure:
#
# ::
#
Expand All @@ -106,21 +105,33 @@
# FudanPed00004.png
#
# Here is one example of a pair of images and segmentation masks
#
# .. image:: ../../_static/img/tv_tutorial/tv_image01.png
#
# .. image:: ../../_static/img/tv_tutorial/tv_image02.png
#

import matplotlib.pyplot as plt
from torchvision.io import read_image


image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
mask = read_image("data/PennFudanPed/PedMasks/FudanPed00046_mask.png")

plt.figure(figsize=(16, 8))
plt.subplot(121)
plt.title("Image")
plt.imshow(image.permute(1, 2, 0))
plt.subplot(122)
plt.title("Mask")
plt.imshow(mask.permute(1, 2, 0))

######################################################################
# So each image has a corresponding
# segmentation mask, where each color correspond to a different instance.
# Let’s write a :class:`torch.utils.data.Dataset` class for this dataset.
# In the code below, we are wrapping images, bounding boxes and masks into
# ``torchvision.TVTensor`` classes so that we will be able to apply torchvision
# :class:`torchvision.tv_tensors.TVTensor` classes so that we will be able to apply torchvision
# built-in transformations (`new Transforms API <https://pytorch.org/vision/stable/transforms.html>`_)
# for the given object detection and segmentation task.
# Namely, image tensors will be wrapped by :class:`torchvision.tv_tensors.Image`, bounding boxes into
# :class:`torchvision.tv_tensors.BoundingBoxes` and masks into :class:`torchvision.tv_tensors.Mask`.
# As ``torchvision.TVTensor`` are :class:`torch.Tensor` subclasses, wrapped objects are also tensors and inherit the plain
# As :class:`torchvision.tv_tensors.TVTensor` are :class:`torch.Tensor` subclasses, wrapped objects are also tensors and inherit the plain
# :class:`torch.Tensor` API. For more information about torchvision ``tv_tensors`` see
# `this documentation <https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_getting_started.html#what-are-tvtensors>`_.

Expand Down Expand Up @@ -196,8 +207,8 @@ def __len__(self):
# -------------------
#
# In this tutorial, we will be using `Mask
# R-CNN <https://arxiv.org/abs/1703.06870>`__, which is based on top of
# `Faster R-CNN <https://arxiv.org/abs/1506.01497>`__. Faster R-CNN is a
# R-CNN <https://arxiv.org/abs/1703.06870>`_, which is based on top of
# `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_. Faster R-CNN is a
# model that predicts both bounding boxes and class scores for potential
# objects in the image.
#
Expand Down Expand Up @@ -345,6 +356,7 @@ def get_model_instance_segmentation(num_classes):
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_eval.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/transforms.py")

######################################################################
# Since v0.15.0 torchvision provides `new Transforms API <https://pytorch.org/vision/stable/transforms.html>`_
# to easily write data augmentation pipelines for Object Detection and Segmentation tasks.
#
Expand All @@ -362,7 +374,7 @@ def get_transform(train):
transforms.append(T.ToPureTensor())
return T.Compose(transforms)


######################################################################
# Testing ``forward()`` method (Optional)
# ---------------------------------------
#
Expand Down Expand Up @@ -455,8 +467,8 @@ def get_transform(train):
gamma=0.1
)

# let's train it for 5 epochs
num_epochs = 5
# let's train it just for 2 epochs
num_epochs = 2

for epoch in range(num_epochs):
# train for one epoch, printing every 10 iterations
Expand All @@ -477,14 +489,12 @@ def get_transform(train):
# But what do the predictions look like? Let’s take one image in the
# dataset and verify
#
# .. image:: ../../_static/img/tv_tutorial/tv_image05.png
#
import matplotlib.pyplot as plt

from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks


image = read_image("../_static/img/tv_tutorial/tv_image05.png")
image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
eval_transform = get_transform(train=False)

model.eval()
Expand Down Expand Up @@ -517,7 +527,7 @@ def get_transform(train):
#
# In this tutorial, you have learned how to create your own training
# pipeline for object detection models on a custom dataset. For
# that, you wrote a ``torch.utils.data.Dataset`` class that returns the
# that, you wrote a :class:`torch.utils.data.Dataset` class that returns the
# images and the ground truth boxes and segmentation masks. You also
# leveraged a Mask R-CNN model pre-trained on COCO train2017 in order to
# perform transfer learning on this new dataset.
Expand All @@ -526,5 +536,3 @@ def get_transform(train):
# training, check ``references/detection/train.py``, which is present in
# the torchvision repository.
#
# You can download a full source file for this tutorial
# `here <https://pytorch.org/tutorials/_static/tv-training-code.py>`__.
Loading