@@ -304,15 +304,19 @@ def print_size_of_model(model):
304
304
# ----------------------------------
305
305
#
306
306
# As our last major setup step, we define our dataloaders for our training and testing set.
307
- # The specific dataset we've created for this tutorial contains just 1000 images, one from
307
+ #
308
+ # ImageNet Data
309
+ # ^^^^^^^^^^^^^
310
+ #
311
+ # The specific dataset we've created for this tutorial contains just 1000 images from the ImageNet data, one from
308
312
# each class (this dataset, at just over 250 MB, is small enough that it can be downloaded
309
313
# relatively easily). The URL for this custom dataset is:
310
314
#
311
315
# .. code::
312
316
#
313
317
# https://s3.amazonaws.com/pytorch-tutorial-assets/imagenet_1k.zip
314
318
#
315
- # To download this data locally using Python, then, you could use:
319
+ # To download this data locally using Python, you could use:
316
320
#
317
321
# .. code:: python
318
322
#
@@ -326,11 +330,32 @@ def print_size_of_model(model):
326
330
# with open(filename, 'wb') as f:
327
331
# f.write(r.content)
328
332
#
329
- #
330
333
# For this tutorial to run, we download this data and move it to the right place using
331
334
# `these lines <https://github.com/pytorch/tutorials/blob/master/Makefile#L97-L98>`_
332
335
# from the `Makefile <https://github.com/pytorch/tutorials/blob/master/Makefile>`_.
333
336
#
337
+ # To run the code in this tutorial using the entire ImageNet dataset, on the other hand, you could download
338
+ # the data using ``torchvision`` following
339
+ # `here <https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet>`_. For example,
340
+ # to download the training set and apply some standard transformations to it, you could use:
341
+ #
342
+ # .. code:: python
343
+ #
344
+ # import torchvision
345
+ # import torchvision.transforms as transforms
346
+ #
347
+ # imagenet_dataset = torchvision.datasets.ImageNet(
348
+ # '~/.data/imagenet',
349
+ # split='train',
350
+ # download=True,
351
+ # transforms.Compose([
352
+ # transforms.RandomResizedCrop(224),
353
+ # transforms.RandomHorizontalFlip(),
354
+ # transforms.ToTensor(),
355
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406],
356
+ # std=[0.229, 0.224, 0.225]),
357
+ # ])
358
+ #
334
359
# With the data downloaded, we show functions below that define dataloaders we'll use to read
335
360
# in this data. These functions mostly come from
336
361
# `here <https://github.com/pytorch/vision/blob/master/references/detection/train.py>`_.
@@ -374,12 +399,12 @@ def prepare_data_loaders(data_path):
374
399
return data_loader , data_loader_test
375
400
376
401
######################################################################
377
- # Next, we'll load in the pre-trained MobileNetV2 model. Similarly to the data about, the file with the pre-trained
378
- # weights is stored at `` https://s3.amazonaws. com/pytorch-tutorial-assets/mobilenet_quantization.pth``:
402
+ # Next, we'll load in the pre-trained MobileNetV2 model. We provide the URL to download the data from in ``torchvision``
403
+ # `here < https://github. com/pytorch/vision/blob/master/torchvision/models/mobilenet.py#L9>`_.
379
404
380
405
data_path = 'data/imagenet_1k'
381
406
saved_model_dir = 'data/'
382
- float_model_file = 'mobilenet_quantization .pth'
407
+ float_model_file = 'mobilenet_pretrained_float .pth'
383
408
scripted_float_model_file = 'mobilenet_quantization_scripted.pth'
384
409
scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth'
385
410
0 commit comments