Skip to content

Updating torchvision tutorial using Transforms API v2 #2533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 27 additions & 32 deletions _static/tv-training-code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
# http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

import os
import numpy as np
import torch
from PIL import Image

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import datapoints as dp
from torchvision.transforms.v2 import functional as F
from torchvision.transforms import v2 as T


from engine import train_one_epoch, evaluate
import utils
import transforms as T


class PennFudanDataset(object):
class PennFudanDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
Expand All @@ -28,47 +31,36 @@ def __getitem__(self, idx):
# load images and masks
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = Image.open(img_path).convert("RGB")
# note that we haven't converted the mask to RGB,
# because each color corresponds to a different instance
# with 0 being background
mask = Image.open(mask_path)

mask = np.array(mask)
img = read_image(img_path)
mask = read_image(mask_path)
# instances are encoded as different colors
obj_ids = np.unique(mask)
obj_ids = torch.unique(mask)
# first id is the background, so remove it
obj_ids = obj_ids[1:]
num_objs = len(obj_ids)

# split the color-encoded mask into a set
# of binary masks
masks = mask == obj_ids[:, None, None]
masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)

# get bounding box coordinates for each mask
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
boxes.append([xmin, ymin, xmax, ymax])

boxes = torch.as_tensor(boxes, dtype=torch.float32)
boxes = masks_to_boxes(masks)

# there is only one class
labels = torch.ones((num_objs,), dtype=torch.int64)
masks = torch.as_tensor(masks, dtype=torch.uint8)

image_id = torch.tensor([idx])
image_id = idx
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# suppose all instances are not crowd
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

# Wrap sample and targets into torchvision datapoints:
img = dp.Image(img)

target = {}
target["boxes"] = boxes
target["boxes"] = dp.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
target["masks"] = dp.Mask(masks)
target["labels"] = labels
target["masks"] = masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
Expand All @@ -81,9 +73,10 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.imgs)


def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
Expand All @@ -103,9 +96,11 @@ def get_model_instance_segmentation(num_classes):

def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
transforms.append(T.ToImage())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.ToDtype(torch.float, scale=True))
transforms.append(T.ToPureTensor())
return T.Compose(transforms)


Expand Down Expand Up @@ -160,6 +155,6 @@ def main():
evaluate(model, data_loader_test, device=device)

print("That's it!")

if __name__ == "__main__":
main()
Loading