2
2
# http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
3
3
4
4
import os
5
- import numpy as np
6
5
import torch
7
- from PIL import Image
8
6
9
7
import torchvision
10
8
from torchvision .models .detection .faster_rcnn import FastRCNNPredictor
11
9
from torchvision .models .detection .mask_rcnn import MaskRCNNPredictor
10
+ from torchvision .io import read_image
11
+ from torchvision .ops .boxes import masks_to_boxes
12
+ from torchvision import datapoints as dp
13
+ from torchvision .transforms .v2 import functional as F
14
+ from torchvision .transforms import v2 as T
15
+
12
16
13
17
from engine import train_one_epoch , evaluate
14
18
import utils
15
- import transforms as T
16
19
17
20
18
- class PennFudanDataset (object ):
21
+ class PennFudanDataset (torch . utils . data . Dataset ):
19
22
def __init__ (self , root , transforms ):
20
23
self .root = root
21
24
self .transforms = transforms
@@ -28,47 +31,36 @@ def __getitem__(self, idx):
28
31
# load images and masks
29
32
img_path = os .path .join (self .root , "PNGImages" , self .imgs [idx ])
30
33
mask_path = os .path .join (self .root , "PedMasks" , self .masks [idx ])
31
- img = Image .open (img_path ).convert ("RGB" )
32
- # note that we haven't converted the mask to RGB,
33
- # because each color corresponds to a different instance
34
- # with 0 being background
35
- mask = Image .open (mask_path )
36
-
37
- mask = np .array (mask )
34
+ img = read_image (img_path )
35
+ mask = read_image (mask_path )
38
36
# instances are encoded as different colors
39
- obj_ids = np .unique (mask )
37
+ obj_ids = torch .unique (mask )
40
38
# first id is the background, so remove it
41
39
obj_ids = obj_ids [1 :]
40
+ num_objs = len (obj_ids )
42
41
43
42
# split the color-encoded mask into a set
44
43
# of binary masks
45
- masks = mask == obj_ids [:, None , None ]
44
+ masks = ( mask == obj_ids [:, None , None ]). to ( dtype = torch . uint8 )
46
45
47
46
# get bounding box coordinates for each mask
48
- num_objs = len (obj_ids )
49
- boxes = []
50
- for i in range (num_objs ):
51
- pos = np .where (masks [i ])
52
- xmin = np .min (pos [1 ])
53
- xmax = np .max (pos [1 ])
54
- ymin = np .min (pos [0 ])
55
- ymax = np .max (pos [0 ])
56
- boxes .append ([xmin , ymin , xmax , ymax ])
57
-
58
- boxes = torch .as_tensor (boxes , dtype = torch .float32 )
47
+ boxes = masks_to_boxes (masks )
48
+
59
49
# there is only one class
60
50
labels = torch .ones ((num_objs ,), dtype = torch .int64 )
61
- masks = torch .as_tensor (masks , dtype = torch .uint8 )
62
51
63
- image_id = torch . tensor ([ idx ])
52
+ image_id = idx
64
53
area = (boxes [:, 3 ] - boxes [:, 1 ]) * (boxes [:, 2 ] - boxes [:, 0 ])
65
54
# suppose all instances are not crowd
66
55
iscrowd = torch .zeros ((num_objs ,), dtype = torch .int64 )
67
56
57
+ # Wrap sample and targets into torchvision datapoints:
58
+ img = dp .Image (img )
59
+
68
60
target = {}
69
- target ["boxes" ] = boxes
61
+ target ["boxes" ] = dp .BoundingBoxes (boxes , format = "XYXY" , canvas_size = F .get_size (img ))
62
+ target ["masks" ] = dp .Mask (masks )
70
63
target ["labels" ] = labels
71
- target ["masks" ] = masks
72
64
target ["image_id" ] = image_id
73
65
target ["area" ] = area
74
66
target ["iscrowd" ] = iscrowd
@@ -81,9 +73,10 @@ def __getitem__(self, idx):
81
73
def __len__ (self ):
82
74
return len (self .imgs )
83
75
76
+
84
77
def get_model_instance_segmentation (num_classes ):
85
- # load an instance segmentation model pre-trained pre-trained on COCO
86
- model = torchvision .models .detection .maskrcnn_resnet50_fpn (pretrained = True )
78
+ # load an instance segmentation model pre-trained on COCO
79
+ model = torchvision .models .detection .maskrcnn_resnet50_fpn (weights = "DEFAULT" )
87
80
88
81
# get number of input features for the classifier
89
82
in_features = model .roi_heads .box_predictor .cls_score .in_features
@@ -103,9 +96,11 @@ def get_model_instance_segmentation(num_classes):
103
96
104
97
def get_transform (train ):
105
98
transforms = []
106
- transforms .append (T .ToTensor ())
99
+ transforms .append (T .ToImage ())
107
100
if train :
108
101
transforms .append (T .RandomHorizontalFlip (0.5 ))
102
+ transforms .append (T .ToDtype (torch .float , scale = True ))
103
+ transforms .append (T .ToPureTensor ())
109
104
return T .Compose (transforms )
110
105
111
106
@@ -160,6 +155,6 @@ def main():
160
155
evaluate (model , data_loader_test , device = device )
161
156
162
157
print ("That's it!" )
163
-
158
+
164
159
if __name__ == "__main__" :
165
160
main ()
0 commit comments