|
| 1 | +# Sample code from the TorchVision 0.3 Object Detection Finetuning Tutorial |
| 2 | +# http://pytorch.org/tutorials/intermediate/torchvision_tutorial.html |
| 3 | + |
| 4 | +import os |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from PIL import Image |
| 8 | + |
| 9 | +import torchvision |
| 10 | +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor |
| 11 | +from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor |
| 12 | + |
| 13 | +from engine import train_one_epoch, evaluate |
| 14 | +import utils |
| 15 | +import transforms as T |
| 16 | + |
| 17 | + |
| 18 | +class PennFudanDataset(object): |
| 19 | + def __init__(self, root, transforms): |
| 20 | + self.root = root |
| 21 | + self.transforms = transforms |
| 22 | + # load all image files, sorting them to |
| 23 | + # ensure that they are aligned |
| 24 | + self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages")))) |
| 25 | + self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks")))) |
| 26 | + |
| 27 | + def __getitem__(self, idx): |
| 28 | + # load images ad masks |
| 29 | + img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) |
| 30 | + mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) |
| 31 | + img = Image.open(img_path).convert("RGB") |
| 32 | + # note that we haven't converted the mask to RGB, |
| 33 | + # because each color corresponds to a different instance |
| 34 | + # with 0 being background |
| 35 | + mask = Image.open(mask_path) |
| 36 | + |
| 37 | + mask = np.array(mask) |
| 38 | + # instances are encoded as different colors |
| 39 | + obj_ids = np.unique(mask) |
| 40 | + # first id is the background, so remove it |
| 41 | + obj_ids = obj_ids[1:] |
| 42 | + |
| 43 | + # split the color-encoded mask into a set |
| 44 | + # of binary masks |
| 45 | + masks = mask == obj_ids[:, None, None] |
| 46 | + |
| 47 | + # get bounding box coordinates for each mask |
| 48 | + num_objs = len(obj_ids) |
| 49 | + boxes = [] |
| 50 | + for i in range(num_objs): |
| 51 | + pos = np.where(masks[i]) |
| 52 | + xmin = np.min(pos[1]) |
| 53 | + xmax = np.max(pos[1]) |
| 54 | + ymin = np.min(pos[0]) |
| 55 | + ymax = np.max(pos[0]) |
| 56 | + boxes.append([xmin, ymin, xmax, ymax]) |
| 57 | + |
| 58 | + boxes = torch.as_tensor(boxes, dtype=torch.float32) |
| 59 | + # there is only one class |
| 60 | + labels = torch.ones((num_objs,), dtype=torch.int64) |
| 61 | + masks = torch.as_tensor(masks, dtype=torch.uint8) |
| 62 | + |
| 63 | + image_id = torch.tensor([idx]) |
| 64 | + area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) |
| 65 | + # suppose all instances are not crowd |
| 66 | + iscrowd = torch.zeros((num_objs,), dtype=torch.int64) |
| 67 | + |
| 68 | + target = {} |
| 69 | + target["boxes"] = boxes |
| 70 | + target["labels"] = labels |
| 71 | + target["masks"] = masks |
| 72 | + target["image_id"] = image_id |
| 73 | + target["area"] = area |
| 74 | + target["iscrowd"] = iscrowd |
| 75 | + |
| 76 | + if self.transforms is not None: |
| 77 | + img, target = self.transforms(img, target) |
| 78 | + |
| 79 | + return img, target |
| 80 | + |
| 81 | + def __len__(self): |
| 82 | + return len(self.imgs) |
| 83 | + |
| 84 | +def get_model_instance_segmentation(num_classes): |
| 85 | + # load an instance segmentation model pre-trained pre-trained on COCO |
| 86 | + model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) |
| 87 | + |
| 88 | + # get number of input features for the classifier |
| 89 | + in_features = model.roi_heads.box_predictor.cls_score.in_features |
| 90 | + # replace the pre-trained head with a new one |
| 91 | + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) |
| 92 | + |
| 93 | + # now get the number of input features for the mask classifier |
| 94 | + in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels |
| 95 | + hidden_layer = 256 |
| 96 | + # and replace the mask predictor with a new one |
| 97 | + model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, |
| 98 | + hidden_layer, |
| 99 | + num_classes) |
| 100 | + |
| 101 | + return model |
| 102 | + |
| 103 | + |
| 104 | +def get_transform(train): |
| 105 | + transforms = [] |
| 106 | + transforms.append(T.ToTensor()) |
| 107 | + if train: |
| 108 | + transforms.append(T.RandomHorizontalFlip(0.5)) |
| 109 | + return T.Compose(transforms) |
| 110 | + |
| 111 | + |
| 112 | +def main(): |
| 113 | + # train on the GPU or on the CPU, if a GPU is not available |
| 114 | + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
| 115 | + |
| 116 | + # our dataset has two classes only - background and person |
| 117 | + num_classes = 2 |
| 118 | + # use our dataset and defined transformations |
| 119 | + dataset = PennFudanDataset('PennFudanPed', get_transform(train=True)) |
| 120 | + dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False)) |
| 121 | + |
| 122 | + # split the dataset in train and test set |
| 123 | + indices = torch.randperm(len(dataset)).tolist() |
| 124 | + dataset = torch.utils.data.Subset(dataset, indices[:-50]) |
| 125 | + dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:]) |
| 126 | + |
| 127 | + # define training and validation data loaders |
| 128 | + data_loader = torch.utils.data.DataLoader( |
| 129 | + dataset, batch_size=2, shuffle=True, num_workers=4, |
| 130 | + collate_fn=utils.collate_fn) |
| 131 | + |
| 132 | + data_loader_test = torch.utils.data.DataLoader( |
| 133 | + dataset_test, batch_size=1, shuffle=False, num_workers=4, |
| 134 | + collate_fn=utils.collate_fn) |
| 135 | + |
| 136 | + # get the model using our helper function |
| 137 | + model = get_model_instance_segmentation(num_classes) |
| 138 | + |
| 139 | + # move model to the right device |
| 140 | + model.to(device) |
| 141 | + |
| 142 | + # construct an optimizer |
| 143 | + params = [p for p in model.parameters() if p.requires_grad] |
| 144 | + optimizer = torch.optim.SGD(params, lr=0.005, |
| 145 | + momentum=0.9, weight_decay=0.0005) |
| 146 | + # and a learning rate scheduler |
| 147 | + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, |
| 148 | + step_size=3, |
| 149 | + gamma=0.1) |
| 150 | + |
| 151 | + # let's train it for 10 epochs |
| 152 | + num_epochs = 10 |
| 153 | + |
| 154 | + for epoch in range(num_epochs): |
| 155 | + # train for one epoch, printing every 10 iterations |
| 156 | + train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10) |
| 157 | + # update the learning rate |
| 158 | + lr_scheduler.step() |
| 159 | + # evaluate on the test dataset |
| 160 | + evaluate(model, data_loader_test, device=device) |
| 161 | + |
| 162 | + print("That's it!") |
| 163 | + |
| 164 | +if __name__ == "__main__": |
| 165 | + main() |
0 commit comments