Skip to content

Commit 1891254

Browse files
committed
[skip-ci] Updated python according to main version
1 parent a02c120 commit 1891254

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

intermediate_source/torchvision_tutorial.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
"""
66

77
######################################################################
8+
#
9+
# .. tip::
10+
#
11+
# To get the most of this tutorial, we suggest using this
12+
# `Colab Version <https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/torchvision_finetuning_instance_segmentation.ipynb>`__.
13+
# This will allow you to experiment with the information presented below.
14+
#
15+
#
816
# For this tutorial, we will be finetuning a pre-trained `Mask
917
# R-CNN <https://arxiv.org/abs/1703.06870>`__ model on the `Penn-Fudan
1018
# Database for Pedestrian Detection and
@@ -17,6 +25,8 @@
1725
# .. note ::
1826
#
1927
# This tutorial works only with torchvision version >=0.16 or nightly.
28+
# If you're using torchvision<=0.15, please follow
29+
# `this tutorial instead <https://github.com/pytorch/tutorials/blob/d686b662932a380a58b7683425faa00c06bcf502/intermediate_source/torchvision_tutorial.rst>`_.
2030
#
2131
#
2232
# Defining the Dataset
@@ -252,8 +262,10 @@ def __len__(self):
252262
# ratios. We have a Tuple[Tuple[int]] because each feature
253263
# map could potentially have different sizes and
254264
# aspect ratios
255-
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
256-
aspect_ratios=((0.5, 1.0, 2.0),))
265+
anchor_generator = AnchorGenerator(
266+
sizes=((32, 64, 128, 256, 512),),
267+
aspect_ratios=((0.5, 1.0, 2.0),)
268+
)
257269

258270
# let's define what are the feature maps that we will
259271
# use to perform the region of interest cropping, as well as
@@ -262,15 +274,19 @@ def __len__(self):
262274
# be [0]. More generally, the backbone should return an
263275
# ``OrderedDict[Tensor]``, and in ``featmap_names`` you can choose which
264276
# feature maps to use.
265-
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
266-
output_size=7,
267-
sampling_ratio=2)
277+
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
278+
featmap_names=['0'],
279+
output_size=7,
280+
sampling_ratio=2
281+
)
268282

269283
# put the pieces together inside a Faster-RCNN model
270-
model = FasterRCNN(backbone,
271-
num_classes=2,
272-
rpn_anchor_generator=anchor_generator,
273-
box_roi_pool=roi_pooler)
284+
model = FasterRCNN(
285+
backbone,
286+
num_classes=2,
287+
rpn_anchor_generator=anchor_generator,
288+
box_roi_pool=roi_pooler
289+
)
274290

275291
######################################################################
276292
# Object detection and instance segmentation model for PennFudan Dataset
@@ -301,9 +317,11 @@ def get_model_instance_segmentation(num_classes):
301317
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
302318
hidden_layer = 256
303319
# and replace the mask predictor with a new one
304-
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
305-
hidden_layer,
306-
num_classes)
320+
model.roi_heads.mask_predictor = MaskRCNNPredictor(
321+
in_features_mask,
322+
hidden_layer,
323+
num_classes
324+
)
307325

308326
return model
309327

@@ -477,6 +495,7 @@ def get_transform(train):
477495
predictions = model([x, ])
478496
pred = predictions[0]
479497

498+
480499
image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
481500
image = image[:3, ...]
482501
pred_labels = [f"pedestrian: {score:.3f}" for label, score in zip(pred["labels"], pred["scores"])]
@@ -486,7 +505,8 @@ def get_transform(train):
486505
masks = (pred["masks"] > 0.7).squeeze(1)
487506
output_image = draw_segmentation_masks(output_image, masks, alpha=0.5, colors="blue")
488507

489-
plt.figure()
508+
509+
plt.figure(figsize=(12, 12))
490510
plt.imshow(output_image.permute(1, 2, 0))
491511

492512
######################################################################
@@ -506,3 +526,5 @@ def get_transform(train):
506526
# training, check ``references/detection/train.py``, which is present in
507527
# the torchvision repository.
508528
#
529+
# You can download a full source file for this tutorial
530+
# `here <https://pytorch.org/tutorials/_static/tv-training-code.py>`__.

0 commit comments

Comments
 (0)