diff --git a/googleapiclient/http.py b/googleapiclient/http.py index 187f6f5dac8..099a8253f20 100644 --- a/googleapiclient/http.py +++ b/googleapiclient/http.py @@ -661,6 +661,119 @@ def __init__( ) +class MediaGenBaseDownload(object): + """ Download media resources using generator. + Example: + request = farms.animals().get_media(id='cow') + downloader = MediaGenBaseDownload(request) + for chunk, status, done in downloader.next_chunk(): + with open('cow.png', 'ab') as save_file: + save_file.write(chunk) + print("Download %d%%." % round(status.progress() * 100, 2)) + print(done) + """ + + @util.positional(2) + def __init__(self, request, chunksize=DEFAULT_CHUNK_SIZE): + """Constructor. + Args: + request: googleapiclient.http.HttpRequest, the media request to perform in + chunks. + chunksize: int, File will be downloaded in chunks of this many bytes. + """ + self._request = request + self._uri = request.uri + self._chunksize = chunksize + self._progress = 0 + self._total_size = None + self._done = False + + # Stubs for testing. + self._sleep = time.sleep + self._rand = random.random + + self._headers = {} + for k, v in request.headers.items(): + # allow users to supply custom headers by setting them on the request + # but strip out the ones that are set by default on requests generated by + # API methods like Drive's files().get(fileId=...) + if not k.lower() in ("accept", "accept-encoding", "user-agent"): + self._headers[k] = v + + @util.positional(1) + def next_chunk(self, num_retries=0): + """Get the next chunk of the download. + Args: + num_retries: Integer, number of times to retry with randomized + exponential backoff. If all retries fail, the raised HttpError + represents the last request. If zero (default), we attempt the + request only once. + Returns: + (content, status, done): (file object, MediaDownloadProgress, boolean) + The value of 'done' will be True when the media has been fully + downloaded or the total size of the media is unknown. + Raises: + googleapiclient.errors.HttpError if the response was not a 2xx. + httplib2.HttpLib2Error if a transport error has occurred. + """ + while self._done is False: + headers = self._headers.copy() + headers["range"] = "bytes=%d-%d" % ( + self._progress, + self._progress + self._chunksize - 1, + ) + http = self._request.http + + resp, content = _retry_request( + http, + num_retries, + "media download", + self._sleep, + self._rand, + self._uri, + "GET", + headers=headers, + ) + + if resp.status in [200, 206]: + if "content-location" in resp and resp["content-location"] != self._uri: + self._uri = resp["content-location"] + self._progress += len(content) + + if "content-range" in resp: + content_range = resp["content-range"] + length = content_range.rsplit("/", 1)[1] + self._total_size = int(length) + elif "content-length" in resp: + self._total_size = int(resp["content-length"]) + + if self._total_size is None or self._progress == self._total_size: + self._done = True + + yield ( + content, + MediaDownloadProgress(self._progress, self._total_size), + self._done + ) + continue + elif resp.status == 416: + # 416 is Range Not Satisfiable + # This typically occurs with a zero byte file + content_range = resp["content-range"] + length = content_range.rsplit("/", 1)[1] + self._total_size = int(length) + if self._total_size == 0: + self._done = True + yield ( + content, + MediaDownloadProgress(self._progress, self._total_size), + self._done, + ) + continue + + raise HttpError(resp, content, uri=self._uri) + + class MediaIoBaseDownload(object): """ "Download media resources. diff --git a/tests/test_http.py b/tests/test_http.py index 42110adfab1..0d6b1bf64a7 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -50,6 +50,7 @@ MediaFileUpload, MediaInMemoryUpload, MediaIoBaseDownload, + MediaGenBaseDownload, MediaIoBaseUpload, MediaUpload, _StreamSlice, @@ -458,6 +459,292 @@ def test_media_io_base_empty_file(self): self.assertEqual("0", http.request_sequence[-1][-1]["Content-Length"]) +class TestMediaGenBaseDownload(unittest.TestCase): + def setUp(self): + http = HttpMock(datafile("zoo.json"), {"status": "200"}) + zoo = build("zoo", "v1", http=http, static_discovery=False) + self.request = zoo.animals().get_media(name="Lion") + + def test_media_gen_base_download(self): + fd = io.BytesIO() + self.request.http = HttpMockSequence( + [ + ({"status": "200", "content-range": "0-2/5"}, b"123"), + ({"status": "200", "content-range": "3-4/5"}, b"45"), + ] + ) + self.assertEqual(True, self.request.http.follow_redirects) + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + + self.assertEqual(3, download._chunksize) + self.assertEqual(0, download._progress) + self.assertEqual(None, download._total_size) + self.assertEqual(False, download._done) + self.assertEqual(self.request.uri, download._uri) + + for content, status, done in download.next_chunk(): + fd.write(content) + self.assertEqual(fd.getvalue(), b"123") + self.assertEqual(False, done) + self.assertEqual(3, download._progress) + self.assertEqual(5, download._total_size) + self.assertEqual(3, status.resumable_progress) + break + + for content, status, done in download.next_chunk(): + fd.write(content) + self.assertEqual(fd.getvalue(), b"12345") + self.assertEqual(True, done) + self.assertEqual(5, download._progress) + self.assertEqual(5, download._total_size) + break + + def test_media_gen_base_download_range_request_header(self): + fd = io.BytesIO() + + self.request.http = HttpMockSequence( + [ + ( + {"status": "200", "content-range": "0-2/5"}, + "echo_request_headers_as_json", + ), + ] + ) + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + + for content, status, done in download.next_chunk(): + fd.write(content) + result = json.loads(fd.getvalue().decode("utf-8")) + self.assertEqual(result.get("range"), "bytes=0-2") + break + + + def test_media_gen_base_download_custom_request_headers(self): + fd = io.BytesIO() + self.request.http = HttpMockSequence( + [ + ( + {"status": "200", "content-range": "0-2/5"}, + "echo_request_headers_as_json", + ), + ( + {"status": "200", "content-range": "3-4/5"}, + "echo_request_headers_as_json", + ), + ] + ) + self.assertEqual(True, self.request.http.follow_redirects) + + self.request.headers["Cache-Control"] = "no-store" + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + + self.assertEqual(download._headers.get("Cache-Control"), "no-store") + + # assert that that the header we added to the original request is + # sent up to the server on each call to next_chunk + for content, status, done in download.next_chunk(): + fd.write(content) + result = json.loads(fd.getvalue().decode("utf-8")) + self.assertEqual(result.get("Cache-Control"), "no-store") + break + + fd = io.BytesIO() + for content, status, done in download.next_chunk(): + fd.write(content) + result = json.loads(fd.getvalue().decode("utf-8")) + self.assertEqual(result.get("Cache-Control"), "no-store") + break + + def test_media_gen_base_download_handle_redirects(self): + fd = io.BytesIO() + self.request.http = HttpMockSequence( + [ + ( + { + "status": "200", + "content-location": "https://secure.example.net/lion", + }, + b"", + ), + ({"status": "200", "content-range": "0-2/5"}, b"abc"), + ] + ) + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + + for content, status, done in download.next_chunk(): + break + + self.assertEqual("https://secure.example.net/lion", download._uri) + + def test_media_gen_base_download_handle_4xx(self): + fd = io.BytesIO() + self.request.http = HttpMockSequence([({"status": "400"}, "")]) + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + + try: + for content, status, done in download.next_chunk(): + self.fail("Should raise an exception") + except HttpError: + pass + + # Even after raising an exception we can pick up where we left off. + self.request.http = HttpMockSequence( + [({"status": "200", "content-range": "0-2/5"}, b"123")] + ) + + for content, status, done in download.next_chunk(): + fd.write(content) + break + + self.assertEqual(fd.getvalue(), b"123") + + + def test_media_gen_base_download_retries_connection_errors(self): + fd = io.BytesIO() + self.request.http = HttpMockWithErrors( + 5, {"status": "200", "content-range": "0-2/3"}, b"123" + ) + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + download._sleep = lambda _x: 0 # do nothing + download._rand = lambda: 10 + + for content, status, done in download.next_chunk(num_retries=5): + fd.write(content) + + self.assertEqual(fd.getvalue(), b"123") + self.assertEqual(True, done) + + + def test_media_gen_base_download_retries_5xx(self): + fd = io.BytesIO() + self.request.http = HttpMockSequence( + [ + ({"status": "500"}, ""), + ({"status": "500"}, ""), + ({"status": "500"}, ""), + ({"status": "200", "content-range": "0-2/5"}, b"123"), + ({"status": "503"}, ""), + ({"status": "503"}, ""), + ({"status": "503"}, ""), + ({"status": "200", "content-range": "3-4/5"}, b"45"), + ] + ) + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + + self.assertEqual(3, download._chunksize) + self.assertEqual(0, download._progress) + self.assertEqual(None, download._total_size) + self.assertEqual(False, download._done) + self.assertEqual(self.request.uri, download._uri) + + # Set time.sleep and random.random stubs. + sleeptimes = [] + download._sleep = lambda x: sleeptimes.append(x) + download._rand = lambda: 10 + + for content, status, done in download.next_chunk(num_retries=3): + fd.write(content) + break + + # Check for exponential backoff using the rand function above. + self.assertEqual([20, 40, 80], sleeptimes) + + self.assertEqual(fd.getvalue(), b"123") + self.assertEqual(False, done) + self.assertEqual(3, download._progress) + self.assertEqual(5, download._total_size) + self.assertEqual(3, status.resumable_progress) + + # Reset time.sleep stub. + del sleeptimes[0 : len(sleeptimes)] + + for content, status, done in download.next_chunk(num_retries=3): + fd.write(content) + break + + # Check for exponential backoff using the rand function above. + self.assertEqual([20, 40, 80], sleeptimes) + + self.assertEqual(fd.getvalue(), b"12345") + self.assertEqual(True, done) + self.assertEqual(5, download._progress) + self.assertEqual(5, download._total_size) + + def test_media_gen_base_download_empty_file(self): + fd = io.BytesIO() + + self.request.http = HttpMockSequence( + [({"status": "200", "content-range": "0-0/0"}, b"")] + ) + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + + self.assertEqual(0, download._progress) + self.assertEqual(None, download._total_size) + self.assertEqual(False, download._done) + self.assertEqual(self.request.uri, download._uri) + + for content, status, done in download.next_chunk(): + fd.write(content) + + self.assertEqual(fd.getvalue(), b"") + self.assertEqual(True, done) + self.assertEqual(0, download._progress) + self.assertEqual(0, download._total_size) + self.assertEqual(0, status.progress()) + + + def test_media_gen_base_download_empty_file_416_response(self): + fd = io.BytesIO() + self.request.http = HttpMockSequence( + [({"status": "416", "content-range": "0-0/0"}, b"")] + ) + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + + self.assertEqual(0, download._progress) + self.assertEqual(None, download._total_size) + self.assertEqual(False, download._done) + self.assertEqual(self.request.uri, download._uri) + + for content, status, done in download.next_chunk(): + fd.write(content) + + self.assertEqual(fd.getvalue(), b"") + self.assertEqual(True, done) + self.assertEqual(0, download._progress) + self.assertEqual(0, download._total_size) + self.assertEqual(0, status.progress()) + + + def test_media_gen_base_download_unknown_media_size(self): + fd = io.BytesIO() + self.request.http = HttpMockSequence([({"status": "200"}, b"123")]) + + download = MediaGenBaseDownload(request=self.request, chunksize=3) + + self.assertEqual(0, download._progress) + self.assertEqual(None, download._total_size) + self.assertEqual(False, download._done) + self.assertEqual(self.request.uri, download._uri) + + for content, status, done in download.next_chunk(): + fd.write(content) + + self.assertEqual(fd.getvalue(), b"123") + self.assertEqual(True, done) + self.assertEqual(3, download._progress) + self.assertEqual(None, download._total_size) + self.assertEqual(0, status.progress()) + + class TestMediaIoBaseDownload(unittest.TestCase): def setUp(self): http = HttpMock(datafile("zoo.json"), {"status": "200"})