Skip to content

Commit 8244bff

Browse files
authored
Merge pull request #796 from prajjwal1/master
added sample code for fasterrcnn_resnet50_fpn (optional)
2 parents 212c55b + fe6080f commit 8244bff

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

intermediate_source/torchvision_tutorial.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,30 @@ transformation:
327327
transforms.append(T.RandomHorizontalFlip(0.5))
328328
return T.Compose(transforms)
329329
330+
331+
Testing ``forward()`` method (Optional)
332+
---------------------------------------
333+
334+
Before iterating over the dataset, it's good to see what the model
335+
expects during training and inference time on sample data.
336+
337+
.. code:: python
338+
339+
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
340+
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
341+
data_loader = torch.utils.data.DataLoader(
342+
dataset, batch_size=2, shuffle=True, num_workers=4,
343+
collate_fn=utils.collate_fn)
344+
# For Training
345+
images,targets = next(iter(data_loader))
346+
images = list(image for image in images)
347+
targets = [{k: v for k, v in t.items()} for t in targets]
348+
output = model(images,targets) # Returns losses and detections
349+
# For inference
350+
model.eval()
351+
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
352+
predictions = model(x) # Returns predictions
353+
330354
Let’s now write the main function which performs the training and the
331355
validation:
332356

0 commit comments

Comments
 (0)