@@ -304,30 +304,6 @@ be using Mask R-CNN:
304
304
That’s it, this will make ``model `` be ready to be trained and evaluated
305
305
on your custom dataset.
306
306
307
- Checking the model with random tensors (Optional)
308
- ---------------------------
309
-
310
- Before iterating over the dataset, it's always good to see what the model
311
- expects during training and inference time with random tensors.
312
-
313
- .. code :: python
314
-
315
- model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True )
316
- images,boxes,labels = torch.rand(4 ,3 ,600 ,1200 ), torch.rand(4 ,11 ,4 ), torch.rand(4 ,11 ) # For Training
317
- images = list (image for image in images)
318
- targets = []
319
- for i in range (len (images)):
320
- d = {}
321
- d[' boxes' ] = boxes[i]
322
- d[' labels' ] = labels[i].type(torch.int64)
323
- targets.append(d)
324
- output = model(images,targets) # Returns losses and detections
325
-
326
- model.eval() # For inference
327
- x = [torch.rand(3 , 300 , 400 ), torch.rand(3 , 500 , 400 )]
328
- predictions = model(x) # Returns predictions
329
-
330
-
331
307
Putting everything together
332
308
---------------------------
333
309
@@ -351,6 +327,30 @@ transformation:
351
327
transforms.append(T.RandomHorizontalFlip(0.5 ))
352
328
return T.Compose(transforms)
353
329
330
+
331
+ Testing ``forward() `` method (Optional)
332
+ ---------------------------------------
333
+
334
+ Before iterating over the dataset, it's good to see what the model
335
+ expects during training and inference time on sample data.
336
+
337
+ .. code :: python
338
+
339
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True )
340
+ dataset = PennFudanDataset(' PennFudanPed' , get_transform(train = True ))
341
+ data_loader = torch.utils.data.DataLoader(
342
+ dataset, batch_size = 2 , shuffle = True , num_workers = 4 ,
343
+ collate_fn = utils.collate_fn)
344
+ # For Training
345
+ images,targets = next (iter (data_loader))
346
+ images = list (image for image in images)
347
+ targets = [{k: v for k, v in t.items()} for t in targets]
348
+ output = model(images,targets) # Returns losses and detections
349
+ # For inference
350
+ model.eval()
351
+ x = [torch.rand(3 , 300 , 400 ), torch.rand(3 , 500 , 400 )]
352
+ predictions = model(x) # Returns predictions
353
+
354
354
Let’s now write the main function which performs the training and the
355
355
validation:
356
356
0 commit comments