Skip to content

Commit 169319e

Browse files
author
Svetlana Karslioglu
authored
Merge branch 'main' into tb-profiler-tutorial-docs-update
2 parents 6de6381 + 6b31dd0 commit 169319e

13 files changed

+951
-140
lines changed
Loading
Loading
Loading
Loading
Loading
Loading

_static/tv-training-code.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,23 @@
22
# http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
33

44
import os
5-
import numpy as np
65
import torch
7-
from PIL import Image
86

97
import torchvision
108
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
119
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
10+
from torchvision.io import read_image
11+
from torchvision.ops.boxes import masks_to_boxes
12+
from torchvision import datapoints as dp
13+
from torchvision.transforms.v2 import functional as F
14+
from torchvision.transforms import v2 as T
15+
1216

1317
from engine import train_one_epoch, evaluate
1418
import utils
15-
import transforms as T
1619

1720

18-
class PennFudanDataset(object):
21+
class PennFudanDataset(torch.utils.data.Dataset):
1922
def __init__(self, root, transforms):
2023
self.root = root
2124
self.transforms = transforms
@@ -28,47 +31,36 @@ def __getitem__(self, idx):
2831
# load images and masks
2932
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
3033
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
31-
img = Image.open(img_path).convert("RGB")
32-
# note that we haven't converted the mask to RGB,
33-
# because each color corresponds to a different instance
34-
# with 0 being background
35-
mask = Image.open(mask_path)
36-
37-
mask = np.array(mask)
34+
img = read_image(img_path)
35+
mask = read_image(mask_path)
3836
# instances are encoded as different colors
39-
obj_ids = np.unique(mask)
37+
obj_ids = torch.unique(mask)
4038
# first id is the background, so remove it
4139
obj_ids = obj_ids[1:]
40+
num_objs = len(obj_ids)
4241

4342
# split the color-encoded mask into a set
4443
# of binary masks
45-
masks = mask == obj_ids[:, None, None]
44+
masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)
4645

4746
# get bounding box coordinates for each mask
48-
num_objs = len(obj_ids)
49-
boxes = []
50-
for i in range(num_objs):
51-
pos = np.where(masks[i])
52-
xmin = np.min(pos[1])
53-
xmax = np.max(pos[1])
54-
ymin = np.min(pos[0])
55-
ymax = np.max(pos[0])
56-
boxes.append([xmin, ymin, xmax, ymax])
57-
58-
boxes = torch.as_tensor(boxes, dtype=torch.float32)
47+
boxes = masks_to_boxes(masks)
48+
5949
# there is only one class
6050
labels = torch.ones((num_objs,), dtype=torch.int64)
61-
masks = torch.as_tensor(masks, dtype=torch.uint8)
6251

63-
image_id = torch.tensor([idx])
52+
image_id = idx
6453
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
6554
# suppose all instances are not crowd
6655
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
6756

57+
# Wrap sample and targets into torchvision datapoints:
58+
img = dp.Image(img)
59+
6860
target = {}
69-
target["boxes"] = boxes
61+
target["boxes"] = dp.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
62+
target["masks"] = dp.Mask(masks)
7063
target["labels"] = labels
71-
target["masks"] = masks
7264
target["image_id"] = image_id
7365
target["area"] = area
7466
target["iscrowd"] = iscrowd
@@ -81,9 +73,10 @@ def __getitem__(self, idx):
8173
def __len__(self):
8274
return len(self.imgs)
8375

76+
8477
def get_model_instance_segmentation(num_classes):
85-
# load an instance segmentation model pre-trained pre-trained on COCO
86-
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
78+
# load an instance segmentation model pre-trained on COCO
79+
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
8780

8881
# get number of input features for the classifier
8982
in_features = model.roi_heads.box_predictor.cls_score.in_features
@@ -103,9 +96,11 @@ def get_model_instance_segmentation(num_classes):
10396

10497
def get_transform(train):
10598
transforms = []
106-
transforms.append(T.ToTensor())
99+
transforms.append(T.ToImage())
107100
if train:
108101
transforms.append(T.RandomHorizontalFlip(0.5))
102+
transforms.append(T.ToDtype(torch.float, scale=True))
103+
transforms.append(T.ToPureTensor())
109104
return T.Compose(transforms)
110105

111106

@@ -160,6 +155,6 @@ def main():
160155
evaluate(model, data_loader_test, device=device)
161156

162157
print("That's it!")
163-
158+
164159
if __name__ == "__main__":
165160
main()

0 commit comments

Comments
 (0)