Skip to content

Commit 70a0f8c

Browse files
authored
[BE] Move data download logic to download_data.py (#2581)
- `download_url_to_file` which is heavily inspired by [torch/hub.py](https://github.com/pytorch/pytorch/blob/efb73fe8e4413a0d6db078e85c7ed7c91f05ca5d/torch/hub.py#L600) - `size_fmt` is borrowed from [torch/autograd/profiler_util.py](https://github.com/pytorch/pytorch/blob/efb73fe8e4413a0d6db078e85c7ed7c91f05ca5d/torch/autograd/profiler_util.py#L372) - Skip downloads if `FILES_TO_RUN` is defined, but tutorial is not in this shard, i.e.: - Call `download_dcgan_data`(which is 1Gb+ downloadable) only for `dcgan_tutorial ` - Call `download_lenet_mnist`(downloads from rate-limited GDrive) only for `fgsm_tutorial` - Call `download_hymenoptera_data` only for `transfer_learning_tutorial` - Call `download_nlp_data` for `seq2seq_translation_tutorial`, `char_rnn_classification_tutorial` and `char_rnn_generation_tutorial`
1 parent 309c889 commit 70a0f8c

File tree

2 files changed

+130
-21
lines changed

2 files changed

+130
-21
lines changed

.jenkins/download_data.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
import hashlib
3+
import os
4+
5+
from typing import Optional
6+
from urllib.request import urlopen, Request
7+
from pathlib import Path
8+
from zipfile import ZipFile
9+
10+
REPO_BASE_DIR = Path(__file__).absolute().parent.parent
11+
DATA_DIR = REPO_BASE_DIR / "_data"
12+
BEGINNER_DATA_DIR = REPO_BASE_DIR / "beginner_source" / "data"
13+
INTERMEDIATE_DATA_DIR = REPO_BASE_DIR / "intermediate_source" / "data"
14+
ADVANCED_DATA_DIR = REPO_BASE_DIR / "advanced_source" / "data"
15+
PROTOTYPE_DATA_DIR = REPO_BASE_DIR / "prototype_source" / "data"
16+
FILES_TO_RUN = os.getenv("FILES_TO_RUN")
17+
18+
19+
def size_fmt(nbytes: int) -> str:
20+
"""Returns a formatted file size string"""
21+
KB = 1024
22+
MB = 1024 * KB
23+
GB = 1024 * MB
24+
if abs(nbytes) >= GB:
25+
return f"{nbytes * 1.0 / GB:.2f} Gb"
26+
elif abs(nbytes) >= MB:
27+
return f"{nbytes * 1.0 / MB:.2f} Mb"
28+
elif abs(nbytes) >= KB:
29+
return f"{nbytes * 1.0 / KB:.2f} Kb"
30+
return str(nbytes) + " bytes"
31+
32+
33+
def download_url_to_file(url: str,
34+
dst: Optional[str] = None,
35+
prefix: Optional[Path] = None,
36+
sha256: Optional[str] = None) -> Path:
37+
dst = dst if dst is not None else Path(url).name
38+
dst = dst if prefix is None else str(prefix / dst)
39+
if Path(dst).exists():
40+
print(f"Skip downloading {url} as {dst} already exists")
41+
return Path(dst)
42+
file_size = None
43+
u = urlopen(Request(url, headers={"User-Agent": "tutorials.downloader"}))
44+
meta = u.info()
45+
if hasattr(meta, 'getheaders'):
46+
content_length = meta.getheaders("Content-Length")
47+
else:
48+
content_length = meta.get_all("Content-Length")
49+
if content_length is not None and len(content_length) > 0:
50+
file_size = int(content_length[0])
51+
sha256_sum = hashlib.sha256()
52+
with open(dst, "wb") as f:
53+
while True:
54+
buffer = u.read(32768)
55+
if len(buffer) == 0:
56+
break
57+
sha256_sum.update(buffer)
58+
f.write(buffer)
59+
digest = sha256_sum.hexdigest()
60+
if sha256 is not None and sha256 != digest:
61+
Path(dst).unlink()
62+
raise RuntimeError(f"Downloaded {url} has unexpected sha256sum {digest} should be {sha256}")
63+
print(f"Downloaded {url} sha256sum={digest} size={size_fmt(file_size)}")
64+
return Path(dst)
65+
66+
67+
def unzip(archive: Path, tgt_dir: Path) -> None:
68+
with ZipFile(str(archive), "r") as zip_ref:
69+
zip_ref.extractall(str(tgt_dir))
70+
71+
72+
def download_hymenoptera_data():
73+
# transfer learning tutorial data
74+
z = download_url_to_file("https://download.pytorch.org/tutorial/hymenoptera_data.zip",
75+
prefix=DATA_DIR,
76+
sha256="fbc41b31d544714d18dd1230b1e2b455e1557766e13e67f9f5a7a23af7c02209",
77+
)
78+
unzip(z, BEGINNER_DATA_DIR)
79+
80+
81+
def download_nlp_data() -> None:
82+
# nlp tutorial data
83+
z = download_url_to_file("https://download.pytorch.org/tutorial/data.zip",
84+
prefix=DATA_DIR,
85+
sha256="fb317e80248faeb62dc25ef3390ae24ca34b94e276bbc5141fd8862c2200bff5",
86+
)
87+
# This will unzip all files in data.zip to intermediate_source/data/ folder
88+
unzip(z, INTERMEDIATE_DATA_DIR.parent)
89+
90+
91+
def download_dcgan_data() -> None:
92+
# Download dataset for beginner_source/dcgan_faces_tutorial.py
93+
z = download_url_to_file("https://s3.amazonaws.com/pytorch-tutorial-assets/img_align_celeba.zip",
94+
prefix=DATA_DIR,
95+
sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74",
96+
)
97+
unzip(z, BEGINNER_DATA_DIR / "celeba")
98+
99+
100+
def download_lenet_mnist() -> None:
101+
# Download model for beginner_source/fgsm_tutorial.py
102+
download_url_to_file("https://docs.google.com/uc?export=download&id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl",
103+
prefix=BEGINNER_DATA_DIR,
104+
dst="lenet_mnist_model.pth",
105+
sha256="cb5f8e578aef96d5c1a2cc5695e1aa9bbf4d0fe00d25760eeebaaac6ebc2edcb",
106+
)
107+
108+
109+
def main() -> None:
110+
DATA_DIR.mkdir(exist_ok=True)
111+
BEGINNER_DATA_DIR.mkdir(exist_ok=True)
112+
ADVANCED_DATA_DIR.mkdir(exist_ok=True)
113+
INTERMEDIATE_DATA_DIR.mkdir(exist_ok=True)
114+
PROTOTYPE_DATA_DIR.mkdir(exist_ok=True)
115+
116+
if FILES_TO_RUN is None or "transfer_learning_tutorial" in FILES_TO_RUN:
117+
download_hymenoptera_data()
118+
nlp_tutorials = ["seq2seq_translation_tutorial", "char_rnn_classification_tutorial", "char_rnn_generation_tutorial"]
119+
if FILES_TO_RUN is None or any(x in FILES_TO_RUN for x in nlp_tutorials):
120+
download_nlp_data()
121+
if FILES_TO_RUN is None or "dcgan_faces_tutorial" in FILES_TO_RUN:
122+
download_dcgan_data()
123+
if FILES_TO_RUN is None or "fgsm_tutorial" in FILES_TO_RUN:
124+
download_lenet_mnist()
125+
126+
127+
if __name__ == "__main__":
128+
main()

Makefile

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,8 @@ download:
3838
# Step2-2. UNTAR: tar -xzf $(DATADIR)/[SOURCE_FILE] -C [*_source/data/]
3939
# Step2-3. AS-IS: cp $(DATADIR)/[SOURCE_FILE] [*_source/data/]
4040

41-
# make data directories
42-
mkdir -p $(DATADIR)
43-
mkdir -p advanced_source/data
44-
mkdir -p beginner_source/data
45-
mkdir -p intermediate_source/data
46-
mkdir -p prototype_source/data
47-
48-
# transfer learning tutorial data
49-
wget -nv -N https://download.pytorch.org/tutorial/hymenoptera_data.zip -P $(DATADIR)
50-
unzip $(ZIPOPTS) $(DATADIR)/hymenoptera_data.zip -d beginner_source/data/
51-
52-
# nlp tutorial data
53-
wget -nv -N https://download.pytorch.org/tutorial/data.zip -P $(DATADIR)
54-
unzip $(ZIPOPTS) $(DATADIR)/data.zip -d intermediate_source/ # This will unzip all files in data.zip to intermediate_source/data/ folder
41+
# Run structured downloads first (will also make directories
42+
python3 .jenkins/download_data.py
5543

5644
# data loader tutorial
5745
wget -nv -N https://download.pytorch.org/tutorial/faces.zip -P $(DATADIR)
@@ -65,10 +53,6 @@ download:
6553
mkdir -p advanced_source/data/images/
6654
cp -r _static/img/neural-style/ advanced_source/data/images/
6755

68-
# Download dataset for beginner_source/dcgan_faces_tutorial.py
69-
wget -nv -N https://s3.amazonaws.com/pytorch-tutorial-assets/img_align_celeba.zip -P $(DATADIR)
70-
unzip $(ZIPOPTS) $(DATADIR)/img_align_celeba.zip -d beginner_source/data/celeba
71-
7256
# Download dataset for beginner_source/hybrid_frontend/introduction_to_hybrid_frontend_tutorial.py
7357
wget -nv -N https://s3.amazonaws.com/pytorch-tutorial-assets/iris.data -P $(DATADIR)
7458
cp $(DATADIR)/iris.data beginner_source/data/
@@ -81,9 +65,6 @@ download:
8165
wget -nv -N https://s3.amazonaws.com/pytorch-tutorial-assets/UrbanSound8K.tar.gz -P $(DATADIR)
8266
tar $(TAROPTS) -xzf $(DATADIR)/UrbanSound8K.tar.gz -C ./beginner_source/data/
8367

84-
# Download model for beginner_source/fgsm_tutorial.py
85-
wget -nv 'https://docs.google.com/uc?export=download&id=1HJV2nUHJqclXQ8flKvcWmjZ-OU5DGatl' -O $(DATADIR)/lenet_mnist_model.pth
86-
cp $(DATADIR)/lenet_mnist_model.pth ./beginner_source/data/lenet_mnist_model.pth
8768

8869
# Download model for advanced_source/dynamic_quantization_tutorial.py
8970
wget -nv -N https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth -P $(DATADIR)

0 commit comments

Comments
 (0)