diff --git a/httpie/downloads.py b/httpie/downloads.py index 9c4b895e6f..7a11e78311 100644 --- a/httpie/downloads.py +++ b/httpie/downloads.py @@ -7,7 +7,7 @@ import re from mailbox import Message from time import monotonic -from typing import IO, Optional, Tuple +from typing import IO, Mapping, Optional, Tuple from urllib.parse import urlsplit import requests @@ -24,6 +24,22 @@ class ContentRangeError(ValueError): pass +def has_content_encoding(headers: Mapping[str, str]) -> bool: + content_encoding = headers.get('Content-Encoding', '') + return any( + coding.strip().lower() != 'identity' + for coding in content_encoding.split(',') + if coding.strip() + ) + + +def get_raw_body_bytes_read(response: requests.Response) -> Optional[int]: + try: + return response.raw.tell() + except AttributeError: + return None + + def parse_content_range(content_range: str, resumed_from: int) -> int: """ Parse and validate Content-Range header. @@ -182,6 +198,8 @@ def __init__( self._output_file = output_file self._resume = resume self._resumed_from = 0 + self._raw_body_bytes_read = None + self._downloaded_size_getter = None def pre_request(self, request_headers: dict): """Called just before the HTTP request is sent. @@ -251,6 +269,15 @@ def start( on_body_chunk_downloaded=self.chunk_downloaded, ) + if has_content_encoding(final_response.headers): + self._raw_body_bytes_read = get_raw_body_bytes_read(final_response) + if self._raw_body_bytes_read is None: + total_size = None + else: + self._downloaded_size_getter = ( + lambda: get_raw_body_bytes_read(final_response) + ) + self.status.started( output_file=self._output_file, resumed_from=self._resumed_from, @@ -283,7 +310,13 @@ def chunk_downloaded(self, chunk: bytes): been downloaded and written to the output. """ - self.status.chunk_downloaded(len(chunk)) + size = len(chunk) + if self._downloaded_size_getter: + raw_body_bytes_read = self._downloaded_size_getter() + if raw_body_bytes_read is not None: + size = raw_body_bytes_read - self._raw_body_bytes_read + self._raw_body_bytes_read = raw_body_bytes_read + self.status.chunk_downloaded(size) @staticmethod def _get_output_file_from_response( diff --git a/tests/test_downloads.py b/tests/test_downloads.py index b646a0e6a5..5371c132c0 100644 --- a/tests/test_downloads.py +++ b/tests/test_downloads.py @@ -17,10 +17,20 @@ class Response(requests.Response): # noinspection PyDefaultArgument - def __init__(self, url, headers={}, status_code=200): + def __init__(self, url, headers={}, status_code=200, raw=None): self.url = url self.headers = CaseInsensitiveDict(headers) self.status_code = status_code + self.raw = raw + + +class RawResponse: + + def __init__(self): + self.body_bytes_read = 0 + + def tell(self): + return self.body_bytes_read class TestDownloadUtils: @@ -201,6 +211,27 @@ def test_download_interrupted(self, mock_env, httpbin_both): downloader.finish() assert downloader.interrupted + def test_download_with_encoded_Content_Length(self, mock_env): + raw = RawResponse() + with open(os.devnull, 'w') as devnull: + downloader = Downloader(mock_env, output_file=devnull) + downloader.start( + final_response=Response( + url='http://example.org/', + headers={ + 'Content-Encoding': 'gzip', + 'Content-Length': 5, + }, + raw=raw, + ), + initial_url='/' + ) + raw.body_bytes_read = 5 + downloader.chunk_downloaded(b'1234567890') + downloader.finish() + assert downloader.status.downloaded == 5 + assert not downloader.interrupted + def test_download_resumed(self, mock_env, httpbin_both): with tempfile.TemporaryDirectory() as tmp_dirname: file = os.path.join(tmp_dirname, 'file.bin')